# Calibration
这个任务不是很明白 似乎是把目标的映射关系进行缩减
主要的api是：
cc_logits = calibrate(prompt_model, support_dataloader)
verbalizer.register_calibrate_logits(logits=cc_logits)

In [1]:
from tqdm import tqdm
from openprompt.data_utils.text_classification_dataset import AgnewsProcessor
import torch
from openprompt.data_utils.utils import InputExample

In [2]:
data_path = '/home/wy/OpenPrompt/datasets/TextClassification/agnews'
data_path

'/home/wy/OpenPrompt/datasets/TextClassification/agnews'

In [3]:
dataset = dict()
dataset['train'] = AgnewsProcessor().get_train_examples(data_dir=data_path)
dataset['test'] = AgnewsProcessor().get_test_examples(data_dir=data_path)

In [28]:
dataset['train'][0]

{
  "guid": "0",
  "label": 2,
  "meta": {},
  "text_a": "Wall St. Bears Claw Back Into the Black (Reuters)",
  "text_b": "Reuters - Short-sellers, Wall Street's dwindling band of ultra-cynics, are seeing green again.",
  "tgt_text": null
}

In [4]:
from openprompt.plms import load_plm

In [5]:
plm, tokenizer, model_config, WrapperClass = load_plm("roberta","roberta-large")

In [6]:
template_path = '/home/wy/OpenPrompt/scripts/TextClassification/agnews/manual_template.txt'

In [7]:
from openprompt.prompts import ManualTemplate

In [8]:
my_template = ManualTemplate(tokenizer=tokenizer).from_file(path=template_path, choice=0)

In [9]:
my_template.text

[{'add_prefix_space': ' ', 'text': 'A'},
 {'add_prefix_space': ' ', 'mask': None},
 {'add_prefix_space': ' ', 'text': 'news :'},
 {'add_prefix_space': ' ', 'placeholder': 'text_a'},
 {'add_prefix_space': ' ', 'placeholder': 'text_b'}]

In [10]:
from openprompt import PromptDataLoader

In [11]:
from openprompt.prompts import ManualVerbalizer, KnowledgeableVerbalizer

In [12]:
verbalizer_path = '/home/wy/OpenPrompt/scripts/TextClassification/agnews/knowledgeable_verbalizer.txt'

In [13]:
my_verbalizer = KnowledgeableVerbalizer(tokenizer=tokenizer, num_classes=4).from_file(verbalizer_path)
# 从文件中导入Verbalizer

##Num of label words for each label: [376, 350, 287, 366]


In [14]:
from openprompt.data_utils.data_sampler import FewShotSampler
support_sampler = FewShotSampler(num_examples_total=200, also_sample_dev=False)
dataset['support'] = support_sampler(dataset['train'], seed=1)

In [15]:
print(len(dataset['train']))
print(len(dataset['test']))
print(len(dataset['support']))

120000
7600
200


In [16]:
for example in dataset['support']:
    example.label = -1  # 移除真实label
support_dataloader = PromptDataLoader(dataset=dataset['support'], template=my_template, tokenizer=tokenizer,tokenizer_wrapper_class=WrapperClass,
                                      max_seq_length=512, decoder_max_length=3, batch_size=5, shuffle=False, teacher_forcing=False, predict_eos_token=False,
                                      truncate_method='tail')

tokenizing: 200it [00:00, 683.10it/s]


In [17]:
from openprompt import PromptForClassification

In [18]:
use_cuda = True
prompt_model = PromptForClassification(plm=plm, template=my_template, verbalizer=my_verbalizer, freeze_plm=False)
if use_cuda:
    prompt_model = prompt_model.cuda()

In [19]:
my_verbalizer.label_words

[[' politics',
  ' government',
  ' diplomatic',
  ' law',
  ' aristotle',
  ' diplomatical',
  ' governance',
  ' republic',
  ' politician',
  ' smooth',
  ' suave',
  ' state',
  ' expedient',
  ' sagacious',
  ' police',
  ' election',
  ' political',
  ' monarchy',
  ' parliament',
  ' dukes',
  ' polity',
  ' regime',
  ' democratic',
  ' ethics',
  ' communism',
  ' federation',
  ' anarchism',
  ' authoritarianism',
  ' populism',
  ' bland',
  ' aristocracy',
  ' tribe',
  ' power',
  ' negotiation',
  ' force',
  ' warfare',
  ' city',
  ' clans',
  ' tribes',
  ' company',
  ' country',
  ' plato',
  ' confucius',
  ' latin',
  ' polis',
  ' kingship',
  ' earls',
  ' counts',
  ' tribute',
  ' lordship',
  ' property',
  ' inheritance',
  ' confiscation',
  ' individualist',
  ' allegiance',
  ' espionage',
  ' conspiracy',
  ' treason',
  ' jewish',
  ' gentile',
  ' convention',
  ' observance',
  ' celibacy',
  ' pope',
  ' taxation',
  ' petition',
  ' legislation',
  '

In [20]:
org_label_words_num = [len(prompt_model.verbalizer.label_words[i]) for i in range(4)]
org_label_words_num

[376, 350, 287, 366]

In [21]:
from openprompt.utils.calibrate import calibrate
cc_logits = calibrate(prompt_model, support_dataloader)
print('calibration logits is', cc_logits)

ContextCali: 100%|██████████| 40/40 [00:04<00:00,  8.63it/s]


calibration logits is tensor([ 3.4085e+01, -4.2505e+00,  4.5795e+01,  ...,  4.3056e-02,
         8.8509e-01,  2.6045e+01], device='cuda:0')


In [23]:
prompt_model.verbalizer.register_calibrate_logits(logits=cc_logits)

##Num of label words for each label: [223, 272, 230, 242]


In [25]:
new_label_words_num = [len(prompt_model.verbalizer.label_words[i]) for i in range(4)]
new_label_words_num

[223, 272, 230, 242]

In [26]:
print("Original number of label words per class: {} \n After filtering, number of label words per class: {}".format(org_label_words_num, new_label_words_num))

Original number of label words per class: [376, 350, 287, 366] 
 After filtering, number of label words per class: [223, 272, 230, 242]


In [27]:
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=my_template, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=512, decoder_max_length=3,
    batch_size=5,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="tail")

tokenizing: 7600it [00:06, 1100.14it/s]


In [29]:
allpreds = []
alllabels = []
pbar = tqdm(test_dataloader)
for step, inputs in enumerate(pbar):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

100%|██████████| 1520/1520 [02:24<00:00, 10.54it/s]


In [30]:
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
acc

0.854078947368421