# OpenPrompt demo

In this demo, you will learn to use openprompt to construct various prompt learning methods

1. Basic Usage
2. Mixed Template
3. Freeze the PLM
4. Soft Verbalizer
5. Knowledgeable Verbalizer

Firstly check the existence of GPU.

In [None]:
!nvidia-smi

Mon Jul  4 03:44:40 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# install necessary packages
!pip install transformers --quiet
!pip install datasets==2.0 --quiet
!pip install openprompt --quiet
!pip install torch --quiet

[K     |████████████████████████████████| 4.4 MB 15.5 MB/s 
[K     |████████████████████████████████| 596 kB 64.4 MB/s 
[K     |████████████████████████████████| 6.6 MB 55.2 MB/s 
[K     |████████████████████████████████| 101 kB 12.8 MB/s 
[K     |████████████████████████████████| 325 kB 16.0 MB/s 
[K     |████████████████████████████████| 212 kB 73.0 MB/s 
[K     |████████████████████████████████| 140 kB 16.0 MB/s 
[K     |████████████████████████████████| 1.1 MB 73.5 MB/s 
[K     |████████████████████████████████| 127 kB 74.2 MB/s 
[K     |████████████████████████████████| 94 kB 3.7 MB/s 
[K     |████████████████████████████████| 271 kB 77.6 MB/s 
[K     |████████████████████████████████| 144 kB 74.7 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible

### Load the CB dataset
The superglue.CB dataset is a small dataset for natural language entailment task.

In [None]:
from datasets import load_dataset
raw_dataset = load_dataset('super_glue', 'cb', cache_dir="../datasets/.cache/huggingface_datasets")
raw_dataset['train'][0]
# from datasets import load_from_disk
# raw_dataset = load_from_disk("/home/datasts_cache/super_glue.cb")
# Note that if you are running this scripts inside a GPU cluster, there are chances are you are not able to connect to huggingface website directly.
# In this case, we recommend you to run `raw_dataset = load_dataset(...)` on some machine that have internet connections.
# Then use `raw_dataset.save_to_disk(path)` method to save to local path.
# Thirdly upload the saved content into the machiine in cluster.
# Then use `load_from_disk` method to load the dataset.


Downloading builder script:   0%|          | 0.00/9.47k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/8.23k [00:00<?, ?B/s]

Downloading and preparing dataset super_glue/cb (download: 73.71 KiB, generated: 198.02 KiB, post-processed: Unknown size, total: 271.73 KiB) to ../datasets/.cache/huggingface_datasets/super_glue/cb/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7...


Downloading data:   0%|          | 0.00/75.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/250 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/56 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/250 [00:00<?, ? examples/s]

Dataset super_glue downloaded and prepared to ../datasets/.cache/huggingface_datasets/super_glue/cb/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

{'hypothesis': 'the language was peeled down',
 'idx': 0,
 'label': 0,
 'premise': 'It was a complex language. Not written down but handed down. One might say it was peeled down.'}

### Load the model and tokenizer


In [None]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


### Build the input

##### 1. construct a InputExample, which organize data into a form that can be used later

In [None]:
from openprompt.data_utils import InputExample

dataset = {}
for split in ['train', 'validation', 'test']:
    dataset[split] = []
    for data in raw_dataset[split]:
        input_example = InputExample(text_a = data['premise'], text_b = data['hypothesis'], label=int(data['label']), guid=data['idx'])
        dataset[split].append(input_example)
print(dataset['train'][0])


{
  "guid": 0,
  "label": 0,
  "meta": {},
  "text_a": "It was a complex language. Not written down but handed down. One might say it was peeled down.",
  "text_b": "the language was peeled down",
  "tgt_text": null
}



##### 2. Constructing Template
A template can be constructed from the yaml config, but it can also be constructed by directly passing arguments.


In [None]:
from openprompt.prompts import ManualTemplate
template_text = '{"placeholder":"text_a"} Deduction: {"placeholder":"text_b"}. Is it correct? {"mask"}.'
mytemplate = ManualTemplate(tokenizer=tokenizer, text=template_text)

# To better understand how does the template wrap the example, we visualize one instance.
wrapped_example = mytemplate.wrap_one_example(dataset['train'][0])
wrapped_example

[[{'loss_ids': 0,
   'shortenable_ids': 1,
   'text': 'It was a complex language. Not written down but handed down. One might say it was peeled down.'},
  {'loss_ids': 0, 'shortenable_ids': 0, 'text': ' Deduction:'},
  {'loss_ids': 0,
   'shortenable_ids': 1,
   'text': ' the language was peeled down'},
  {'loss_ids': 0, 'shortenable_ids': 0, 'text': '. Is it correct?'},
  {'loss_ids': 1, 'shortenable_ids': 0, 'text': '<mask>'},
  {'loss_ids': 0, 'shortenable_ids': 0, 'text': '.'}],
 {'guid': 0, 'label': 0}]

##### 3. use the tokenizer wrapper class to tokenize the wrapped_example.

We can't directly use the pre-trained tokenizer since it can't handle the complex situation in prompt learning (multiple pieces of template text and multiple pieces of shortenable input text; some tokens are soft and trainable while others are not).

In [None]:
wrapped_t5tokenizer = WrapperClass(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")
# or
from openprompt.plms import T5TokenizerWrapper
wrapped_t5tokenizer= T5TokenizerWrapper(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")

# You can see what a tokenized example looks like by
tokenized_example = wrapped_t5tokenizer.tokenize_one_example(wrapped_example, teacher_forcing=False)
print(tokenized_example)
print(tokenizer.convert_ids_to_tokens(tokenized_example['input_ids']))
print(tokenizer.convert_ids_to_tokens(tokenized_example['decoder_input_ids']))


{'input_ids': [94, 47, 3, 9, 1561, 1612, 5, 933, 1545, 323, 68, 14014, 323, 5, 555, 429, 497, 34, 47, 158, 400, 26, 323, 5, 374, 8291, 10, 8, 1612, 47, 158, 400, 26, 323, 3, 5, 27, 7, 34, 2024, 58, 32099, 3, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_input_ids': [0, 32099, 0], 'loss_ids': [0, 1, 0]}
['▁It', '▁was', '▁', 'a', '▁complex', '▁language', '.', '▁Not', '▁written

#### Building the Input for the whole dataset
Now it's time to convert the whole dataset into the input format!
Simply loop over the dataset to achieve it!

In [None]:
model_inputs = {}
for split in ['train', 'validation', 'test']:
    model_inputs[split] = []
    for sample in dataset[split]:
        tokenized_example = wrapped_t5tokenizer.tokenize_one_example(mytemplate.wrap_one_example(sample), teacher_forcing=False)
        model_inputs[split].append(tokenized_example)


Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in indexing errors


### Build the dataloader
The dataloader object is an iterable object, iterate over it will provide the input tensors for each forward pass of the model. 

In [None]:
from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

tokenizing: 250it [00:00, 605.49it/s]


### Create the verbalizer


In [None]:

from openprompt.prompts import ManualVerbalizer
import torch

# for example the verbalizer contains multiple label words in each class
myverbalizer = ManualVerbalizer(tokenizer, num_classes=3,
                        label_words=[["yes"], ["no"], ["maybe"]])

print(myverbalizer.label_words_ids)
logits = torch.randn(2,len(tokenizer)) # creating a pseudo output from the plm, and
print(myverbalizer.process_logits(logits))

Parameter containing:
tensor([[[4273]],

        [[ 150]],

        [[2087]]])
tensor([[-1.1827, -1.1284, -0.9942],
        [-0.4532, -1.1774, -2.8767]])


### Using the pipeline for classification

In [None]:
from openprompt import PromptForClassification

use_cuda = torch.cuda.is_available()
print("GPU enabled? {}".format(use_cuda))
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
    prompt_model=  prompt_model.cuda()

GPU enabled? True


### Now use a traditional training pipeline.

In [None]:

# Now the training is standard
from transformers import  AdamW, get_linear_schedule_with_warmup
loss_func = torch.nn.CrossEntropyLoss()
no_decay = ['bias', 'LayerNorm.weight']
# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters = [
    {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)

for epoch in range(5):
    tot_loss = 0
    for step, inputs in enumerate(train_dataloader):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        if step %100 ==1:
            print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)




Epoch 0, average loss: 1.037403017282486
Epoch 1, average loss: 0.2624429762363434
Epoch 2, average loss: 0.017658442142419517
Epoch 3, average loss: 0.0011099744879174978
Epoch 4, average loss: 0.0012295878259465098


### Evaluate the model

In [None]:
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

allpreds = []
alllabels = []
for step, inputs in enumerate(validation_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print(acc)

tokenizing: 56it [00:00, 615.13it/s]


0.8928571428571429


# Mixed Template

First we reload the plm since the weight has been changed.

In [None]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


Load the mixed template

In [None]:
from openprompt.prompts import MixedTemplate

mytemplate_soft1 = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft": "Question:"} {"placeholder":"text_b"}? Is it correct? {"mask"}.')
wrapped_t5tokenizer = WrapperClass(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")

from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate_soft1, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

tokenizing: 250it [00:00, 676.37it/s]


In [None]:
from openprompt import PromptForClassification

use_cuda = True
prompt_model = PromptForClassification(plm=plm,template=mytemplate_soft1, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
    prompt_model=  prompt_model.cuda()


In [None]:
from transformers import  AdamW, get_linear_schedule_with_warmup
loss_func = torch.nn.CrossEntropyLoss()

no_decay = ['bias', 'LayerNorm.weight']

# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters1 = [
    {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# Using different optimizer for prompt parameters and model parameters
optimizer_grouped_parameters2 = [
    {'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}
]

optimizer1 = AdamW(optimizer_grouped_parameters1, lr=1e-4)
optimizer2 = AdamW(optimizer_grouped_parameters2, lr=1e-3)

for epoch in range(5):
    tot_loss = 0
    for step, inputs in enumerate(train_dataloader):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer1.step()
        optimizer1.zero_grad()
        optimizer2.step()
        optimizer2.zero_grad()
        if step %100 ==1:
            print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)

In [None]:
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate_soft1, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

allpreds = []
alllabels = []
for step, inputs in enumerate(validation_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print(acc)


tokenizing: 56it [00:00, 582.26it/s]


0.9464285714285714


# Now we freeze the PLM


In [None]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")

from openprompt.prompts import MixedTemplate


mytemplate_soft2 = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft": "quenstion", "duplicate": 50} {"placeholder":"text_b"} {"soft": "yes", "duplicate": 16} {"soft": "no", "duplicate":16} {"soft": "maybe" , "duplicate": 16} {"mask"}.')
wrapped_t5tokenizer = WrapperClass(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")

from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate_soft2, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

from openprompt import PromptForClassification

use_cuda = True
## Freeze the plm
prompt_model = PromptForClassification(plm=plm,template=mytemplate_soft2, verbalizer=myverbalizer, freeze_plm=True)
if use_cuda:
    prompt_model=  prompt_model.cuda()


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
tokenizing: 250it [00:00, 687.21it/s]


In [None]:
from transformers import  AdamW, get_linear_schedule_with_warmup
loss_func = torch.nn.CrossEntropyLoss()

no_decay = ['bias', 'LayerNorm.weight']

# optimizer_grouped_parameters1 = [
#     {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
#     {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# ]

# Using different optimizer for prompt parameters and model parameters
optimizer_grouped_parameters2 = [
    {'params': [p for n,p in prompt_model.template.named_parameters() if "raw_embedding" not in n]}
]

# optimizer1 = AdamW(optimizer_grouped_parameters1, lr=1e-4)
optimizer2 = AdamW(optimizer_grouped_parameters2, lr=0.3)

for epoch in range(20):  # Longer epochs are needed
    tot_loss = 0
    for step, inputs in enumerate(train_dataloader):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        torch.nn.utils.clip_grad_norm_(prompt_model.template.parameters(), 1.0)
        # optimizer1.step()
        # optimizer1.zero_grad()
        optimizer2.step()
        optimizer2.zero_grad()
        if step %100 ==1:
            print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)



Epoch 0, average loss: 0.8537187576293945
Epoch 1, average loss: 0.5277364104986191
Epoch 2, average loss: 0.6581695377826691
Epoch 3, average loss: 0.2453271597623825
Epoch 4, average loss: 0.25415898859500885
Epoch 5, average loss: 0.27045516669750214
Epoch 6, average loss: 0.10566383227705956
Epoch 7, average loss: 0.1770429015159607
Epoch 8, average loss: 0.0975570622831583
Epoch 9, average loss: 0.33903101086616516
Epoch 10, average loss: 0.09052591398358345
Epoch 11, average loss: 0.16264036670327187
Epoch 12, average loss: 0.08797828480601311
Epoch 13, average loss: 0.03713387344032526
Epoch 14, average loss: 0.09599068760871887
Epoch 15, average loss: 0.025538533926010132
Epoch 16, average loss: 0.04664657078683376
Epoch 17, average loss: 0.063438655808568
Epoch 18, average loss: 0.056047774851322174
Epoch 19, average loss: 0.007814905839040875


In [None]:
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate_soft2, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

allpreds = []
alllabels = []
for step, inputs in enumerate(validation_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print(acc)

tokenizing: 56it [00:00, 150.15it/s]


0.8214285714285714


# Verbalizer

We use a text-classification task to demonstrate the funtionality of verbalizer

In [None]:
raw_dataset = load_dataset("ag_news")
raw_dataset['train'][0]

Downloading builder script:   0%|          | 0.00/1.83k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset ag_news/default (download: 29.88 MiB, generated: 30.23 MiB, post-processed: Unknown size, total: 60.10 MiB) to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548...


Downloading data:   0%|          | 0.00/11.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/751k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Dataset ag_news downloaded and prepared to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

{'label': 2,
 'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."}

In [None]:
from openprompt.data_utils import InputExample

dataset = {}
for split in ['train', 'test']:
    dataset[split] = []
    for idx, data in enumerate(raw_dataset[split]):
        input_example = InputExample(text_a = data['text'], label=int(data['label']), guid=idx)
        dataset[split].append(input_example)
print(dataset['train'][0])

{
  "guid": 0,
  "label": 2,
  "meta": {},
  "text_a": "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
  "text_b": "",
  "tgt_text": null
}



In [None]:
# We sample a few examples to form the few-shot training pool
from openprompt.data_utils.data_sampler import FewShotSampler
sampler  = FewShotSampler(num_examples_per_label=16, num_examples_per_label_dev=16, also_sample_dev=True)
dataset['train'], dataset['validation'] = sampler(dataset['train'])



In [None]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [None]:
from openprompt.prompts import ManualTemplate
mytemplate = ManualTemplate(tokenizer=tokenizer, text='{"placeholder":"text_a"} {"placeholder":"text_b"} In this sentence, the topic is {"mask"}.')

wrapped_example = mytemplate.wrap_one_example(dataset['train'][0])
print(wrapped_example)


[[{'text': 'Yukos unit to be sold for knockdown \\$4 bn The Russian government yesterday night appeared to be preparing to sell the main asset of Yukos, the countrys embattled oil company, for as little as \\$4 billion - a fraction of its fair value -n a move analysts branded as daylight robbery ', 'loss_ids': 0, 'shortenable_ids': 1}, {'text': ' ', 'loss_ids': 0, 'shortenable_ids': 1}, {'text': ' In this sentence, the topic is', 'loss_ids': 0, 'shortenable_ids': 0}, {'text': '<mask>', 'loss_ids': 1, 'shortenable_ids': 0}, {'text': '.', 'loss_ids': 0, 'shortenable_ids': 0}], {'guid': 61640, 'label': 2}]


In [None]:


from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

tokenizing: 64it [00:00, 745.78it/s]


##### Define the soft verbalizer

In [None]:
from openprompt.prompts import SoftVerbalizer
import torch
myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=4)


from openprompt import PromptForClassification

use_cuda = True
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
    prompt_model=  prompt_model.cuda()

In [None]:

from transformers import  AdamW, get_linear_schedule_with_warmup
loss_func = torch.nn.CrossEntropyLoss()

no_decay = ['bias', 'LayerNorm.weight']

# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters1 = [
    {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# Using different optimizer for prompt parameters and model parameters

optimizer_grouped_parameters2 = [
    {'params': prompt_model.verbalizer.group_parameters_1, "lr":3e-5},
    {'params': prompt_model.verbalizer.group_parameters_2, "lr":3e-4},
]


optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5)
optimizer2 = AdamW(optimizer_grouped_parameters2)


for epoch in range(10):
    tot_loss = 0
    for step, inputs in enumerate(train_dataloader):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer1.step()
        optimizer1.zero_grad()
        optimizer2.step()
        optimizer2.zero_grad()
        if step %100 ==1:
            print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)



Epoch 0, average loss: 0.6962400078773499
Epoch 1, average loss: 0.5157100260257721
Epoch 2, average loss: 0.36224932968616486
Epoch 3, average loss: 0.1654350832104683
Epoch 4, average loss: 0.1426909863948822
Epoch 5, average loss: 0.10412478819489479
Epoch 6, average loss: 0.046542128548026085
Epoch 7, average loss: 0.05146203376352787
Epoch 8, average loss: 0.029171333648264408
Epoch 9, average loss: 0.022262783721089363


In [13]:
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")
allpreds = []
alllabels = []
for step, inputs in enumerate(test_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print("test:", acc)  # roughly ~0.85

tokenizing: 7600it [00:10, 736.68it/s]


test: 0.6575


In [15]:
alllabels[:10],allpreds[:10]

([2, 3, 3, 3, 3, 3, 3, 3, 3, 3], [2, 3, 3, 3, 2, 3, 3, 3, 3, 2])

# KPT for zero-shot text classification


### Load the ag_news dataset

In [None]:
from datasets import load_dataset
raw_dataset = load_dataset("ag_news")
raw_dataset['train'][0]
from openprompt.data_utils import InputExample

dataset = {}
for split in ['train', 'test']:
    dataset[split] = []
    for idx, data in enumerate(raw_dataset[split]):
        input_example = InputExample(text_a = data['text'], label=int(data['label']), guid=idx)
        dataset[split].append(input_example)
print(dataset['train'][0])

Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


  0%|          | 0/2 [00:00<?, ?it/s]

{
  "guid": 0,
  "label": 2,
  "meta": {},
  "text_a": "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
  "text_b": "",
  "tgt_text": null
}



### We try a different PLM: roberta-large

In [23]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("roberta", "roberta-large")


Downloading:   0%|          | 0.00/482 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

### Define the label words
The label words in KPT is expanded by the knowledge bases. So there may be a lot of it. we save them to a file and load from the file.

In [24]:
label_words = ["politics,government,diplomatic,law,aristotle,diplomatical,governance,republic,politician,smooth,suave,state,expedient,sagacious,police,election,political,monarchy,parliament,dukes,polity,regime,democratic,ethics,communism,federation,anarchism,authoritarianism,populism,bland,aristocracy,tribe,power,negotiation,force,warfare,city,clans,tribes,company,country,plato,confucius,latin,polis,kingship,earls,counts,tribute,lordship,property,inheritance,confiscation,individualist,allegiance,espionage,conspiracy,treason,jewish,gentile,convention,observance,celibacy,pope,taxation,petition,legislation,majority,collegial,permeates,flabby,policy,governmental,pervades,abstract,straitjacket,justice,myopic,discernment,curfew,consciences,revenue,pervade,matriarchal,numbed,juridical,rightness,unaccountable,clannish,deliberative,monopoly,fatness,paternalism,monkish,coin,principle,sinfulness,morass,permeate,stuffy,timorous,meddles,factious,disentangle,grayness,etheric,blandness,indigent,expediency,comity,unclothed,subjectivity,transitory,guild,cesspool,psyche,quicksand,egoism,diplomat,colonies,agreement,profit,policy-making,education,democracy,debate,anarchy,executive,humorless,colorblind,monarchies,psychodrama,self-perpetuating,piercer,clubby,reflation,nontransparent,kafkaesque,postindustrial,democracies,unsustainability,socialism,oligarchies,self-regulating,hidebound,nonideological,wrongness,tyrannies,dispassion,clinton,governor,senate,squishy,soviet,heteronormative,administration,corporatism,metabolizes,covenantal,microbiota,shapers,solipsistic,originalism,seven-man,empowerment,standard-setting,groupthink,bureaucracy,sovereignty,146-nation,autocracy,constitutions,president,imperialism,privatization,restoration,presidential,preside,manifesto,interpol,france,minimalist,crafty,mastermind,federal,racketeer,campaigner,australia,contract,capitalism,administrative,presidency,ombudsman,confederation,generalissimo,fiefdom,officer,senator,ceo,autocratic,financier,secede,timocracy,anarchist,manipulative,civilization,civil,artifice,bloc,suzerainty,politicize,issue,multinational,shrewd,politricks,liberalism,warden,privatize,dishonest,federalization,govern,tenderpreneur,treasurer,authority,international,separatism,expert,statesperson,economist,chancellor,commissioner,guile,neocolonialism,impolitic,politik,papacy,socio,posturing,discourse,politico,chomsky,democrats,crist,polemics,hegelian,religio,politicians,elites,quietism,theo,circumlocution,parliamentary,sayers,roguery,sociopolitical,bonapartist,moralize,reactionary,religionist,egotistic,pragmatics,contestation,apolitical,raison,poli,pols,intelligentsia,viler,hausa,jurisdiction,manipulation,cabal,resourceful,democrat,judiciary,supremacy,demagogue,controller,shrewdness,corporatist,official,socialist,wizard,dominion,supervision,skillful,organisation,cleverness,statesman,employer,undeceive,comptroller,organization,magician,mislead,rebellion,civility,reich,marxism,cosmopolitanism,fedzilla,neoimperialism,leader,quango,corruption,extortion,misleader,nepotism,patronage,crossbencher,graft,mandarinate,scienda,stateswoman,embezzlement,presider,statocracy,politick,policial,ocracy,reformable,exclusionism,egoistical,coopt,laic,bureaucratize,liberalist,reactionism,confederal,popularism,delegitimation,realpolitik,establishmentarian,plebiscitary,confessionalism,demagogical,duumvirate,ethnocentric,coarsen,centrism,peoplehood,ideologic,liberalness,repub,revanchist,technocracy,balkanise,conscionable,christianism,contradictive,mobocracy,equalitarian,sermonise,mudslinger,sloganeer,technocratic,biopolitics,megalomanic,westernism,passivism,russophile,falsifiable,u.s.s.r.,presidium,nativism,elections,misinform,nationalise,ideology,constitutionalization,disenchant,coalition,cronyism,influence,hobbes,kleptocracy",
"sports,athletics,gymnastics,sportsman,competition,cycling,soccer,tennis,game,downfield,offside,judo,polo,team,skiing,hockey,baseball,football,fun,sportswoman,play,rugby,basketball,call,spar,kill,referee,ineligible,wipeout,schuss,luge,archery,upfield,funambulism,toboggan,skateboard,jackknife,ski,rollerblade,boast,mutation,lark,frolic,frisk,skylark,gambol,feature,disport,mutant,romp,cavort,rollick,coach,volleyball,athletic,sumo,television,sportsmanship,pastime,badminton,chess,position,sportaccord,equestrianism,sportsperson,athlete,competitions,golf,coaches,wrestling,cricket,championships,racers,challenge,motocross,leagues,variation,boxing,summercater,entertainment,tournament,champion,season,playoffs,athleticism,dexterity,foul,jog,handler,box,defense,defence,trial,series,cut,side,english,tuck,canoe,dribble,possession,bob,stroke,shot,equitation,row,aquatics,recreation,diversion,toss,pass,flip,occupation,line,job,paddle,carry,surf,racket,drive,surfboard,punt,onside,kick,submarine,bandy,kayak,drop,pack,umpire,backpack,scull,snorkel,shoot,rappel,field,mountaineer,start,curl,underarm,seed,surge,turn,underhand,underhanded,round,bout,hurdle,average,sleigh,loose,overhand,humor,legal,humour,wit,lead,hike,deficit,timer,witticism,jocularity,timekeeper,shooter,scout,home,ref,ironman,skate,manager,comedy,sportive,lacrosse,tradition,goal,biathlon,dodgeball,leisure,floorball,soccerplex,overarm,waggishness,jocosity,wittiness,windsurf,abseil,skin-dive,double-team,prizefight,outclass,shadowbox,birling,spread-eagle,offsides,man-to-man,one-on-one,most-valuable,waggery,motorsport,sportful,sporter,gameday,sportsaholic,nonsports,multisport,footballer,outsport,sportless,lusorious,acrobatic,sportlike,rugger,paddlesport,sportsplex,gamesome,pickleball,postseason,professional,passtime,competitive,slalom,birle,skateboarder,olympics,world,racquet,compete,bowling,competes,olympic,dropkick,sportsfield,clubs,skater,formula,racer,cheerlead,pharaoh,race,minigame,athletes,recreational,bike,snowboard,bicycle,championship,motorcycle,brand,youth,nascar,iran,model,f1,uci,teams,puck,track,racquetball,competitor,riders,postgame,subbuteo,enthusiasts,trashsport,popular,super,games,jousting,class,sponsorship,event,netball,softball,models,best,women,amateur,association,experience,peloponnese,car,venue,players,roller,fia,pigskin,fit,standards,drivers,european,national,tour,fitness,cars,esports,transgender,wogball,bucketball,tennikoit,snowsport,nongame,gamification,subgame,nongamer,vacationer,gameplayer,rioting,gaymer,sportsbook,hooliganism,zourkhaneh,gameography,watersport,fanwear,cross-country,soccerball,wintersports,woodball,concussion,disability,motorsports,interscholastic,tournaments,korfball,triathlon,intercollegiate,paralympic,olympian,bullfighting,boxers,subculture,crosscountry,mma,equestrian,wage,salary,indoors,pay-per-view,fina,nonresident,spectatorship,waterpolo,bloodsport,bobsledding,tourn,riflery,alpinism,bundesliga,nfl,fanatic",
"business,commerce,trade,market,retail,traffic,commercial,marketing,exchange,sell,deal,export,shop,transportation,finance,barter,noncommercial,resell,mercantilism,arbitrage,commercialism,commercialize,wholesale,negociate,transaction,merchandise,transact,auctioneer,doc,smuggle,import,deaccession,importation,merchant,merchandising,mercantile,trading,sale,exportation,e-commerce,affairs,goods,commodity,agriculture,resources,tourism,telecommunications,economic,consumer,agricultural,banking,communications,industries,globalization,trader,consumers,seller,resale,vendor,huckster,shipping,broker,pawn,ipo,hock,interchange,transport,evasion,distribution,antique,purchase,conversation,franchise,noaa,clear,auction,browse,trust,soak,stock,remainder,payment,dealings,realise,prehistoric,retailer,peddle,basketry,tradesman,vendible,simony,mercature,tradeful,untradeable,tradesfolk,tradable,swap,mercat,dicker,untraded,mart,oversold,dealer,transactive,intertraffic,dressmaking,marketable,hanse,venal,monger,vend,hoppo,troak,marketplace,cybercommerce,scorse,defrayment,liquidize,comparison-shop,impulse-buy,commercialise,defrayal,nonpayment,usance,sellable,overtrade,protrade,incoterm,faculty,bureau,self-sufficiency,nonexchange,pawnbrokery,exchangeable,chaffer,downtick,postdeal,department,cheap,tradesperson,mastercraftsman,carpentry,commission,cooperage,undealt,dealy,reexchange,vice,intercourse,investment,furriery,buyer,haggle,sector,merchants,sectors,forestry,office,contraband,issues,general,nonmarket,bootleg,labor,marketeer,misdeal,marketwide,enterprises,committee,china,kong,hong,local,higgler,circa,services,planning,secretary,taiwan,oversees,shopping,markets,employment,today,public,ministry,press,foreign,chairman,forum,rebuy,domestic,firm,butcherdom,board,regional,report,economics,environment,immigration,financial,institutional,firms,exchanges,customs,petroleum,corporations,xinhua,wto,currency,cornmarket,brokering,logroll,marketman,undersell,trades,handel,store,excambion,peddlery,oligopoly,cheesemonger,act,barterer,money,cybermarket,moc,woolhall,signatures,tradecraft,comercio,commercio,entrepreneurship,retailing,maritime,businesses,chamber,marketplaces,arts,marts,transactions,textiles,humanities,exports,lifeblood,prosperity,portals,outbound,remittance,agri,bookselling,cyberspace,telecom,steamboats,cultural,facilitation,businesspeople,industrie,protectionism,gateway,unctad,floriculture,connectivity,closeout,métier,truckman,oversell,craftsmaster,botanica,redeal,monopsony,buyback,patela,numismatist,barkeeping,ndrc,u.s.,commerical,usdoc,e-business,entrepot,containerization,busi,vitalization,agrarianism,capital,corporation",
"technology,engineering,science,biotechnology,internet,nanotechnology,robotics,communication,computer,industry,automation,wheel,technological,equipment,manufacturing,application,bionics,energy,technical,ergonomics,scientific,telephone,development,dolphin,systems,software,transhumanism,electronics,digital,tech,devices,tools,applications,computers,capabilities,expertise,biomedical,innovations,wireless,products,solutions,management,language,engineer,homo,neolithic,bioscience,scientist,biology,skill,good,service,knowledge,prehistory,lever,weapon,club,miniaturization,economy,pollution,value,productivity,discipline,bailiwick,machinery,source,bioengineering,subject,rocketry,sink,study,primates,technologies,crowbar,spoon,neo-luddism,anarcho-primitivism,ee,medicine,techno-progressivism,physics,innovation,phenomenal,tool,formality,developed,research,utility,merriam-webster,usability,electronic,safety,micro,focus,mathematics,advanced,multimedia,history,system,networking,electron,semiconductor,innovative,enterprise,global,hardware,uses,core,product,data,sophisticated,industrial,design,state-of-the-art,makers,components,use,processing,hominids,cyberculture,create,information,mobile,bipedal,companies,programs,creative,expand,strategy,new,program,generation,labs,dynamic,aims,networks,enables,component,capability,build,specialized,wood,charcoal,clothing,genome,eurasia,deforestation,goal-oriented,technician,nomad,uruk,sumer,hierarchy,pseudoscience,irrigation,neuroscience,furnace,bellows,forge,gold,metrology,copper,geology,silver,evolution,alloys,bronze,brass,alchemy,steel,cyberscience,climatology,sociology,evolve,biophysics,iraq,bionanoscience,alchemical,evolutionary,tribology,sedentism,silk,transformation,systematics,phrenology,superscience,technoscience,optimization,transmutation,horseshoe,microscopy,multiscience,developer,construction,antiscience,geophysics,proscience,transformational,demography,psychology,organon,conversion,scienceless,screw,ic,chasten,geoscience,converter,metamorphosis,pulley,fortran,cybernetics,nonscience,transform,wheelbarrow,architectonics,transformer,sciencelike,windmill,radiography,transmute,agronomy,clock,glycoscience,mutate,convert,technique,complicate,hydroscience,mcscience,theoretician,genetics,environmental,physic,assimilate,metamorphic,mining,physiology,metallurgy,interoperable,liberalize,innovate,electrochemistry,biologist,change,biological,ecological,deaden,reform,metaphysics,electricity,agrobiology,photoscience,decarboxylate,flight,complexify,ology,chemistry,acetylate,volatilize,skyscraper,transaminate,motor,logy,telegraph,biometrics,technologists,microelectronics,innovators,holography,algorithms,informatics,diagnostics,pbc,nano,inventions,sapir,startups,methodologies,telematics,functionality,gadgets,ione,biomedicine,interfaces,prognostics,semiconductors,cryptography,ignis,geospatial,cryogenics,radio,lifehack,airplane,techie,sociobiology,automobile,saponify,biocomputing,rarefy,actinochemistry,technoid,geroscience,exobiology,convertee,transmogrify,opacify,conversive,transchange,neurophysics,electrotelegraphy,classicize,transistor,hydrolyze,downshift,cyberpsychology,professionalize,microphonics,unscramble,remew,inactivate,conventionalize,bioclimatology,sysop,brutalize,satellite,telecommunication,miniaturisation,technologic,tecnology,gizmo,photomicrography,energid,fiberoptics,mirasol,telerobotics,relume,fluidics,idesia,blueshift,lightwave,ceroma,corrigent,reflectent,vocable,seawell,photomultipliers,echoscope,electromagnetics,ultrasonics,viridity,micrographics,orthogon,chiliad,autonomics,aits,tomograph,biomimetics,vadium,snocone,pyrometers,architecture,faust,goethe,technicism,citizenship"]

with open("./kpt_label_words.txt", 'w') as fout:
    for ws in label_words:
        fout.write(ws+"\n")

### Define the verbalizer and template

In [25]:
from openprompt.prompts import ManualVerbalizer, KnowledgeableVerbalizer, ManualTemplate

myverbalizer = KnowledgeableVerbalizer(tokenizer, num_classes=4).from_file("./kpt_label_words.txt")
mytemplate = ManualTemplate(tokenizer=tokenizer, text="""A {"mask"} news : {"placeholder": "text_a"} {"placeholder": "text_b"}""")

##Num of label words for each label: [376, 350, 287, 366]


### Perform contextualized calibration
The label words contains a lot of noise. For example, the probability of predicting `neocolonialism` will be noisy. The probability of prediction `machinery` will be much smaller than `technology`, although it may be also informative.
Therefore we perform contextualized calibration.

The calibration is without label. 

In [26]:
# (contextual) calibration
from openprompt import PromptDataLoader
from openprompt.data_utils.data_sampler import FewShotSampler
support_sampler = FewShotSampler(num_examples_total=200, also_sample_dev=False)
dataset['support'] = support_sampler(dataset['train'], seed=1)
for example in dataset['support']:
    example.label = -1 # remove the labels of support set for classification
support_dataloader = PromptDataLoader(dataset=dataset["support"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=128, 
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="tail")

tokenizing: 64it [00:00, 561.65it/s]


In [27]:
from openprompt import PromptForClassification
use_cuda = True
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
    prompt_model=  prompt_model.cuda()


### Processing and Refining the label words

In [28]:
org_label_words_num = [len(prompt_model.verbalizer.label_words[i]) for i in range(4)]
from openprompt.utils.calibrate import calibrate
# calculate the calibration logits
cc_logits = calibrate(prompt_model, support_dataloader)
print("the calibration logits is", cc_logits)

# register the logits to the verbalizer so that the verbalizer will divide the calibration probability in producing label logits
# currently, only ManualVerbalizer and KnowledgeableVerbalizer support calibration.
prompt_model.verbalizer.register_calibrate_logits(cc_logits)
new_label_words_num = [len(prompt_model.verbalizer.label_words[i]) for i in range(4)]
print("Original number of label words per class: {} \n After filtering, number of label words per class: {}".format(org_label_words_num, new_label_words_num))


ContextCali: 100%|██████████| 16/16 [00:01<00:00, 14.17it/s]


the calibration logits is tensor([26.5108, -3.2521, 42.8338,  ..., -0.0709,  0.8171, 22.8033],
       device='cuda:0')
##Num of label words for each label: [239, 286, 225, 230]
Original number of label words per class: [376, 350, 287, 366] 
 After filtering, number of label words per class: [239, 286, 225, 230]


### Zero-shot test achieves high accuracy.

In [29]:
# zero-shot test
from tqdm import tqdm
import torch
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=128,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="tail")
allpreds = []
alllabels = []
pbar = tqdm(test_dataloader)
for step, inputs in enumerate(pbar):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
print("test:", acc)  # roughly ~0.853 when using template 0

tokenizing: 7600it [00:09, 838.97it/s]
100%|██████████| 1900/1900 [02:10<00:00, 14.60it/s]

test: 0.8356578947368422



