In [2]:
from openprompt.data_utils import InputExample
import torch

In [3]:

from openprompt.data_utils.text_classification_dataset import AgnewsProcessor

In [5]:
dataset = {}
dataset['train'] = AgnewsProcessor().get_train_examples("../datasets/TextClassification/agnews")

In [96]:

classes = [ # There are two classes in Sentiment Analysis, one for negative and one for positive
    "negative",
    "positive"
]
dataset = [ # For simplicity, there's only two examples
    # text_a is the input text of the data, some other datasets may have multiple input sentences in one example.
    InputExample(
        guid = 0,
        text_a = "Patient has severe chest pain and feels awful.",
    ),
    InputExample(
        guid = 1,
        text_a = "The patient is doing well and has no problems.",
    ),
]


In [97]:
dataset

[{
   "guid": 0,
   "label": null,
   "meta": {},
   "text_a": "Patient has severe chest pain and feels awful.",
   "text_b": "",
   "tgt_text": null
 },
 {
   "guid": 1,
   "label": null,
   "meta": {},
   "text_a": "The patient is doing well and has no problems.",
   "text_b": "",
   "tgt_text": null
 }]

In [98]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [113]:
from openprompt.prompts import ManualTemplate, SoftTemplate, MixedTemplate
promptTemplate = ManualTemplate(
    text = '{"placeholder":"text_a"} Patient is {"mask"}',
    tokenizer = tokenizer,
)


In [100]:
promptTemplate.text

[{'add_prefix_space': '', 'placeholder': 'text_a'},
 {'add_prefix_space': ' ', 'text': 'Patient is'},
 {'add_prefix_space': ' ', 'mask': None}]

In [101]:
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        "negative": ["sick","dying","unwell"],
        "positive": ["healthy", "well", "great"],
    },
    tokenizer = tokenizer,
)


In [31]:
promptVerbalizer

ManualVerbalizer()

In [102]:
from openprompt import PromptForClassification
promptModel = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer,
)


In [103]:
promptModel

PromptForClassification(
  (prompt_model): PromptModel(
    (plm): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(28996, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
          

In [104]:

from openprompt import PromptDataLoader
data_loader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer, 
    template = promptTemplate, 
    tokenizer_wrapper_class=WrapperClass,
)


tokenizing: 2it [00:00, 2004.45it/s]


In [56]:
data_loader.wrapped_dataset

[[[{'text': 'Patient has severe chest pain and feels awful.',
    'loss_ids': 0,
    'shortenable_ids': 1},
   {'text': ' Patient is', 'loss_ids': 0, 'shortenable_ids': 0},
   {'text': '[MASK]', 'loss_ids': 1, 'shortenable_ids': 0}],
  {'guid': 0}],
 [[{'text': 'The patient is doing well and has no problems.',
    'loss_ids': 0,
    'shortenable_ids': 1},
   {'text': ' Patient is', 'loss_ids': 0, 'shortenable_ids': 0},
   {'text': '[MASK]', 'loss_ids': 1, 'shortenable_ids': 0}],
  {'guid': 1}]]

In [111]:
# making zero-shot inference using pretrained MLM with prompt
promptModel.eval()
with torch.no_grad():
    for batch in data_loader:
        # print(batch)
        logits = promptModel(batch)
        print(logits)
        preds = torch.argmax(logits, dim = -1)
        print(classes[preds])
# predictions would be 1, 0 for classes 'positive', 'negative'


tensor([[-2.4273, -1.7207]])
positive
tensor([[-4.1333, -1.4446]])
positive


In [38]:
from openprompt import PromptModel as PromptModelBase



In [40]:
# get base prompt model which does not auto extract only the masked token logits
p_base = PromptModelBase(  template = promptTemplate,
    plm = plm
   )

In [106]:
# quick hack to check the output logits at all positions - not just mask
base_output = p_base(batch)

In [107]:
# convert to numpy
base_output_np = base_output.logits.detach().cpu().numpy().squeeze(0).shape

In [108]:
# run through a softmax to get probs then get argmax which will be vocab/token id
sm = torch.nn.Softmax()
probs = sm(base_output.logits.squeeze(0))
output_token_predictions = probs.argmax(1)

  probs = sm(base_output.logits.squeeze(0))


In [109]:
# decode
tokenizer.decode(output_token_predictions)

'. The patient is doing well and has no problems. Patient is.. The happy. He Very, and has no problems. Pat patient happy happy happy The The is very very, he has no problems The The Pet good. v Good The patient is very, He has no problems. Pat patient good. good The The patient is very very and he is healthy. The The The the patient The The patient is very very, he has no problems. Pat The good happy happy the healthy healthy always very very and has has health.. Pat Pat happy.. good good. Very The Very healthy and he still and. Be The The satisfied.. Pat patient is.. Pat good. happy The healthy healthy always very, and has no health symptoms has no problems. Pat Pat v happy happy... good.. happy always very very healthy He He He He healthy He has The healthy Ex Pat Pat. v und. the the healthy always The He is always very very, He he he He has has problems happy. Pat patient happy. und. the good good good happy The patient is very very very healthy, very He healthy He He has no health

In [110]:
tokenizer.decode([101, 1109, 5351, 1110, 1833, 1218, 1105, 1144, 1185, 2645, 119, 7195, 9080, 1110, 103, 102])

'[CLS] The patient is doing well and has no problems. Patient is [MASK] [SEP]'

Try with soft or mixed template

In [114]:
# Soft

# classification
softTemplate = SoftTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"special": "<eos>"} {"mask"}',num_tokens=100)

# QA
# mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft"} {"soft"} {"soft"} {"placeholder":"text_b"} {"soft"} {"mask"}.')

# QA
# mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer, text='{"placeholder":"text_a"} {"soft": "Question:"} {"placeholder":"text_b"}? Is it correct? {"soft"} {"mask"}.')

In [115]:
softTemplate

SoftTemplate(
  (raw_embedding): Embedding(28996, 768, padding_idx=0)
)

In [119]:
wrapped_example = softTemplate.wrap_one_example(dataset[0])  # If you use template 3, don't worry the {"soft": "Question:"} is replace by an empty template, it is used to initialize the mixed template and then removed. 
print(wrapped_example)

[[{'text': 'Patient has severe chest pain and feels awful.', 'loss_ids': 0, 'shortenable_ids': 1}, {'text': '<eos>', 'loss_ids': 0, 'shortenable_ids': 0}, {'text': '<mask>', 'loss_ids': 1, 'shortenable_ids': 0}], {'guid': 0}]
