In [4]:
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0,6,7'

from functools import partial
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler




from typing import Dict
from tqdm import tqdm
from openprompt.data_utils import PROCESSORS
import torch
from openprompt.data_utils.utils import InputExample
import argparse
import numpy as np
import pandas as pd
import seaborn as sn

from openprompt import PromptDataLoader
from openprompt.prompts import ManualVerbalizer, ManualTemplate, SoftVerbalizer

from openprompt.prompts import SoftTemplate, MixedTemplate
from openprompt import PromptForClassification


from transformers import  AdamW, get_linear_schedule_with_warmup,get_constant_schedule_with_warmup  # use AdamW is a standard practice for transformer 
from transformers.optimization import Adafactor, AdafactorSchedule  # use Adafactor is the default setting for T5

# from openprompt.utils.logging import logger
from loguru import logger

from utils import Mimic_ICD9_Processor, Mimic_ICD9_Triage_Processor, Mimic_Mortality_Processor, customPromptDataLoader
import time
import os
from datetime import datetime
# from torch.utils.tensorboard import SummaryWriter
from utils import SummaryWriter # this version is better for logging hparams with metrics..

import torchmetrics.functional.classification as metrics
from sklearn.metrics import balanced_accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix, roc_auc_score 

from torch.utils.data.sampler import RandomSampler, WeightedRandomSampler

import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

import json
import itertools
from collections import Counter

import os 
# # Kill all processes on GPU 6 and 7
# os.system("""kill $(nvidia-smi | awk '$5=="PID" {p=1} p && $2 >= 6 && $2 < 7 {print $5}')""")

'''
Script to run different setups of prompt learning.

Right now this is primarily set up for the mimic_top50_icd9 task, although it is quite flexible to other datasets. Any datasets need a corresponding processor class in utils.


example usage. python prompt_experiment_runner.py --model bert --model_name_or_path bert-base-uncased --num_epochs 10 --tune_plm

other example usage:
- python prompt_experiment_runner.py --model t5 --model_name_or_path razent/SciFive-base-Pubmed_PMC --num_epochs 10 --template_id 0 --template_type soft --max_steps 15000 --tune_plm


'''




import random

from openprompt.utils.reproduciblity import set_seed
set_seed(42)

from openprompt.plms.seq2seq import T5TokenizerWrapper, T5LMTokenizerWrapper
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
from openprompt.data_utils.data_sampler import FewShotSampler
from openprompt.plms import load_plm


# set up some variables to add to checkpoint and logs filenames
time_now = str(datetime.now().strftime("%d-%m-%Y--%H-%M"))
version = f"version_{time_now}"


dataset = "icd9_triage"
model_name_or_path = "emilyalsentzer/Bio_ClinicalBERT"
template_id = 2
verbalizer_id = 0
template_type = 'mixed'
verbalizer_type = 'soft'
model = 'bert'
scripts_path = './scripts/'
plm_lr = 1e-05
prompt_lr = 0.3
warmup_step_prompt = 50
plm_warmup_steps = 50
num_epochs = 5
num_samples = 5
batch_size = 4
gpu_num = 0
optimizer="adafactor" 
training_size = "fewshot"
tune_plm = True
project_root = '/mnt/sdg/niallt/saved_models/mimic-tasks/prompt-based-models/'
few_shot_n = 32

    # actually want to save the checkpoints and logs in same place now. Becomes a lot easier to manage later
if tune_plm:
    logger.warning("Unfreezing the plm - will be updated during training")
    freeze_plm = False
    # set checkpoint, logs and params save_dirs    
    logs_dir = f"{project_root}/raytune_results/{dataset}/{model_name_or_path}_temp{template_type}{template_id}_verb{verbalizer_type}{verbalizer_id}_{training_size}_{few_shot_n}/{version}"

else:
    logger.warning("Freezing the plm")
    freeze_plm = True
    # set checkpoint, logs and params save_dirs    
    logs_dir = f"{project_root}/raytune_results/{dataset}/frozen_plm/{model_name_or_path}_temp{template_type}{template_id}_verb{verbalizer_type}{verbalizer_id}_{training_size}_{few_shot_n}/{version}"

# set up tensorboard logger
writer = SummaryWriter(logs_dir)

def train_mimic(prompt_model, train_dataloader, num_epochs, mode = "train", 
                ckpt_dir = None, dataset = "icd9_triage",
                data_dir = data_dir,
                config = None):

    

    plm, tokenizer, model_config, WrapperClass = load_plm(model, model_name_or_path)

    # edit based on whether or not plm was frozen during training


    # initialise empty dataset
    dataset = {}

    # crude setting of sampler to None - changed for mortality with umbalanced dataset

    sampler = None
    # Below are multiple dataset examples, although right now just mimic ic9-top50. 
    if dataset == "icd9_triage":
        logger.warning(f"Using the following dataset: {dataset} ")
        Processor = Mimic_ICD9_Triage_Processor
        # update data_dir
        data_dir = f"{data_dir}/triage"

        ce_class_weights = ce_class_weights
        sampler_weights = sampler_weights    
        balance_data = balance_data

        # get different splits
        dataset['train'] = Processor().get_examples(data_dir = data_dir, mode = "train")
        dataset['validation'] = Processor().get_examples(data_dir = data_dir, mode = "valid")
        dataset['test'] = Processor().get_examples(data_dir = data_dir, mode = "test")
        # the below class labels should align with the label encoder fitted to training data
        # you will need to generate this class label text file first using the mimic processor with generate_class_labels flag to set true
        # e.g. Processor().get_examples(data_dir = data_dir, mode = "train", generate_class_labels = True)[:10000]
        class_labels =Processor().load_class_labels()
        print(f"number of classes: {len(class_labels)}")
        scriptsbase = f"{scripts_path}/mimic_triage/"
        scriptformat = "txt"
        max_seq_l = 480 # this should be specified according to the running GPU's capacity 
        
        batchsize_t = config['batch_size'] 
        batchsize_e = config['batch_size'] 
        gradient_accumulation_steps = config['gradient_accum_steps']
        model_parallelize = False # if multiple gpus are available, one can use model_parallelize

    else:
        
        raise NotImplementedError


    # Now define the template and verbalizer. 
    # Note that soft template can be combined with hard template, by loading the hard template from file. 
    # For example, the template in soft_template.txt is {}
    # The choice_id 1 is the hard template 

    # decide which template and verbalizer to use
    if template_type == "manual":
        print(f"manual template selected, with id :{template_id}")
        mytemplate = ManualTemplate(tokenizer=tokenizer).from_file(f"{scriptsbase}/manual_template.txt", choice=template_id)

    elif template_type == "soft":
        print(f"soft template selected, with id :{template_id}")
        mytemplate = SoftTemplate(model=plm, tokenizer=tokenizer, num_tokens=soft_token_num, initialize_from_vocab=init_from_vocab).from_file(f"{scriptsbase}/soft_template.txt", choice=template_id)


    elif template_type == "mixed":
        print(f"mixed template selected, with id :{template_id}")
        mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer).from_file(f"{scriptsbase}/mixed_template.txt", choice=template_id)
    # now set verbalizer
    if verbalizer_type == "manual":
        print(f"manual verbalizer selected, with id :{verbalizer_id}")
        myverbalizer = ManualVerbalizer(tokenizer, classes=class_labels).from_file(f"{scriptsbase}/manual_verbalizer.{scriptformat}", choice=verbalizer_id)

    elif verbalizer_type == "soft":
        print(f"soft verbalizer selected!")
        myverbalizer = SoftVerbalizer(tokenizer, plm, num_classes=len(class_labels))
    # are we using cuda and if so which number of device
    use_cuda = True
    
    cuda_device = "cpu"
    if use_cuda:
        if torch.cuda.is_available():
            cuda_device = "cuda:0"
    
    
            cuda_device = torch.device(f'cuda:{gpu_num}')
    # now set the default gpu to this one
    torch.cuda.set_device(cuda_device)


    print(f"tune_plm value: {tune_plm}")
    prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=freeze_plm, plm_eval_mode=plm_eval_mode)
    if use_cuda:
        prompt_model=  prompt_model.to(cuda_device)

    if model_parallelize:
        prompt_model.parallelize()


    # if doing few shot learning - produce the datasets here:
    if training_size == "fewshot":
        logger.warning(f"Will be performing few shot learning.")
    # create the few_shot sampler for when we want to run training and testing with few shot learning
        support_sampler = FewShotSampler(num_examples_per_label = few_shot_n, also_sample_dev=False)

        # create a fewshot dataset from training, val and test. Seems to be what several papers do...
        dataset['train'] = support_sampler(dataset['train'], seed=seed)
        dataset['validation'] = support_sampler(dataset['validation'], seed=seed)
        dataset['test'] = support_sampler(dataset['test'], seed=seed)

    # are we doing training?
    do_training = (not no_training)
    if do_training:
        # if we have a sampler .e.g weightedrandomsampler. Do not shuffle
        if "WeightedRandom" in type(sampler).__name__:
            logger.warning("Sampler is WeightedRandom - will not be shuffling training data!")
            shuffle = False
        else:
            shuffle = True
        logger.warning(f"Do training is True - creating train and validation dataloders!")
        train_dataloader = customPromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer, 
            tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 
            batch_size=batchsize_t,shuffle=shuffle, sampler = sampler, teacher_forcing=False, predict_eos_token=False,
            truncate_method="tail")

        validation_dataloader = customPromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer, 
            tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 
            batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,
            truncate_method="tail")


    # zero-shot test
    test_dataloader = customPromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer, 
        tokenizer_wrapper_class=WrapperClass, max_seq_length=max_seq_l, decoder_max_length=3, 
        batch_size=batchsize_e,shuffle=False, teacher_forcing=False, predict_eos_token=False,
        truncate_method="tail")


    #TODO update this to handle class weights for imabalanced datasets
    if ce_class_weights:
        logger.warning("we have some task specific class weights - passing to CE loss")
        # get from the class_weight function
        # task_class_weights = torch.tensor(task_class_weights, dtype=torch.float).to(cuda_device)
        
        # set manually cause above didnt work
        task_class_weights = torch.tensor([1,16.1], dtype=torch.float).to(cuda_device)
        loss_func = torch.nn.CrossEntropyLoss(weight = task_class_weights, reduction = 'mean')
    else:
        loss_func = torch.nn.CrossEntropyLoss()

    # get total steps as a function of the max epochs, batch_size and len of dataloader
    tot_step = max_steps

    if tune_plm:
        
        logger.warning("We will be tuning the PLM!") # normally we freeze the model when using soft_template. However, we keep the option to tune plm
        no_decay = ['bias', 'LayerNorm.weight'] # it's always good practice to set no decay to biase and LayerNorm parameters
        optimizer_grouped_parameters_plm = [
            {'params': [p for n, p in prompt_model.plm.named_parameters() if (not any(nd in n for nd in no_decay))], 'weight_decay': 0.01},
            {'params': [p for n, p in prompt_model.plm.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer_plm = AdamW(optimizer_grouped_parameters_plm, lr=config['plm_lr'])
        scheduler_plm = get_linear_schedule_with_warmup(
            optimizer_plm, 
            num_warmup_steps=plm_warmup_steps, num_training_steps=tot_step)
    else:
        logger.warning("We will not be tunning the plm - i.e. the PLM layers are frozen during training")
        optimizer_plm = None
        scheduler_plm = None

    # if using soft template
    if template_type == "soft" or template_type == "mixed":
        logger.warning(f"{template_type} template used - will be fine tuning the prompt embeddings!")
        optimizer_grouped_parameters_template = [{'params': [p for name, p in prompt_model.template.named_parameters() if 'raw_embedding' not in name]}] # note that you have to remove the raw_embedding manually from the optimization
        if optimizer.lower() == "adafactor":
            optimizer_template = Adafactor(optimizer_grouped_parameters_template,  
                                    lr=config['prompt_lr'],
                                    relative_step=False,
                                    scale_parameter=False,
                                    warmup_init=False)  # when lr is 0.3, it is the same as the configuration of https://arxiv.org/abs/2104.08691
            scheduler_template = get_constant_schedule_with_warmup(optimizer_template, num_warmup_steps=warmup_step_prompt) # when num_warmup_steps is 0, it is the same as the configuration of https://arxiv.org/abs/2104.08691
        elif optimizer.lower() == "adamw":
            optimizer_template = AdamW(optimizer_grouped_parameters_template, lr=config['prompt_lr']) # usually lr = 0.5
            scheduler_template = get_linear_schedule_with_warmup(
                            optimizer_template, 
                            num_warmup_steps=warmup_step_prompt, num_training_steps=tot_step) # usually num_warmup_steps is 500

    elif template_type == "manual":
        optimizer_template = None
        scheduler_template = None


    if verbalizer_type == "soft":
        logger.warning("Soft verbalizer used - will be fine tuning the verbalizer/answer embeddings!")
        optimizer_grouped_parameters_verb = [
        {'params': prompt_model.verbalizer.group_parameters_1, "lr":config['plm_lr']},
        {'params': prompt_model.verbalizer.group_parameters_2, "lr":config['plm_lr']}        
        ]
        optimizer_verb= AdamW(optimizer_grouped_parameters_verb)
        scheduler_verb = get_linear_schedule_with_warmup(
                            optimizer_verb, 
                            num_warmup_steps=warmup_step_prompt, num_training_steps=tot_step) # usually num_warmup_steps is 500

    elif verbalizer_type == "manual":
        optimizer_verb = None
        scheduler_verb = None


    # set model to train 
    prompt_model.train()

    # set up some counters
    actual_step = 0
    glb_step = 0

    # some validation metrics to monitor
    best_val_acc = 0
    best_val_f1 = 0
    best_val_prec = 0    
    best_val_recall = 0

 

    # this will be set to true when max steps are reached
    leave_training = False

    for epoch in tqdm(range(num_epochs)):
        print(f"On epoch: {epoch}")
        tot_loss = 0 
        epoch_loss = 0
        for step, inputs in enumerate(train_dataloader):       

            if use_cuda:
                inputs = inputs.to(cuda_device)
            logits = prompt_model(inputs)
            labels = inputs['label']
            loss = loss_func(logits, labels)

            # normalize loss to account for gradient accumulation
            loss = loss / gradient_accumulation_steps 

            # propogate backward to calculate gradients
            loss.backward()
            tot_loss += loss.item()

            actual_step+=1
            # log loss to tensorboard  every 50 steps    

            # clip gradients based on gradient accumulation steps
            if actual_step % gradient_accumulation_steps == 0:
                # log loss
                aveloss = tot_loss/(step+1)
                # write to tensorboard
                writer.add_scalar("train/batch_loss", aveloss, glb_step)        

                # clip grads            
                torch.nn.utils.clip_grad_norm_(prompt_model.parameters(), 1.0)
                glb_step += 1

                # backprop the loss and update optimizers and then schedulers too
                # plm
                if optimizer_plm is not None:
                    optimizer_plm.step()
                    optimizer_plm.zero_grad()
                if scheduler_plm is not None:
                    scheduler_plm.step()
                # template
                if optimizer_template is not None:
                    optimizer_template.step()
                    optimizer_template.zero_grad()
                if scheduler_template is not None:
                    scheduler_template.step()
                # verbalizer
                if optimizer_verb is not None:
                    optimizer_verb.step()
                    optimizer_verb.zero_grad()
                if scheduler_verb is not None:
                    scheduler_verb.step()

                # check if we are over max steps
                if glb_step > max_steps:
                    logger.warning("max steps reached - stopping training!")
                    leave_training = True
                    break

        # get epoch loss and write to tensorboard

        epoch_loss = tot_loss/len(train_dataloader)
        print("Epoch {}, loss: {}".format(epoch, epoch_loss), flush=True)   
        writer.add_scalar("train/epoch_loss", epoch_loss, epoch)

        
        # run a run through validation set to get some metrics        
        val_loss, val_acc, val_prec_weighted, val_prec_macro, val_recall_weighted,val_recall_macro, val_f1_weighted,val_f1_macro, val_auc_weighted,val_auc_macro, cm_figure = evaluate(prompt_model, validation_dataloader,
                                                                                                                                                                                        use_cuda=use_cuda, cuda_device = cuda_device,
                                                                                                                                                                                        loss_func = loss_func)

        writer.add_scalar("valid/loss", val_loss, epoch)
        writer.add_scalar("valid/balanced_accuracy", val_acc, epoch)
        writer.add_scalar("valid/precision_weighted", val_prec_weighted, epoch)
        writer.add_scalar("valid/precision_macro", val_prec_macro, epoch)
        writer.add_scalar("valid/recall_weighted", val_recall_weighted, epoch)
        writer.add_scalar("valid/recall_macro", val_recall_macro, epoch)
        writer.add_scalar("valid/f1_weighted", val_f1_weighted, epoch)
        writer.add_scalar("valid/f1_macro", val_f1_macro, epoch)

        #TODO add binary classification metrics e.g. roc/auc
        writer.add_scalar("valid/auc_weighted", val_auc_weighted, epoch)
        writer.add_scalar("valid/auc_macro", val_auc_macro, epoch)        

        # add cm to tensorboard
        writer.add_figure("valid/Confusion_Matrix", cm_figure, epoch)

        # save checkpoint if validation accuracy improved
        if val_acc >= best_val_acc:
            # only save ckpts if no_ckpt is False - we do not always want to save - especially when developing code
            if ckpt_dir != None:
                logger.warning("Accuracy improved! Saving checkpoint!")
                torch.save(prompt_model.state_dict(),f"{ckpt_dir}/best-checkpoint.ckpt")
            best_val_acc = val_acc


        if glb_step > max_steps:
            leave_training = True
            break
    
        if leave_training:
            logger.warning("Leaving training as max steps have been met!")
            break 

        
        # now we want to send these back to raytune
        tune.report(loss=val_loss, accuracy = val_acc)
           
   
# ## evaluate

# %%

def evaluate(prompt_model, dataloader, mode = "validation", 
                class_labels = None, use_cuda = True, cuda_device = None, loss_func= None):

    prompt_model.eval()

    tot_loss = 0
    allpreds = []
    alllabels = []
    #record logits from the the model
    alllogits = []
    # store probabilties i.e. softmax applied to logits
    allscores = []

    allids = []
    with torch.no_grad():
        for step, inputs in enumerate(dataloader):
            if use_cuda:
                inputs = inputs.to(cuda_device)
            logits = prompt_model(inputs)
            labels = inputs['label']

            loss = loss_func(logits, labels)
            tot_loss += loss.item()

            # add labels to list
            alllabels.extend(labels.cpu().tolist())

            # add ids to list - they are already a list so no need to send to cpu
            allids.extend(inputs['guid'])

            # add logits to list
            alllogits.extend(logits.cpu().tolist())
            #use softmax to normalize, as the sum of probs should be 1
            # if binary classification we just want the positive class probabilities
            if len(class_labels) > 2:  
                allscores.extend(torch.nn.functional.softmax(logits).cpu().tolist())
            else:

                allscores.extend(torch.nn.functional.softmax(logits)[:,1].cpu().tolist())

            # add predicted labels    
            allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

    
    val_loss = tot_loss/len(dataloader)    
    # get sklearn based metrics
    acc = balanced_accuracy_score(alllabels, allpreds)
    f1_weighted = f1_score(alllabels, allpreds, average = 'weighted')
    f1_macro = f1_score(alllabels, allpreds, average = 'macro')
    prec_weighted = precision_score(alllabels, allpreds, average = 'weighted')
    prec_macro = precision_score(alllabels, allpreds, average = 'macro')
    recall_weighted = recall_score(alllabels, allpreds, average = 'weighted')
    recall_macro = recall_score(alllabels, allpreds, average = 'macro')


    # roc_auc  - only really good for binary classification but can try for multiclass too
    # use scores instead of predicted labels to give probs
    
    if len(class_labels) > 2:   
        roc_auc_weighted = roc_auc_score(alllabels, allscores, average = "weighted", multi_class = "ovr")
        roc_auc_macro = roc_auc_score(alllabels, allscores, average = "macro", multi_class = "ovr")
                  
    else:
        roc_auc_weighted = roc_auc_score(alllabels, allscores, average = "weighted")
        roc_auc_macro = roc_auc_score(alllabels, allscores, average = "macro")         


    
   
    return val_loss, acc, prec_weighted, prec_macro, recall_weighted, recall_macro, f1_weighted, f1_macro, roc_auc_weighted, roc_auc_macro



# create raytune config
config = {
    "plm_lr":tune.loguniform(1e-4, 1e-5),
    "prompt_lr":tune.loguniform(1e-4, 1e-5),
    "batch_size": tune.choice([4]),
    "grad_accum_steps":tune.choice([2,5,10]),
    "dropout": tune.choice([0.1,0.2,0.5]),
    "optimizer": tune.choice(['adamw']),
}



# create the trainable ray tune class
trainable = tune.with_parameters(
    train_mimic,   
    num_epochs=num_epochs,
    num_gpus=gpus_per_trial,
    data_dir = data_dir)
# run the analysis
analysis = tune.run(
    trainable,
    resources_per_trial={
        "cpu": 1,
        "gpu": gpus_per_trial
    },
    metric="loss",
    mode="min",
    config=config,    
    num_samples=num_samples,
    local_dir = f"{save_dir}",
    name=f"tune_mimic_{dataset}"
    
    )
import ray
ray.shutdown()


print(f"Best config based on ray tune analysis!:\n {analysis.best_config}")



NameError: name 'data_dir' is not defined