In [None]:
!pip install openprompt

In [None]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")

In [None]:
import json
from openprompt.data_utils import InputExample

f = open("/content/raw_input_examples.json")
raw_input_examples = json.load(f)

dataset = []
for example in raw_input_examples:
  datum = InputExample(guid = example["index"], text_a = example["text"])
  dataset.append(datum)

classes = ["positive", "negative"]

In [None]:
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
    text = '{"placeholder":"text_a"} It was {"mask"}.',
    tokenizer = tokenizer,
)

In [None]:
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        "negative": ["bad"],
        "positive": ["good", "wonderful", "great"],
    },
    tokenizer = tokenizer,
)

In [None]:
from openprompt import PromptForClassification
promptModel = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer,
)

In [None]:
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)

tokenizing: 2it [00:00, 205.13it/s]


In [None]:
import math

def logit_to_probability(logit):
  odds = math.exp(logit)
  prob = odds / (1 + odds)
  return prob

In [None]:
import torch

# making zero-shot inference using pretrained MLM with prompt
with torch.no_grad():
  for batch in data_loader:
    print('===================================================================')
    print("Token IDs : ", batch["input_ids"].tolist()[0])
    print("Tokens : ", tokenizer.convert_ids_to_tokens(batch['input_ids'][0]), "\n")

    lm_logits = promptModel.forward_without_verbalize(batch)
    print("LM Logits : ", lm_logits.tolist()[0])

    vocab_index = torch.argmax(lm_logits, dim = -1)   # Most Probable Vocab Index
    print("Most Probable Token : ", tokenizer.convert_ids_to_tokens(vocab_index)[0])
    print("Token Probability : ", logit_to_probability(lm_logits[0][vocab_index]), "\n")

    cls_logits = promptModel(batch)
    print("Class Logits : ", cls_logits.tolist()[0])

    output_index = torch.argmax(cls_logits, dim = -1)   # Most Probable Class Index
    print("Most Probable Output : ", classes[output_index])
    print("Output Probability : ", logit_to_probability(cls_logits[0][output_index]))

Token IDs :  [101, 3986, 16127, 1108, 1141, 1104, 1103, 4459, 1107, 7854, 18465, 1116, 1104, 1117, 1159, 119, 1135, 1108, 103, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,