In [2]:
from datasets import load_dataset
from openprompt.data_utils import InputExample
from openprompt.plms import load_plm
from openprompt.prompts import ManualTemplate,ManualVerbalizer
from openprompt import PromptDataLoader
from openprompt import PromptForClassification
from transformers import AdamW,get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter
from pprint import pprint
from tqdm import tqdm
import torch


print('Load Dataset..')
raw_dataset = load_dataset('super_glue', 'cb')
print(raw_dataset['train'][0])

dataset={}
for split in ['train', 'validation', 'test']:
    dataset[split] = []
    for data in raw_dataset[split]:
        input_example = InputExample(text_a = data['premise'], text_b = data['hypothesis'], label=int(data['label']), guid=data['idx'])
        dataset[split].append(input_example)
print(dataset['train'][0])

Load Dataset..


Found cached dataset super_glue (C:/Users/cq906/.cache/huggingface/datasets/super_glue/cb/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)


  0%|          | 0/3 [00:00<?, ?it/s]

{'premise': 'It was a complex language. Not written down but handed down. One might say it was peeled down.', 'hypothesis': 'the language was peeled down', 'idx': 0, 'label': 0}
{
  "guid": 0,
  "label": 0,
  "meta": {},
  "text_a": "It was a complex language. Not written down but handed down. One might say it was peeled down.",
  "text_b": "the language was peeled down",
  "tgt_text": null
}



In [3]:
print('Load Model..')
plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")
print('Build Template..')
template_text = '{"placeholder":"text_a"} Question: {"placeholder":"text_b"}? Is it correct? {"mask"}.'
mytemplate = ManualTemplate(tokenizer=tokenizer, text=template_text)
wrapped_example = mytemplate.wrap_one_example(dataset['train'][0])
'''
Return two list 1-template-text,2-label and guid
'''
pprint(wrapped_example)

Load Model..
Build Template..
[[{'loss_ids': 0,
   'shortenable_ids': 1,
   'text': 'It was a complex language. Not written down but handed down. One '
           'might say it was peeled down.'},
  {'loss_ids': 0, 'shortenable_ids': 0, 'text': ' Question:'},
  {'loss_ids': 0,
   'shortenable_ids': 1,
   'text': ' the language was peeled down'},
  {'loss_ids': 0, 'shortenable_ids': 0, 'text': '? Is it correct?'},
  {'loss_ids': 1, 'shortenable_ids': 0, 'text': '<mask>'},
  {'loss_ids': 0, 'shortenable_ids': 0, 'text': '.'}],
 {'guid': 0, 'label': 0}]


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [4]:
# Note that when t5 is used for classification, we only need to pass <pad> <extra_id_0> <eos> to decoder.
# The loss is calcaluted at <extra_id_0>. Thus passing decoder_max_length=3 saves the space
wrapped_t5tokenizer=WrapperClass(max_seq_length=128,
                                 decoder_max_length=3,
                                 tokenizer=tokenizer,
                                 truncate_method='head')
tokenized_example=wrapped_t5tokenizer.tokenize_one_example(wrapped_example,teacher_forcing=False)
print(tokenized_example)

{'input_ids': [94, 47, 3, 9, 1561, 1612, 5, 933, 1545, 323, 68, 14014, 323, 5, 555, 429, 497, 34, 47, 158, 400, 26, 323, 5, 11860, 10, 8, 1612, 47, 158, 400, 26, 323, 3, 58, 27, 7, 34, 2024, 58, 32099, 3, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_input_ids': [0, 32099, 0], 'loss_ids': [0, 1, 0]}


In [5]:
model_inputs = {}
for split in ['train', 'validation', 'test']:
    print('Processing: ',split.title())
    model_inputs[split] = []
    for sample in tqdm(dataset[split]):
        tokenized_example = wrapped_t5tokenizer.tokenize_one_example(mytemplate.wrap_one_example(sample), teacher_forcing=False)
        model_inputs[split].append(tokenized_example)

Processing:  Train


100%|██████████| 250/250 [00:00<00:00, 1069.05it/s]


Processing:  Validation


100%|██████████| 56/56 [00:00<00:00, 1515.43it/s]


Processing:  Test


  0%|          | 0/250 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 250/250 [00:00<00:00, 1062.88it/s]


In [6]:
print('Build Dataloader..')
train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=16,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")
valid_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=16,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

Build Dataloader..


tokenizing: 250it [00:00, 1158.91it/s]
tokenizing: 56it [00:00, 1322.92it/s]
tokenizing: 250it [00:00, 1084.20it/s]


In [7]:
myverbalizer = ManualVerbalizer(tokenizer, num_classes=3,
                        label_words=[["yes"], ["no"], ["maybe"]])
print(myverbalizer.label_words_ids)
logits = torch.randn(2,len(tokenizer)) # creating a pseudo output from the plm, and
print(myverbalizer.process_logits(logits))

Parameter containing:
tensor([[[4273]],

        [[ 150]],

        [[2087]]])
tensor([[-1.1512, -1.8715, -0.6351],
        [-2.9707, -0.3543, -1.3981]])


In [10]:
PromptModel=PromptForClassification(plm=plm,template=mytemplate,verbalizer=myverbalizer,freeze_plm=False)
PromptModel=PromptModel.cuda()
loss_func=torch.nn.CrossEntropyLoss()
# it's always good practice to set no decay to biase and LayerNorm parameters
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in PromptModel.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in PromptModel.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer=torch.optim.AdamW(optimizer_grouped_parameters,lr=1.5e-4)
writer=SummaryWriter(log_dir='log_dir')

In [11]:
best_loss=9999999
for epoch in range(20):
    train_loss = 0
    PromptModel.train()
    par=tqdm(train_dataloader)
    for step, inputs in enumerate(par):
        inputs.cuda()
        logits = PromptModel(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        par.set_description('epoch {}'.format(epoch+1))
        par.set_postfix(loss=loss.item())
    writer.add_scalars('Loss',{'train':train_loss/len(train_dataloader)},epoch+1)
    PromptModel.eval()
    val_loss=0
    par=tqdm(valid_dataloader)
    for step, inputs in enumerate(par):
        inputs.cuda()
        logits = PromptModel(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        val_loss += loss.item()
        par.set_description('epoch {}'.format(epoch+1))
        par.set_postfix(loss=loss.item())
    val_loss=val_loss/len(valid_dataloader)
    writer.add_scalars('Loss',{'valid':val_loss},epoch+1)

    if val_loss<best_loss:
        best_loss=val_loss
        print('Save Model..')
        torch.save(PromptModel.state_dict(),'t5-base-prompt.pt')

epoch 1: 100%|██████████| 16/16 [00:10<00:00,  1.51it/s, loss=0.197]
epoch 1: 100%|██████████| 4/4 [00:00<00:00,  6.76it/s, loss=0.824]


Save Model..


epoch 2: 100%|██████████| 16/16 [00:08<00:00,  1.79it/s, loss=0.286] 
epoch 2: 100%|██████████| 4/4 [00:00<00:00,  6.82it/s, loss=0.0678]


Save Model..


epoch 3: 100%|██████████| 16/16 [00:08<00:00,  1.79it/s, loss=0.609]  
epoch 3: 100%|██████████| 4/4 [00:00<00:00,  6.77it/s, loss=1.2]   
epoch 4: 100%|██████████| 16/16 [00:08<00:00,  1.80it/s, loss=0.00548]
epoch 4: 100%|██████████| 4/4 [00:00<00:00,  6.81it/s, loss=0.000546]


Save Model..


epoch 5: 100%|██████████| 16/16 [00:09<00:00,  1.71it/s, loss=0.00705] 
epoch 5: 100%|██████████| 4/4 [00:00<00:00,  6.36it/s, loss=0.325] 


Save Model..


epoch 6: 100%|██████████| 16/16 [00:09<00:00,  1.67it/s, loss=0.00115]
epoch 6: 100%|██████████| 4/4 [00:00<00:00,  6.47it/s, loss=0.0162]


Save Model..


epoch 7: 100%|██████████| 16/16 [00:09<00:00,  1.75it/s, loss=0.000661]
epoch 7: 100%|██████████| 4/4 [00:00<00:00,  6.54it/s, loss=0.00572]


Save Model..


epoch 8: 100%|██████████| 16/16 [00:09<00:00,  1.75it/s, loss=0.000236]
epoch 8: 100%|██████████| 4/4 [00:00<00:00,  6.52it/s, loss=1.02]  
epoch 9: 100%|██████████| 16/16 [00:09<00:00,  1.77it/s, loss=9.56e-5] 
epoch 9: 100%|██████████| 4/4 [00:00<00:00,  6.68it/s, loss=0.104] 
epoch 10: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s, loss=0.00345] 
epoch 10: 100%|██████████| 4/4 [00:00<00:00,  6.64it/s, loss=1.02]  
epoch 11: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s, loss=4.47e-5] 
epoch 11: 100%|██████████| 4/4 [00:00<00:00,  6.55it/s, loss=0.151] 
epoch 12: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s, loss=0.000157]
epoch 12: 100%|██████████| 4/4 [00:00<00:00,  6.39it/s, loss=1.45]  
epoch 13: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s, loss=5.15e-5] 
epoch 13: 100%|██████████| 4/4 [00:00<00:00,  6.62it/s, loss=2.74e-5]


Save Model..


epoch 14: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s, loss=0.000708]
epoch 14: 100%|██████████| 4/4 [00:00<00:00,  6.60it/s, loss=1.24]   
epoch 15: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s, loss=0.000138]
epoch 15: 100%|██████████| 4/4 [00:00<00:00,  6.56it/s, loss=0.184]
epoch 16: 100%|██████████| 16/16 [00:09<00:00,  1.76it/s, loss=0.000298]
epoch 16: 100%|██████████| 4/4 [00:00<00:00,  6.48it/s, loss=6.03e-5]


Save Model..


epoch 17: 100%|██████████| 16/16 [00:09<00:00,  1.74it/s, loss=0.000795]
epoch 17: 100%|██████████| 4/4 [00:00<00:00,  6.55it/s, loss=0.000708]
epoch 18: 100%|██████████| 16/16 [00:09<00:00,  1.77it/s, loss=0.000168]
epoch 18: 100%|██████████| 4/4 [00:00<00:00,  6.58it/s, loss=0.0446] 
epoch 19: 100%|██████████| 16/16 [00:08<00:00,  1.80it/s, loss=0.000298]
epoch 19: 100%|██████████| 4/4 [00:00<00:00,  6.74it/s, loss=0.0107]
epoch 20: 100%|██████████| 16/16 [00:08<00:00,  1.79it/s, loss=0.002]   
epoch 20: 100%|██████████| 4/4 [00:00<00:00,  6.69it/s, loss=0.00581]


In [12]:
print('Load Best Model..')
PromptModel.cpu()
PromptModel.load_state_dict(torch.load('t5-base-prompt.pt'))


Load Best Model..


<All keys matched successfully>

In [13]:
allpreds = []
alllabels = []
PromptModel.eval()
for step, inputs in enumerate(valid_dataloader):

    inputs = inputs.to('cpu')
    logits = PromptModel(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print('Best Accuracy:{:.2f}%'.format(acc*100))


Best Accuracy:94.64%
