In [20]:
from tqdm import tqdm
from openprompt.data_utils import PROCESSORS
import torch
from openprompt.data_utils.utils import InputExample
import argparse
import numpy as np

from openprompt import PromptDataLoader
from openprompt.prompts import ManualVerbalizer, ManualTemplate, SoftVerbalizer

from openprompt.prompts import SoftTemplate
from openprompt import PromptForClassification

from openprompt.plms.seq2seq import T5TokenizerWrapper, T5LMTokenizerWrapper
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
from openprompt.data_utils.data_sampler import FewShotSampler
from openprompt.plms import load_plm

from utils import Mimic_ICD9_Processor as MimicProcessor
import time
import os
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

import torchmetrics.functional.classification as metrics
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix

In [2]:
ckpt_dir = "./checkpoints/icd9_50/emilyalsentzer/Bio_ClinicalBERT_tempmanual2_verbsoft0/version_21-01-2022--13-41/checkpoint.ckpt"

In [4]:
loaded_model = torch.load(ckpt_dir)

In [18]:
loaded_model.keys()

dict_keys(['plm', 'template', 'verbalizer'])

In [14]:
PromptForClassification

TypeError: state_dict() missing 1 required positional argument: 'self'

In [None]:
# the original way of instantiating a model

In [21]:
plm, tokenizer, model_config, WrapperClass = load_plm("bert", "emilyalsentzer/Bio_ClinicalBERT")

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# can load the trained state dict for the plm like this

In [26]:
plm.load_state_dict(loaded_model['plm'])

<All keys matched successfully>

In [27]:
plm

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [28]:
# template

mytemplate = ManualTemplate(tokenizer=tokenizer).from_file("scripts/mimic_icd9_top50/manual_template.txt", choice=2)


In [41]:
# soft template

soft_template = SoftTemplate(model=plm, tokenizer=tokenizer, num_tokens=20, initialize_from_vocab=True).from_file(f"scripts/mimic_icd9_top50/soft_template.txt", choice=0)

In [42]:
soft_template

SoftTemplate(
  (raw_embedding): Embedding(28996, 768, padding_idx=0)
)

In [44]:
soft_template

OrderedDict([('[PAD]', 0),
             ('[unused1]', 1),
             ('[unused2]', 2),
             ('[unused3]', 3),
             ('[unused4]', 4),
             ('[unused5]', 5),
             ('[unused6]', 6),
             ('[unused7]', 7),
             ('[unused8]', 8),
             ('[unused9]', 9),
             ('[unused10]', 10),
             ('[unused11]', 11),
             ('[unused12]', 12),
             ('[unused13]', 13),
             ('[unused14]', 14),
             ('[unused15]', 15),
             ('[unused16]', 16),
             ('[unused17]', 17),
             ('[unused18]', 18),
             ('[unused19]', 19),
             ('[unused20]', 20),
             ('[unused21]', 21),
             ('[unused22]', 22),
             ('[unused23]', 23),
             ('[unused24]', 24),
             ('[unused25]', 25),
             ('[unused26]', 26),
             ('[unused27]', 27),
             ('[unused28]', 28),
             ('[unused29]', 29),
             ('[unused30]', 30),
 

In [29]:
# now try the verbalizer

soft_verb = SoftVerbalizer(tokenizer, plm, num_classes=50)

In [31]:
# now load the state dict from saved checkpoint
soft_verb.load_state_dict(loaded_model['verbalizer'])

<All keys matched successfully>

In [33]:
# now bring it all together into the prompt classification model

trained_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=soft_verb)

In [34]:
trained_model

PromptForClassification(
  (prompt_model): PromptModel(
    (plm): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(28996, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
          

# TODO

Implement a test of the trained model through this loading procedure to compare with initial training results