In [None]:
!pip install openprompt==0.1.1 \
'torch>=1.9.0' \
'transformers>=4.10.0' \
sentencepiece==0.1.96 \
'scikit-learn>=0.24.2' \
'tqdm>=4.62.2' \
tensorboardX \
nltk \
yacs \
dill \
datasets \
rouge==1.0.0 \
scipy==1.4.1 \
fugashi \
ipadic \
unidic-lite

# BERT
OpenPrompt

In [None]:
import openprompt.plms as plms
from openprompt.plms.mlm import MLMTokenizerWrapper
from transformers import BertConfig, BertForMaskedLM, BertTokenizer

In [None]:
plms._MODEL_CLASSES['bert'] = plms.ModelClass(**{
    'config': BertConfig,
    'tokenizer': BertTokenizer,
    'model':BertForMaskedLM,
    'wrapper': MLMTokenizerWrapper,
})

In [None]:
plms._MODEL_CLASSES

# Step 1: Define a task
|

In [None]:
from openprompt.data_utils import InputExample
classes = [ 
    "lung",
    "brain"
]
dataset = [ 
      InputExample(
        guid = 0,
        text_a = "Asthma affects lungs  and can be hard to diagnose. The signs of asthma can seem like the signs of COPD, pneumonia, bronchitis, pulmonary embolism, anxiety, and heart disease.", #lung
    ),
    # InputExample(
    #     guid = 1,
    #     text_a = "COVID-19 is caused by a coronavirus called SARS-CoV-2", #virus
    # ),
    InputExample(
        guid = 2,
        text_a = "When your brain is damaged, it can affect many different things, including your memory, your sensation, and even your personality. Brain disorders include any conditions or disabilities that affect your brain.", #brain
    ),
    # InputExample(
    #     guid = 3,
    #     text_a = "Symptoms may appear 2-14 days after exposure to the virus", #virus
    # ),
        InputExample(
        guid = 4,
        text_a = """Neurodegenerative diseases cause your brain and nerves to deteriorate over time. They can change your personality and cause confusion. They can also destroy your brain’s tissue and nerves.

Some brain diseases, such as Alzheimer’s disease, may develop as you age. """, #brain
    ),
]

# Step 2: Define a Pre-trained Language Models (PLMs)

In [None]:
from transformers import AutoTokenizer, AutoModel
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "dmis-lab/biobert-v1.1")


In [None]:
tokenizer.tokenize("I think this drug is not a solution")

# Step 3: Define a Template.


In [None]:
from openprompt.prompts import ManualTemplate
# template_text = '{"placeholder":"text_a"}: This effects {"mask"}'
template_text= 'A {"mask"} disorder :  {"placeholder": "text_a"}'

promptTemplate = ManualTemplate(
    text = template_text,
    tokenizer = tokenizer,
)

# Step 4: Define a Verbalizer


In [None]:
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        "lung": ["breathe", "lungs"],
        "brain": ["head","brain"],
    },
    tokenizer = tokenizer,
)

# Step 5: Combine them into a PromptModel
Given the task, now we have a PLM, a Template and a Verbalizer,  combine them into a PromptModel. 

In [None]:
from openprompt import PromptForClassification
promptModel = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer
)

# Step 6: Define a DataLoader
A PromptDataLoader is basically a prompt version of pytorch Dataloader, which also includes a Tokenizer, a Template and a TokenizerWrapper.

In [None]:
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer, 
    template = promptTemplate, 
    tokenizer_wrapper_class=WrapperClass,
    max_seq_length=256, decoder_max_length=3, 
    batch_size=1,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head"
)


# Step 7: Train and inference
Done! We can conduct training and inference the same as other processes in Pytorch.

In [None]:
# making zero-shot inference using pretrained MLM with prompt
import torch
promptModel.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = promptModel(batch)
        print(logits)
        preds = torch.argmax(logits, dim = -1)
        print(classes[preds])
# predictions would be 1, 0 for classes 'positive', 'negative'
