In [1]:
from openprompt.data_utils import InputExample
import torch
import pandas as pd
import os
import json, csv
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable

from openprompt.utils.logging import logger

from openprompt.data_utils.utils import InputExample
from openprompt.data_utils.data_processor import DataProcessor

import pandas as pd
import numpy as np
from tqdm import tqdm

from torchnlp.encoders import LabelEncoder

In [3]:
# top 50 icd 9 data

# set a local pc directory if not on alejos machines
local_pc = False
if local_pc:
    mimic_data_dir = "C://Users/ntaylor/Documents/GitHub/Neural_Networks/DPhil_NLP/mimic-icd9-classification/clinical-longformer/data/intermediary-data/top_50_icd9"
else:

    mimic_data_dir = "/home/niallt/NLP_DPhil/DPhil_projects/mimic-icd9-classification//data/intermediary-data/top_50_icd9"
mimic_data = pd.read_csv(f"{mimic_data_dir}/train.csv")

In [4]:
mimic_data.head()

Unnamed: 0,text,label
0,: : : Sex: F Service: CARDIOTHORACIC Allergies...,4240
1,: : : Sex: F Service: NEONATOLOGY HISTORY: wee...,V3001
2,: : : Sex: M Service: CARDIOTHORACIC Allergies...,41041
3,: : : Sex: F Service: MEDICINE Allergies: Peni...,51881
4,: : : Sex: F Service: CARDIOTHORACIC Allergies...,3962


fortunatley the label encoder class seems to always map the same code from this data set to same idx. What we want to do now is create a list of the icd9 codes as class names in the order by which the label encoder indexes them. This is useful for the autoprompt pipeline...


In [32]:
le = LabelEncoder(np.unique(mimic_data.label).tolist(), reserved_labels = [])

In [37]:
icd9_list = list(le.tokens.keys())
icd9_list

['0380',
 '03811',
 '03842',
 '03849',
 '0389',
 '042',
 '1623',
 '1983',
 '29181',
 '3962',
 '41011',
 '41041',
 '41071',
 '41401',
 '41519',
 '4240',
 '4241',
 '4271',
 '42731',
 '4280',
 '42823',
 '430',
 '431',
 '4321',
 '43310',
 '43411',
 '43491',
 '4373',
 '44101',
 '4414',
 '486',
 '5070',
 '51881',
 '51884',
 '53240',
 '56212',
 '5712',
 '5715',
 '5761',
 '5770',
 '5789',
 '5849',
 '85221',
 '99662',
 '99811',
 '99859',
 'V3000',
 'V3001',
 'V3101',
 'V3401']

In [38]:
# write these to files as the classes for prompt learning pipeline
save_dir = "../prompt-based-models/scripts/"

textfile = open(f"{save_dir}/labels.txt", "w")

for element in icd9_list:

    textfile.write(element + "\n")

textfile.close()

In [69]:
# read back in as list

my_file = open(f"{save_dir}/labels.txt", "r")

content = my_file.read().split("\n")

content

['0380',
 '03811',
 '03842',
 '03849',
 '0389',
 '042',
 '1623',
 '1983',
 '29181',
 '3962',
 '41011',
 '41041',
 '41071',
 '41401',
 '41519',
 '4240',
 '4241',
 '4271',
 '42731',
 '4280',
 '42823',
 '430',
 '431',
 '4321',
 '43310',
 '43411',
 '43491',
 '4373',
 '44101',
 '4414',
 '486',
 '5070',
 '51881',
 '51884',
 '53240',
 '56212',
 '5712',
 '5715',
 '5761',
 '5770',
 '5789',
 '5849',
 '85221',
 '99662',
 '99811',
 '99859',
 'V3000',
 'V3001',
 'V3101',
 'V3401',
 '']

In [21]:
class Mimic_ICD9_Processor(DataProcessor):


    '''
    Function to convert mimic icd9 dataset to a open prompt ready dataset. 
    
    We also instantiate a LabelEncoder() class which is fitted to the given dataset. Fortunately it appears
    to create the same mapping for each set, given each set contains all classes. 

    This is not ideal, and need to think of a better way to store the label encoder based on training data.
    

  
    
    '''
    # TODO Test needed
    def __init__(self):
        super().__init__()        

    def get_examples(self, data_dir, mode = "train", label_encoder = None,
                     generate_class_labels = False, class_labels_save_dir = "scripts/"):

        path = f"{data_dir}/{mode}.csv"
        print(f"loading {mode} data")
        print(f"data path provided was: {path}")
        examples = []
        df = pd.read_csv(path)

        # need to either initializer and fit the label encoder if not provided
        if label_encoder is None:
            self.label_encoder = LabelEncoder(np.unique(df.label).tolist(), reserved_labels = [])
        else: 
            print("we were given a label encoder")
            self.label_encoder = label_encoder

        
        for idx, row in tqdm(df.iterrows()):
#             print(row)
            body, label = row['text'],row['label']
            label = self.label_encoder.encode(label)
#             print(f"body : {body}")
#             print(f"label: {label}")
#             print(f"labels original: {self.label_encoder.index_to_token[label]}")
            
            text_a = body.replace('\\', ' ')

            example = InputExample(
                guid=str(idx), text_a=text_a, label=int(label))
            examples.append(example)
            
        logger.info(f"Returning {len(examples)} samples!") 

#         now we want to return a list of the non-encoded labels based on the fitted label encoder
        if generate_class_labels:
            logger.info(f"Saving class labels to: {class_labels_save_dir}")
            class_labels = self.generate_class_labels()
            # write these to files as the classes for prompt learning pipeline
            save_dir = "../prompt-based-models/scripts/"

            textfile = open(f"{class_labels_save_dir}/labels_test.txt", "w")

            for element in class_labels:

                textfile.write(element + "\n")

            textfile.close() 

        return examples

    def generate_class_labels(self):
        # now we want to return a list of the non-encoded labels based on the fitted label encoder
        try:
            return list(self.label_encoder.tokens.keys())
        except:
            print("No class labels as haven't fitted any data yet. Run get_examples first!")
            raise NotImplementedError

    
    def load_class_labels(self, file_path = "./scripts/labels.txt"):
        # function to load pre-generated class labels
        # returns list of class labels

        text_file = open(f"{file_path}", "r")

        class_labels = text_file.read().split("\n")

        return class_labels

In [22]:
# get different splits
dataset = {}
dataset['train'] = Mimic_ICD9_Processor().get_examples(data_dir = f"{mimic_data_dir}", mode = "train")[:10]
dataset['valid'] = Mimic_ICD9_Processor().get_examples(data_dir = f"{mimic_data_dir}",mode = "valid", generate_class_labels = False )


loading train data
data path provided was: /home/niallt/NLP_DPhil/DPhil_projects/mimic-icd9-classification//data/intermediary-data/top_50_icd9/train.csv


14360it [00:01, 9275.16it/s]


loading valid data
data path provided was: /home/niallt/NLP_DPhil/DPhil_projects/mimic-icd9-classification//data/intermediary-data/top_50_icd9/valid.csv


4693it [00:00, 7903.11it/s]


In [23]:
dataset['train']

[{
   "guid": "0",
   "label": 15,
   "meta": {},
   "text_a": ": : : Sex: F Service: CARDIOTHORACIC Allergies: Patient recorded as having No Known Allergies to Drugs : Chief Complaint: SOB with exertion, heart murmur since y/o Major Surgical or Invasive Procedure: Mitral valve replacement(mm CE tissue History of Present Illness: y/o female with known MVP who was diagnosed with a heart murmur at age . She was evaluated with serial TTE's which showed worsening MR. Echo showed LVEF % with Mitral valve regurgitant fraction of %. She denies any symptoms. Past Medical History: Hyperlipidemia, MVP/MR, Depression, Obesity Social History: social Etoh, live with mother, deniesDA or tobacco use Family History: noncontributory Physical Exam: y/o F in bed NAD Neuro AA&Ox, nonfocal Chest CTAB resp unlab median sternotomy stable, c/d/i no d/c, RRR no m/r/g chest tubes and epicardial wires removed. Abd S/NT/ND/BS+ EXT warm with trace edema Pertinent Results: RADIOLOGY Preliminary Report CHEST (PA & L

In [63]:
# dataset['valid'][0]

In [81]:
labels = Mimic_ICD9_Processor().load_class_labels()
labels

['0380',
 '03811',
 '03842',
 '03849',
 '0389',
 '042',
 '1623',
 '1983',
 '29181',
 '3962',
 '41011',
 '41041',
 '41071',
 '41401',
 '41519',
 '4240',
 '4241',
 '4271',
 '42731',
 '4280',
 '42823',
 '430',
 '431',
 '4321',
 '43310',
 '43411',
 '43491',
 '4373',
 '44101',
 '4414',
 '486',
 '5070',
 '51881',
 '51884',
 '53240',
 '56212',
 '5712',
 '5715',
 '5761',
 '5770',
 '5789',
 '5849',
 '85221',
 '99662',
 '99811',
 '99859',
 'V3000',
 'V3001',
 'V3101',
 'V3401',
 '']

In [41]:
df = pd.read_csv("../data/intermediary-data/triage/train.csv")
df.head()

Unnamed: 0,text,label,triage-category
0,: : : Sex: F Service: CARDIOTHORACIC Allergies...,4240,Cardiology
1,: : : Sex: F Service: NEONATOLOGY HISTORY: wee...,V3001,Obstetrics
2,: : : Sex: M Service: CARDIOTHORACIC Allergies...,41041,Cardiology
3,: : : Sex: F Service: MEDICINE Allergies: Peni...,51881,Respiratory
4,: : : Sex: M Service: ADMISSION DIAGNOSIS: . S...,41401,Cardiology


In [46]:
# one for the triage data

class Mimic_ICD9_Triage_Processor(DataProcessor):


    '''
    Function to convert mimic icd9 triage dataset to a open prompt ready dataset. 
    
    We also instantiate a LabelEncoder() class which is fitted to the given dataset. Fortunately it appears
    to create the same mapping for each set, given each set contains all classes. 

    This is not ideal, and need to think of a better way to store the label encoder based on training data.
    

  
    
    '''
    # TODO Test needed
    def __init__(self):
        super().__init__()        

    def get_examples(self, data_dir, mode = "train", label_encoder = None,
                     generate_class_labels = False, class_labels_save_dir = "./scripts/mimic_triage/"):

        path = f"{data_dir}/{mode}.csv"
        print(f"loading {mode} data")
        print(f"data path provided was: {path}")
        examples = []
        df = pd.read_csv(path)


        # need to either initializer and fit the label encoder if not provided
        if label_encoder is None:
            self.label_encoder = LabelEncoder(np.unique(df["triage-category"]).tolist(), reserved_labels = [])
        else: 
            print("we were given a label encoder")
            self.label_encoder = label_encoder

        
        for idx, row in tqdm(df.iterrows()):
#             print(row)
            body, label = row['text'],row['triage-category']
            label = self.label_encoder.encode(label)
#             print(f"body : {body}")
#             print(f"label: {label}")
#             print(f"labels original: {self.label_encoder.index_to_token[label]}")
            
            text_a = body.replace('\\', ' ')

            example = InputExample(
                guid=str(idx), text_a=text_a, label=int(label))
            examples.append(example)
            
        print(f"Returning {len(examples)} samples!") 

#         now we want to return a list of the non-encoded labels based on the fitted label encoder
        if generate_class_labels:
        
            if not os.path.exists(class_labels_save_dir):
                os.makedirs(class_labels_save_dir)
            print(f"Saving class labels to: {class_labels_save_dir}")
            class_labels = self.generate_class_labels()
            # write these to files as the classes for prompt learning pipeline           

            textfile = open(f"{class_labels_save_dir}/labels.txt", "w")

            for element in class_labels:

                textfile.write(element + "\n")

            textfile.close() 

        return examples

    def generate_class_labels(self):
        # now we want to return a list of the non-encoded labels based on the fitted label encoder
        try:
            return list(self.label_encoder.tokens.keys())
        except:
            print("No class labels as haven't fitted any data yet. Run get_examples first!")
            raise NotImplementedError

    
    def load_class_labels(self, file_path = "./scripts/mimic_triage/labels.txt"):
        # function to load pre-generated class labels
        # returns list of class labels

        text_file = open(f"{file_path}", "r")

        class_labels = text_file.read().split("\n")

        return class_labels

In [47]:
triage_train = Mimic_ICD9_Triage_Processor().get_examples(data_dir = f"../data/intermediary-data/triage", mode = "train", generate_class_labels= True)

loading train data
data path provided was: ../data/intermediary-data/triage/train.csv


9559it [00:01, 7669.29it/s]

Returning 9559 samples!
Saving class labels to: ./scripts/mimic_triage/





In [30]:
triage_train

[{
   "guid": "0",
   "label": 5,
   "meta": {},
   "text_a": ": : : Sex: F Service: CARDIOTHORACIC Allergies: Patient recorded as having No Known Allergies to Drugs : Chief Complaint: SOB with exertion, heart murmur since y/o Major Surgical or Invasive Procedure: Mitral valve replacement(mm CE tissue History of Present Illness: y/o female with known MVP who was diagnosed with a heart murmur at age . She was evaluated with serial TTE's which showed worsening MR. Echo showed LVEF % with Mitral valve regurgitant fraction of %. She denies any symptoms. Past Medical History: Hyperlipidemia, MVP/MR, Depression, Obesity Social History: social Etoh, live with mother, deniesDA or tobacco use Family History: noncontributory Physical Exam: y/o F in bed NAD Neuro AA&Ox, nonfocal Chest CTAB resp unlab median sternotomy stable, c/d/i no d/c, RRR no m/r/g chest tubes and epicardial wires removed. Abd S/NT/ND/BS+ EXT warm with trace edema Pertinent Results: RADIOLOGY Preliminary Report CHEST (PA & LA

# adapt below to work with mimic data

In [10]:
# load pretrained language model (plm)


from openprompt.plms import load_plm

# plm, tokenizer, model_config, WrapperClass = load_plm("t5", "t5-base")
# plm, tokenizer, model_config, WrapperClass = load_plm("t5", "razent/SciFive-base-Pubmed_PMC")

# plm, tokenizer, model_config, WrapperClass = load_plm("roberta", "roberta-large")
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "emilyalsentzer/Bio_ClinicalBERT")


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
tokenizer

PreTrainedTokenizer(name_or_path='razent/SciFive-base-Pubmed_PMC', vocab_size=32100, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>',

In [4]:
plm

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dr

In [4]:
plm

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [24]:
# set up templates - either manual, knowledgeable or soft
from openprompt.prompts import ManualTemplate, MixedTemplate, SoftTemplate
# mytemplate = ManualTemplate(tokenizer=tokenizer, text='{"placeholder":"text_a"} {"placeholder":"text_b"} In this sentence, the topic is {"mask"}.')
# mytemplate = ManualTemplate(tokenizer=tokenizer).from_file("scripts/mimic_icd9_top50/manual_template.txt", choice=2)
# mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer).from_file(f"scripts/mimic_icd9_top50/mixed_template.txt", choice=0)
# {"placeholder": "text_a"} {"soft": "This"} patient {"soft":"has diagnosis"} {"mask"}.
mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, num_tokens=20, initialize_from_vocab=True).from_file(f"scripts/mimic_icd9_top50/soft_template.txt", choice=0)


wrapped_example = mytemplate.wrap_one_example(dataset['train'][0]) 
print(wrapped_example)

[[{'text': ": : : Sex: F Service: CARDIOTHORACIC Allergies: Patient recorded as having No Known Allergies to Drugs : Chief Complaint: SOB with exertion, heart murmur since y/o Major Surgical or Invasive Procedure: Mitral valve replacement(mm CE tissue History of Present Illness: y/o female with known MVP who was diagnosed with a heart murmur at age . She was evaluated with serial TTE's which showed worsening MR. Echo showed LVEF % with Mitral valve regurgitant fraction of %. She denies any symptoms. Past Medical History: Hyperlipidemia, MVP/MR, Depression, Obesity Social History: social Etoh, live with mother, deniesDA or tobacco use Family History: noncontributory Physical Exam: y/o F in bed NAD Neuro AA&Ox, nonfocal Chest CTAB resp unlab median sternotomy stable, c/d/i no d/c, RRR no m/r/g chest tubes and epicardial wires removed. Abd S/NT/ND/BS+ EXT warm with trace edema Pertinent Results: RADIOLOGY Preliminary Report CHEST (PA & LAT : AM CHEST (PA & LAT Reason: assess LLL atelectas

In [21]:
from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer, 
    tokenizer_wrapper_class=WrapperClass, max_seq_length=512, decoder_max_length=3, 
    batch_size=2,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="tail")
# next(iter(train_dataloader))

# ## Define the verbalizer
# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:

from openprompt.prompts import SoftVerbalizer, ManualVerbalizer
import torch

# for example the verbalizer contains multiple label words in each class
# myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=4,
#          label_words=["politics", "sports", "business", "technology"])
# or without label words
myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=50)

# or manual
# myverbalizer = ManualVerbalizer(tokenizer, num_classes=4).from_file("scripts/TextClassification/agnews/manual_verbalizer.txt")



tokenizing: 10it [00:00, 36.82it/s]


In [17]:
myverbalizer

SoftVerbalizer(
  (head): RobertaLMHead(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (decoder): Linear(in_features=1024, out_features=50, bias=False)
  )
)

In [22]:



from openprompt import PromptForClassification

use_cuda = False
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
    prompt_model=  prompt_model.cuda()

# ## below is standard training


# from transformers import  AdamW, get_linear_schedule_with_warmup
# loss_func = torch.nn.CrossEntropyLoss()

# no_decay = ['bias', 'LayerNorm.weight']

# # it's always good practice to set no decay to biase and LayerNorm parameters
# optimizer_grouped_parameters1 = [
#     {'params': [p for n, p in prompt_model.plm.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
#     {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# ]

# # Using different optimizer for prompt parameters and model parameters

# # optimizer_grouped_parameters2 = [
# #     {'params': prompt_model.verbalizer.group_parameters_1, "lr":3e-5},
# #     {'params': prompt_model.verbalizer.group_parameters_2, "lr":3e-4},
# # ]


# optimizer1 = AdamW(optimizer_grouped_parameters1, lr=3e-5)
# # optimizer2 = AdamW(optimizer_grouped_parameters2)


# for epoch in range(5):
#     print(f"On epoch: {epoch}")
#     tot_loss = 0 
#     for step, inputs in enumerate(train_dataloader):
#         if use_cuda:
#             inputs = inputs.cuda()
#         logits = prompt_model(inputs)
#         labels = inputs['label']
#         loss = loss_func(logits, labels)
#         loss.backward()
#         tot_loss += loss.item()
#         optimizer1.step()
#         optimizer1.zero_grad()
#         # optimizer2.step()
#         # optimizer2.zero_grad()
#         print(tot_loss/(step+1))
    
# # ## evaluate

# # %%

# print("running validation!")
# validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer, 
#     tokenizer_wrapper_class=WrapperClass, max_seq_length=512, decoder_max_length=3, 
#     batch_size=2,shuffle=False, teacher_forcing=False, predict_eos_token=False,
#     truncate_method="head")

# prompt_model.eval()

# allpreds = []
# alllabels = []
# with torch.no_grad():
#     for step, inputs in enumerate(validation_dataloader):
#         if use_cuda:
#             inputs = inputs.cuda()
#         logits = prompt_model(inputs)
#         labels = inputs['label']
#         alllabels.extend(labels.cpu().tolist())
#         allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

# acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
# print("validation:",acc)


# test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer, 
#     tokenizer_wrapper_class=WrapperClass, max_seq_length=512, decoder_max_length=3, 
#     batch_size=2,shuffle=False, teacher_forcing=False, predict_eos_token=False,
#     truncate_method="head")
# allpreds = []
# alllabels = []
# with torch.no_grad():
#     for step, inputs in enumerate(test_dataloader):
#         if use_cuda:
#             inputs = inputs.cuda()
#         logits = prompt_model(inputs)
#         labels = inputs['label']
#         alllabels.extend(labels.cpu().tolist())
#         allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
# acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
# print("test:", acc)  # roughly ~0.85

In [23]:
prompt_model

PromptForClassification(
  (prompt_model): PromptModel(
    (plm): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(28996, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
          