In [1]:
from tqdm import tqdm
from openprompt.data_utils import PROCESSORS
import torch
from openprompt.data_utils.utils import InputExample
import argparse
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import torch
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

from openprompt import PromptDataLoader
from openprompt.prompts import ManualVerbalizer, ManualTemplate, MixedTemplate, SoftVerbalizer

from openprompt.prompts import SoftTemplate
from openprompt import PromptForClassification

from openprompt.plms.seq2seq import T5TokenizerWrapper, T5LMTokenizerWrapper
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
from openprompt.data_utils.data_sampler import FewShotSampler
from openprompt.plms import load_plm

from utils import Mimic_ICD9_Processor, Mimic_ICD9_Triage_Processor, Mimic_Mortality_Processor
import time
import os
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from loguru import logger
import json
import itertools
import torchmetrics.functional.classification as metrics
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix

In [2]:
def load_trained_prompt_model(ckpt_dir, params_dir,
                              use_cuda = True):
    
    '''
    Function to reload an already trained promptmodelclassifier. At moment this still requires data/task specific 
    manual template or verbalizers to be setup as they need to point to correct scripts.
    
    Args:
        ckpt_dir: path to save promptmodel
        params_dir: path to parameters of training - in newest pipeline will be same as checkpoints dir
        
    '''
    
    # load in saved paramters for the trained model
    with open(f"{params_dir}") as f:
        params = json.load(f)
    # set up the parameters based on the training config
    plm_type = params["model"]
    plm_name = params["model_name_or_path"]
    template_type = params["template_type"]
    template_id = params["template_id"]
    verbalizer_type = params["verbalizer_type"]
    verbalizer_id = params["verbalizer_id"]
    data_dir = params["data_dir"]
    dataset_name = params["dataset"]
    scripts_path = params["scripts_path"]
    init_from_vocab = params["init_from_vocab"]
    batch_size = params["batch_size"]
    tune_plm = params["tune_plm"]
    # set up datasets first    


    dataset = {}
    if dataset_name == "icd9_50":

        logger.warning(f"Using the following dataset: {dataset_name} ")
        Processor = Mimic_ICD9_Processor
        # update data_dir
        data_dir = f"{data_dir}/top_50_icd9"

        # get different splits
        dataset['train'] = Processor().get_examples(data_dir = data_dir, mode = "train")

        # the below class labels should align with the label encoder fitted to training data
        # you will need to generate this class label text file first using the mimic processor with generate_class_labels flag to set true
        # e.g. Processor().get_examples(data_dir = data_dir, mode = "train", generate_class_labels = True)[:10000]
        class_labels =Processor().load_class_labels()
        print(f"number of classes: {len(class_labels)}")
        scriptsbase = f"{scripts_path}/mimic_icd9_top50/"
        scriptformat = "txt"
        max_seq_l = 480 # this should be specified according to the running GPU's capacity 

        batchsize_t = batch_size
        batchsize_e = batch_size
        gradient_accumulation_steps = 4
        model_parallelize = False

    elif dataset_name == "icd9_triage":
        logger.warning(f"Using the following dataset: {dataset_name} ")
        Processor = Mimic_ICD9_Triage_Processor
        # update data_dir
        data_dir = f"{data_dir}/triage"

        # get different splits
        dataset['train'] = Processor().get_examples(data_dir = data_dir, mode = "train")

        # the below class labels should align with the label encoder fitted to training data
        # you will need to generate this class label text file first using the mimic processor with generate_class_labels flag to set true
        # e.g. Processor().get_examples(data_dir = data_dir, mode = "train", generate_class_labels = True)[:10000]
        class_labels =Processor().load_class_labels()
        print(f"number of classes: {len(class_labels)}")
        scriptsbase = f"{scripts_path}/mimic_triage/"
        scriptformat = "txt"
        max_seq_l = 480 # this should be specified according to the running GPU's capacity 

        batchsize_t = batch_size
        batchsize_e = batch_size
        gradient_accumulation_steps = 4
        model_parallelize = False
        
    elif dataset_name == "mortality":
        logger.warning(f"Using the following dataset: {dataset_name} ")
        Processor = Mimic_Mortality_Processor
        # update data_dir
        data_dir = "../clinical-outcomes-data/mimic3-clinical-outcomes/mp/"
        
        
        dataset['train'] = Processor().get_examples(data_dir = data_dir, mode = "train", balance_data = False, class_weights=False, sampler_weights= False)[:1000]
        
        # the below class labels should align with the label encoder fitted to training data
        # you will need to generate this class label text file first using the mimic processor with generate_class_labels flag to set true
        # e.g. Processor().get_examples(data_dir = args.data_dir, mode = "train", generate_class_labels = True)[:10000]
        class_labels = Processor().load_class_labels()
        print(f"class labels: {class_labels}")
        print(f"number of classes: {len(class_labels)}")
        scriptsbase = f"{scripts_path}/mimic_mortality/"
        scriptformat = "txt"
        max_seq_l = 480 # this should be specified according to the running GPU's capacity 
        batchsize_t = batch_size
        batchsize_e = batch_size
        gradient_accumulation_steps = 4
        model_parallelize = False


    else:
        #TODO implement icd9 triage and mimic readmission
        raise NotImplementedError
    
    
    ######### set up the pretrained model etc ###########
    
    # initialise the pretrained language model
    plm, tokenizer, model_config, WrapperClass = load_plm(plm_type, plm_name)    
    
    
    # load the already trained prompt model, which will consist of a separate state_dict for the plm/template/verbalizer
    loaded_model = torch.load(f"{ckpt_dir}/best-checkpoint.ckpt")
    
    
    # now load the trained state_dict into the plm model if it was tuned during training
    if tune_plm:
        freeze_plm = False
        print("PLM was tuned during training - loading the weights!")
#         plm.load_state_dict(loaded_model['plm'])
    else:
        freeze_plm=True
        print("PLM was frozen during training to initializing from original pretrained weights!")
    
    
    # decide which template and verbalizer to use
    if template_type == "manual":
        print(f"manual template selected, with id :{template_id}")
        mytemplate = ManualTemplate(tokenizer=tokenizer).from_file(f"{scriptsbase}/manual_template.txt", choice=template_id)

    elif template_type == "soft":
        print(f"soft template selected, with id :{template_id}, will load template weights")
        mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, num_tokens=params['soft_token_num'], initialize_from_vocab=init_from_vocab).from_file(f"{scriptsbase}/soft_template.txt", choice=template_id)
        # now load the state_dict from ckpt
#         mytemplate.load_state_dict(loaded_model['template'])

    elif template_type == "mixed":
        print(f"mixed template selected, with id :{template_id}, will load template weights")
        mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer).from_file(f"{scriptsbase}/mixed_template.txt", choice=template_id)
#         mytemplate.load_state_dict(loaded_model['template'])
    # now set verbalizer
    if verbalizer_type == "manual":
        print(f"manual verbalizer selected, with id :{verbalizer_id}")
        myverbalizer = ManualVerbalizer(tokenizer, classes=class_labels).from_file(f"{scriptsbase}/manual_verbalizer.{scriptformat}", choice=verbalizer_id)

    elif verbalizer_type == "soft":
        print(f"soft verbalizer selected!")
        myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=len(class_labels))
        # now load the state dict from saved checkpoint
#         myverbalizer.load_state_dict(loaded_model['verbalizer'])
        
#     # now bring it all together into the prompt classification model

    trained_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=freeze_plm)
    # now load state dicts
    trained_model.load_state_dict(state_dict = loaded_model)
    
    # send to cuda
    if use_cuda:
        print("using cuda!")
        trained_model =  trained_model.cuda()
   
    
    
    return trained_model

In [None]:
# testing the number of parameters based on mixed tokenizer of varying number of soft tokens

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


soft verbalizer selected!


# finetuned model

In [74]:
ckpt_dir = "./logs/icd9_triage/emilyalsentzer/Bio_ClinicalBERT_tempmixed2_verbsoft0_fewshot_32/version_28-03-2022--10-29/checkpoints/"
params_dir = f"{ckpt_dir}/hparams.txt"

In [75]:
trained_model = load_trained_prompt_model(ckpt_dir, params_dir, use_cuda = False)



loading train data
data path provided was: ../mimic3-icd9-data/intermediary-data//triage/train.csv


9559it [00:00, 14305.11it/s]


number of classes: 7


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


PLM was tuned during training - loading the weights!
mixed template selected, with id :2, will load template weights
soft verbalizer selected!


In [9]:
# trained_model.prompt_model.plm.bert

In [10]:
# trained_model.prompt_model.plm

In [None]:
trained_model

In [14]:
trained_model.template

MixedTemplate(
  (raw_embedding): Embedding(28996, 768, padding_idx=0)
  (soft_embedding): Embedding(4, 768)
)

In [77]:
sum(p.numel() for p in trained_model.verbalizer.parameters() if p.requires_grad)

626500

In [2]:
# function for getting trainable parameters of a model

def get_n_trainable_params(model):    

    
    # all trainable
    num_total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # split into the plm and classisifcation head
    num_plm_trainable = sum(p.numel() for p in model.prompt_model.plm.parameters() if p.requires_grad)
    
    # template trainable
    num_template_trainable = sum(p.numel() for p in model.template.soft_embedding.parameters() if p.requires_grad)
    
    # verbalizer trainable 
    num_verbalizer_trainable = sum(p.numel() for p in model.verbalizer.parameters() if p.requires_grad)
    
    # assert sum of the two = total
    assert num_plm_trainable+num_template_trainable+num_verbalizer_trainable == num_total_trainable
    
    print(f"Number of trainable parameters of PLM: {num_plm_trainable}\n")
    print('#'*50)
    print(f"Number of trainable parameters of template: {num_template_trainable}\n")
    print('#'*50)
    print(f"Number of trainable parameters of verbalizer: {num_verbalizer_trainable}\n")
    print('#'*50)
    print(f"Total number of trainable parameters of whole model: {num_total_trainable}")

In [6]:
get_n_trainable_params(trained_model)

Number of trainable parameters of PLM: 108340804

##################################################
Number of trainable parameters of template: 3072

##################################################
Number of trainable parameters of verbalizer: 626500

##################################################
Total number of trainable parameters of whole model: 108970376


# frozen model

In [5]:


ckpt_dir = "./logs/icd9_triage/frozen_plm/emilyalsentzer/Bio_ClinicalBERT_tempmixed2_verbsoft0_fewshot_32/version_28-03-2022--10-56/checkpoints/"
params_dir = f"{ckpt_dir}/hparams.txt"

trained_model = load_trained_prompt_model(ckpt_dir, params_dir, use_cuda = False)



loading train data
data path provided was: ../mimic3-icd9-data/intermediary-data//triage/train.csv


9559it [00:00, 10399.51it/s]


number of classes: 7


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


PLM was frozen during training to initializing from original pretrained weights!
mixed template selected, with id :2, will load template weights
soft verbalizer selected!


In [81]:
sum(p.numel() for p in trained_model.verbalizer.parameters() if p.requires_grad)

626500

In [17]:
trained_model.verbalizer.group_parameters_1[1].shape

torch.Size([768, 768])

In [22]:
get_n_trainable_params(trained_model)

Number of trainable parameters of PLM: 0

##################################################
Number of trainable parameters of template: 3072

##################################################
Number of trainable parameters of verbalizer: 626500

##################################################
Total number of trainable parameters of whole model: 629572


# look at a manual verb option to confirm that there will be no tuneable verb parameters



In [9]:
ckpt_dir = "./logs/icd9_triage/emilyalsentzer/Bio_ClinicalBERT_tempmixed2_verbmanual0_full_100/version_23-03-2022--14-05/checkpoints/"
params_dir = f"{ckpt_dir}/hparams.txt"

trained_model = load_trained_prompt_model(ckpt_dir, params_dir, use_cuda = False)



loading train data
data path provided was: ../mimic3-icd9-data/intermediary-data//triage/train.csv


9559it [00:00, 14207.03it/s]


number of classes: 7


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


PLM was tuned during training - loading the weights!
mixed template selected, with id :2, will load template weights
manual verbalizer selected, with id :0


In [10]:
get_n_trainable_params(trained_model)

Number of trainable parameters of PLM: 108340804

##################################################
Number of trainable parameters of template: 3072

##################################################
Number of trainable parameters of verbalizer: 0

##################################################
Total number of trainable parameters of whole model: 108343876


In [11]:
# now frozen verb manual
ckpt_dir = "./logs/icd9_triage/frozen_plm/emilyalsentzer/Bio_ClinicalBERT_tempmixed2_verbmanual0_full_100/version_23-03-2022--13-00/checkpoints/"
params_dir = f"{ckpt_dir}/hparams.txt"

trained_model = load_trained_prompt_model(ckpt_dir, params_dir, use_cuda = False)



loading train data
data path provided was: ../mimic3-icd9-data/intermediary-data//triage/train.csv


9559it [00:00, 12738.38it/s]


number of classes: 7


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


PLM was frozen during training to initializing from original pretrained weights!
mixed template selected, with id :2, will load template weights
manual verbalizer selected, with id :0


In [12]:
get_n_trainable_params(trained_model)

Number of trainable parameters of PLM: 0

##################################################
Number of trainable parameters of template: 3072

##################################################
Number of trainable parameters of verbalizer: 0

##################################################
Total number of trainable parameters of whole model: 3072


# now want to look at how mixed template dictates number of parameters

A mixed template consists of soft tokens that are initialized at weights from PLM, which will then be tuned, alongside manual tokens which will not be tuned. 

The dimension of these embeddings is typically going to be restricted to that of the PLM embeddings, e.g. 768 for most if not all bert based models.

i.e. n_trainable_template_embeddings = 768*N_soft_tokens

In [3]:
# instantiate the plm etc

freeze_plm = True
plm, tokenizer, model_config, WrapperClass = load_plm('bert', 'emilyalsentzer/Bio_ClinicalBERT')   




Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [24]:
plm.cls

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=28996, bias=True)
  )
)

In [20]:
head_name = [n for n,c in plm.named_children()][-1]
head_name

'cls'

In [85]:
sum(p.numel() for p in plm.parameters() if p.requires_grad)

0

In [4]:
temp_ids = [4]

for temp_id in temp_ids:
    print(f"Working on mixed templated id: {temp_id}")
    
    # mixed template 
    mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer).from_file(f"./scripts/mimic_triage/mixed_template.txt", choice=temp_id)


    print(f"soft verbalizer selected!")
    myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=7)
    

prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=freeze_plm)

get_n_trainable_params(prompt_model)

print("#"*50)

Working on mixed templated id: 4
soft verbalizer selected!
Number of trainable parameters of PLM: 0

##################################################
Number of trainable parameters of template: 1536

##################################################
Number of trainable parameters of verbalizer: 626500

##################################################
Total number of trainable parameters of whole model: 628036
##################################################


In [97]:
for p in prompt_model.verbalizer.parameters():
    if p.requires_grad:
        print(p)
        print(p.shape)

Parameter containing:
tensor([-0.4309, -0.3870, -0.4145,  ..., -0.5897, -0.5963, -0.5895],
       requires_grad=True)
torch.Size([28996])
Parameter containing:
tensor([[ 0.2472, -0.0149, -0.0300,  ..., -0.0315, -0.0111,  0.0043],
        [ 0.1146,  0.2637, -0.0006,  ..., -0.0473, -0.0150, -0.0258],
        [-0.0057, -0.0517,  0.2008,  ..., -0.0372, -0.1174, -0.0789],
        ...,
        [ 0.0089, -0.0227, -0.0123,  ...,  0.3439,  0.0539,  0.0234],
        [ 0.0407, -0.0772,  0.0116,  ...,  0.0249,  0.3978,  0.0338],
        [ 0.0570, -0.0126, -0.1157,  ...,  0.0136,  0.0458,  0.2184]],
       requires_grad=True)
torch.Size([768, 768])
Parameter containing:
tensor([ 1.6920e-02,  8.6847e-02,  6.4819e-02,  4.8258e-02,  3.1324e-02,
         3.8392e-02,  9.7983e-03,  2.8769e-02,  3.9382e-02,  2.7226e-02,
         3.8215e-02,  3.8464e-02,  4.5411e-02,  3.8577e-02,  5.3082e-02,
        -1.9338e-02,  3.6245e-02,  8.1958e-02,  4.4114e-02,  8.4626e-02,
         1.1240e-01,  8.1564e-02,  2.5398e

In [86]:
myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=7)

In [87]:
sum(p.numel() for p in myverbalizer.parameters() if p.requires_grad)

5376

5376

In [20]:
768*768

589824