# Notebook that handles fine-tuning from start to finish

In [2]:
# !pip install git+https://github.com/huggingface/transformers.git
!pip install "sagemaker>=2.140.0" "transformers==4.26.1" "datasets[s3]==2.10.1" --upgrade
# !pip install accelerate==0.20.3

Collecting sagemaker>=2.140.0
  Downloading sagemaker-2.185.0.tar.gz (884 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m884.9/884.9 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting transformers==4.26.1
  Using cached transformers-4.26.1-py3-none-any.whl (6.3 MB)
Collecting datasets[s3]==2.10.1
  Using cached datasets-2.10.1-py3-none-any.whl (469 kB)
Collecting huggingface-hub<1.0,>=0.11.0
  Using cached huggingface_hub-0.17.1-py3-none-any.whl (294 kB)
Collecting regex!=2019.12.17
  Using cached regex-2023.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (771 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Using cached tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
Collecting filelock
  Downloading filelock-3.12.4-py3-none-any.whl (11 kB)
Collecting aiohttp
  Using cached aiohttp-3.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_

In [3]:
import numpy as np
from sagemaker import get_execution_role
import boto3
import pandas as pd
from io import StringIO # Python 3.
from datasets import load_dataset,Dataset,DatasetDict,concatenate_datasets

from transformers import DataCollatorWithPadding,AutoModelForSequenceClassification, Trainer, TrainingArguments,AutoTokenizer,AutoModel,AutoConfig
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
import pandas as pd
import json

#from models.EDdisposition import EDdispositionClassifier

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# reads out files

import boto3

s3 = boto3.resource('s3')
bucket = s3.Bucket('chianglab-dataderivatives')

folders = set()

for obj in bucket.objects.all():
    prefix, delimiter, _ = obj.key.rpartition('/')
    if prefix:
        folders.add(prefix + '/')

print('Folders:')
for folder in folders:
    print(folder)
    


Folders:
mimic-iv-clinical-database-demo-2.2/
mimic-iv-2.2/
mimic-iv-ed-2.2/
mimic-iv-ed-demo-2.2/


In [5]:
bucket = 'chianglab-dataderivatives'
subfolder = 'mimic-iv-ed-2.2/'
conn = boto3.client('s3')
contents = conn.list_objects(Bucket=bucket, Prefix=subfolder)['Contents']

file_list = []

for f in contents:
    print(f['Key'])
    file_list.append(f['Key'][36:])

print(file_list)

mimic-iv-ed-2.2/
mimic-iv-ed-2.2/text_repr.json
['', '']


### Data we will be working with

In [9]:
bucket_name = 'chianglab-dataderivatives'
file_path = "mimic-iv-ed-demo-2.2/text_repr.json"


content_object = s3.Object(bucket_name, file_path)
file_content = content_object.get()['Body'].read().decode('utf-8')
json_content = json.loads(file_content)
df = pd.DataFrame(json_content).T
print("length of dataframe: "+ str(len(df)))
df.head(5)
# df['codes_headline'] = df['ID'].map(json_content)


length of dataframe: 210


Unnamed: 0,arrival,eddischarge,admission,discharge,triage,medrecon,vitals,pyxis,codes
37887480,"Patient 10014729, a 21 year old white - other ...",The ED disposition was admitted at 2125-03-19 ...,The patient was admitted at 2125-03-19 16:58:00.,The patient's discharge disposition was: home ...,"At triage: temperature was 99.1, pulse was 90....",The patient was previously taking the followin...,The patient had the following vitals: At 2125-...,The patient received the following medications...,The patient received the following diagnostic ...
34176810,"Patient 10018328, a 83 year old white female, ...",The ED disposition was admitted at 2154-02-05 ...,The patient was admitted at 2154-02-05 21:58:00.,The patient's discharge disposition was: home ...,"At triage: temperature was 97.7, pulse was 74....",The patient was previously taking the followin...,The patient had the following vitals: At 2154-...,,The patient received the following diagnostic ...
32103106,"Patient 10018328, a 83 year old white female, ...",The ED disposition was home at 2154-08-03 22:2...,The patient was not admitted.,The patient was not admitted.,"At triage: temperature was 96.2, pulse was 74....",The patient was previously taking the followin...,The patient had the following vitals: At 2154-...,The patient received the following medications...,The patient received the following diagnostic ...
38797992,"Patient 10020640, a 91 year old white female, ...",The ED disposition was admitted at 2153-02-13 ...,The patient was admitted at 2153-02-13 00:22:00.,The patient's discharge disposition was: skill...,"At triage: temperature was 99.2, pulse was 130...",The patient was previously taking the followin...,The patient had the following vitals: At 2153-...,The patient received the following medications...,The patient received the following diagnostic ...
33473053,"Patient 10015272, a 78 year old white female, ...",The ED disposition was admitted at 2137-06-12 ...,The patient was admitted at 2137-06-12 18:36:00.,The patient's discharge disposition was: home ...,"At triage: temperature was 97.5, pulse was 118...",The patient was previously taking the followin...,The patient had the following vitals: At 2137-...,The patient received the following medications...,The patient received the following diagnostic ...


In [10]:
df['eddischarge'] = [1 if 'admitted' in s.lower() else 0 for s in df['eddischarge']] # admitted = 1, Home = 0
df['medrecon'] = df['medrecon'].fillna("The patient was previously not taking any medications.")
df['pyxis'] = df['pyxis'].fillna("The patient did not receive any medications.")
df['vitals'] = df['vitals'].fillna("The patient had no vitals recorded")
df['codes'] = df['codes'].fillna("The patient received no diagnostic codes")
df = df.drop("admission",axis=1)
df = df.drop("discharge",axis=1)
# df = df.drop("eddischarge_category",axis=1)
df = df[[col for col in df.columns if col != 'eddischarge'] + ['eddischarge']] # rearrange column to the end
df['ID'] = df.arrival.astype(str).str.split().str[1].replace(",", " ", regex=True).to_list()
patient_IDS = df['ID'].to_list()
df = df.drop("ID",axis=1)
df

# remove admission and discharge columns 

Unnamed: 0,arrival,triage,medrecon,vitals,pyxis,codes,eddischarge
37887480,"Patient 10014729, a 21 year old white - other ...","At triage: temperature was 99.1, pulse was 90....",The patient was previously taking the followin...,The patient had the following vitals: At 2125-...,The patient received the following medications...,The patient received the following diagnostic ...,1
34176810,"Patient 10018328, a 83 year old white female, ...","At triage: temperature was 97.7, pulse was 74....",The patient was previously taking the followin...,The patient had the following vitals: At 2154-...,The patient did not receive any medications.,The patient received the following diagnostic ...,1
32103106,"Patient 10018328, a 83 year old white female, ...","At triage: temperature was 96.2, pulse was 74....",The patient was previously taking the followin...,The patient had the following vitals: At 2154-...,The patient received the following medications...,The patient received the following diagnostic ...,0
38797992,"Patient 10020640, a 91 year old white female, ...","At triage: temperature was 99.2, pulse was 130...",The patient was previously taking the followin...,The patient had the following vitals: At 2153-...,The patient received the following medications...,The patient received the following diagnostic ...,1
33473053,"Patient 10015272, a 78 year old white female, ...","At triage: temperature was 97.5, pulse was 118...",The patient was previously taking the followin...,The patient had the following vitals: At 2137-...,The patient received the following medications...,The patient received the following diagnostic ...,1
...,...,...,...,...,...,...,...
30272878,"Patient 10038999, a 45 year old white male, ar...","At triage: temperature was not recorded, pulse...",The patient was previously not taking any medi...,The patient had the following vitals: At 2131-...,The patient received the following medications...,The patient received the following diagnostic ...,1
31628990,"Patient 10009049, a 56 year old white male, ar...","At triage: temperature was 99.0, pulse was 87....",The patient was previously taking the followin...,The patient had the following vitals: At 2174-...,The patient received the following medications...,The patient received the following diagnostic ...,1
32405286,"Patient 10004457, a 65 year old white male, ar...","At triage: temperature was 97.6, pulse was 103...",The patient was previously taking the followin...,The patient had the following vitals: At 2141-...,The patient received the following medications...,The patient received the following diagnostic ...,1
34391979,"Patient 10004720, a 61 year old white male, ar...","At triage: temperature was not recorded, pulse...",The patient was previously taking the followin...,The patient had the following vitals: At 2186-...,The patient received the following medications...,The patient received the following diagnostic ...,1


In [11]:
# split dataframe here
def train_validate_test_split(df, train_percent=.7, validate_percent=.15, seed=None):
    np.random.seed(seed)
    df = df.reset_index()
    perm = np.random.permutation(df.index)
    m = len(df.index)
    train_end = int(train_percent * m)
    validate_end = int(validate_percent * m) + train_end
    train = df.iloc[perm[:train_end]]
    validate = df.iloc[perm[train_end:validate_end]]
    test = df.iloc[perm[validate_end:]]
    train = train.set_index('index')
    validate = validate.set_index('index')
    test = test.set_index('index')
    return train, validate, test

In [55]:
t, val, t2 = train_validate_test_split(df, train_percent=.9, validate_percent=.05, seed=7)
print("70% Train:",len(t), "\n30% Test:",len(val+t2))

70% Train: 360017 
30% Test: 40002


In [11]:
import string
# check that seeding works across different notebooks by reading in the test_patients.txt and seeing if they match
test_patients = t2.arrival.astype(str).str.split().str[1].to_list()
validate_patients = val.arrival.astype(str).str.split().str[1].to_list()
test_patients = (test_patients+validate_patients)
test_patients = [''.join(char for char in item if char not in string.punctuation) for item in test_patients]
test_patients = list(set(test_patients))
test_patients.sort()


# extract test patient IDs into list
f = open("./models/data/test_patients.txt", "r")
data = f.read()
test_patients2 = data.split("\n")
test_patients2.sort()
test_patients2.pop(0) # need to pop the empty newline character
f.close()

print(len(test_patients), len(test_patients2))

# using == to check if lists are equal
if test_patients == test_patients2:
    print("The lists are identical")
else:
    print("The lists are not identical")
    

34 82805
The lists are not identical


In [56]:
remain = pd.concat([val, t2])
print("Size of validation + test after concat: ", len(remain)) # sanity check

# #resplit the our testing dataframe into an additional train and test split for fine tuning 
train, validate, test =  train_validate_test_split(remain, seed=7)
print("70% Train:",len(train), "\n30% Test:",len(validate+test))

Size of validation + test after concat:  40002
70% Train: 28001 
30% Test: 12001


In [15]:
# we stack and unstack later for easier tokenization
run=False
if run:
    disposition_train = train.eddischarge
    temp = train.drop("eddischarge",axis=1)
    train_stack = temp.stack().to_frame("headline")
    disposition_validation = validate.eddischarge
    temp = validate.drop("eddischarge",axis=1)
    validate_stack = temp.stack().to_frame("headline")
    disposition_test = test.eddischarge
    temp = test.drop("eddischarge",axis=1)
    test_stack = temp.stack().to_frame("headline")

    training_data_corpus = Dataset.from_pandas(train_stack)
    validation_data_corpus = Dataset.from_pandas(validate_stack)
    test_data_corpus = Dataset.from_pandas(test_stack)

In [16]:
def cut(df, set_type):
    col_names = df.columns.drop("eddischarge")
    l = []
    for i in col_names:
        temp = df[[i, 'eddischarge']].reset_index()
        temp = temp.sort_values(by=['index']).reset_index() # we sort the patient ID numerically before dropping it to preserve order in encoding
        temp = temp.drop(columns=["index", "level_0"])
        temp = temp.rename(columns={i: "headline", "eddischarge": "label"})
        l.append(temp)
        print("\""+i+ "\" Dataframe:", set_type, "set has been split")
    return l

In [17]:
print("################################################")
l1 = cut(train, "train")
print("################################################")
l2 = cut(validate, "validation")
print("################################################")
l3 = cut (test, "test")
print("################################################")

################################################
"arrival" Dataframe: train set has been split
"triage" Dataframe: train set has been split
"medrecon" Dataframe: train set has been split
"vitals" Dataframe: train set has been split
"pyxis" Dataframe: train set has been split
"codes" Dataframe: train set has been split
################################################
"arrival" Dataframe: validation set has been split
"triage" Dataframe: validation set has been split
"medrecon" Dataframe: validation set has been split
"vitals" Dataframe: validation set has been split
"pyxis" Dataframe: validation set has been split
"codes" Dataframe: validation set has been split
################################################
"arrival" Dataframe: test set has been split
"triage" Dataframe: test set has been split
"medrecon" Dataframe: test set has been split
"vitals" Dataframe: test set has been split
"pyxis" Dataframe: test set has been split
"codes" Dataframe: test set has been split
################

In [18]:
from transformers import TextClassificationPipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# tokenize
model = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

class Tokenizer():
    def tokenize(self,examples):
      """Mapping function to tokenize the sentences passed with truncation"""
      return tokenizer(examples["headline"], truncation=True, padding="max_length",
                        max_length=512, return_special_tokens_mask=True)
    def convert(self, l):
        """
        Run this method
        """
        arrival_hf=Dataset.from_pandas(l[0])
        triage_hf=Dataset.from_pandas(l[1])
        medrecon_hf=Dataset.from_pandas(l[2])
        vitals_hf=Dataset.from_pandas(l[3])
        codes_hf=Dataset.from_pandas(l[4])
        pyxis_hf=Dataset.from_pandas(l[5])

        arrival = arrival_hf.map(self.tokenize, batched=True)
        triage = triage_hf.map(self.tokenize, batched=True)
        medrecon = medrecon_hf.map(self.tokenize, batched=True)
        vitals = vitals_hf.map(self.tokenize, batched=True)
        codes = codes_hf.map(self.tokenize, batched=True)
        pyxis = pyxis_hf.map(self.tokenize, batched=True)

        arrival.set_format('torch', columns=["input_ids", "attention_mask", "label"] )
        triage.set_format('torch', columns=["input_ids", "attention_mask", "label"] )
        medrecon.set_format('torch', columns=["input_ids", "attention_mask", "label"] )
        vitals.set_format('torch', columns=["input_ids", "attention_mask", "label"] )
        codes.set_format('torch', columns=["input_ids", "attention_mask", "label"] )
        pyxis.set_format('torch', columns=["input_ids", "attention_mask", "label"] )

        return arrival, triage, medrecon, vitals, codes, pyxis

In [19]:
# calls methods and tokenizes text
processor = Tokenizer()
arrival_train_tokens, triage_train_tokens, medrecon_train_tokens, vitals_train_tokens, codes_train_tokens, pyxis_train_tokens, = processor.convert(l1)
arrival_val_tokens, triage_val_tokens, medrecon_val_tokens, vitals_val_tokens, codes_val_tokens, pyxis_val_tokens, = processor.convert(l2)
arrival_test_tokens, triage_test_tokens, medrecon_test_tokens, vitals_test_tokens, codes_test_tokens, pyxis_test_tokens, = processor.convert(l3)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

                                                   

In [20]:
arrival_dataset_cc = DatasetDict({
    'train': arrival_train_tokens,
    'test': arrival_test_tokens,
    'valid': arrival_val_tokens})

triage_dataset_cc = DatasetDict({
    'train': triage_train_tokens,
    'test': triage_test_tokens,
    'valid': triage_val_tokens})

medrecon_dataset_cc = DatasetDict({
    'train': medrecon_train_tokens,
    'test': medrecon_test_tokens,
    'valid': medrecon_val_tokens})

vitals_dataset_cc = DatasetDict({
    'train': vitals_train_tokens,
    'test': vitals_test_tokens,
    'valid': vitals_val_tokens})

codes_dataset_cc = DatasetDict({
    'train': codes_train_tokens,
    'test': codes_test_tokens,
    'valid': codes_val_tokens})

pyxis_dataset_cc = DatasetDict({
    'train': pyxis_train_tokens,
    'test': pyxis_test_tokens,
    'valid': pyxis_val_tokens})

In [21]:
train_dataloader_concat = [triage_dataset_cc["train"], arrival_dataset_cc["train"],medrecon_dataset_cc["train"],vitals_dataset_cc["train"],codes_dataset_cc["train"],pyxis_dataset_cc["train"]]
valid_dataloader_concat = [triage_dataset_cc["valid"], arrival_dataset_cc["valid"],medrecon_dataset_cc["valid"],vitals_dataset_cc["valid"],codes_dataset_cc["valid"],pyxis_dataset_cc["valid"]]
test_dataloader_concat = [triage_dataset_cc["test"], arrival_dataset_cc["test"],medrecon_dataset_cc["test"],vitals_dataset_cc["test"],codes_dataset_cc["test"],pyxis_dataset_cc["test"]]

In [24]:
class EDdispositionClassifier(nn.Module):
    """
    A task-specific custom transformer model for predicting ED Disposition. 
    This model loads a pre-trained transformer model and adds a new dropout 
    and linear layer at the end for fine-tuning and prediction on specific tasks.
    """
    def __init__(self, checkpoint, num_labels):
        """
        Args:
            checkpoint (str): The name of the pre-trained model or path to the model weights.
            num_labels (int): The number of output labels in the final classification layer.
        """
        super(EDdispositionClassifier, self).__init__()
        self.num_labels = num_labels # number of labels for classifier
        
        # checkpoint is the model name 
        self.model = model = AutoModel.from_pretrained(checkpoint, config = AutoConfig.from_pretrained(checkpoint, 
                                                                                                       output_attention = True, 
                                                                                                       output_hidden_state = True ) )
        # New Layer
        self.dropout = nn.Dropout(0.1) # to prevent overfittting
        self.classifier = nn.Linear(768, num_labels) #FC Layer - takes in a 768 token vector and is a Linear classifier with n labels
        
    def forward(self, input_ids = None, attention_mask=None, labels = None ):
        """
        Forward pass for the model.
        
        Args:
            input_ids (torch.Tensor, optional): Tensor of input IDs. Defaults to None.
            attention_mask (torch.Tensor, optional): Tensor for attention masks. Defaults to None.
            labels (torch.Tensor, optional): Tensor for labels. Defaults to None.
            
        Returns:
            TokenClassifierOutput: A named tuple with the following fields:
            - loss (torch.FloatTensor of shape (1,), optional, returned when label_ids is provided) – Classification loss.
            - logits (torch.FloatTensor of shape (batch_size, num_labels)) – Classification scores before SoftMax.
            - hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).
            - attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).
        """
        # calls on the Automodel to deploy correct model - in our case distilled-bert-uncased
        outputs = self.model(input_ids = input_ids, attention_mask = attention_mask  )
        
        # retrieves the last hidden state
        last_hidden_state = outputs[0]
        
        return last_hidden_state # The embedding
        
class TransformerModel(nn.Module):
    def __init__(self, num_layers=6, d_model=768, nhead=8, dim_feedforward=2048):
        super(TransformerModel, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers)
        self.softmax = nn.Softmax(dim=2)
        self.num_labels = 2 # number of labels for classifier
        self.dropout = nn.Dropout(0.1) # to prevent overfittting
        self.classifier = nn.Linear(768, 2) #FC Layer - takes in a 768 token vector and is a Linear classifier with n labels
        self.dense_layer = nn.Linear(768, 768)

    def forward(self, src, labels=None, attention=True):
        if attention:
            output = self.transformer_encoder(src)
            # project brand new layers into 768 dimensions 
            # src = self.dense_layer(output)
        
        # begin classification
        # include dropout from constructor to feed forward network
        sequence_outputs = self.dropout(src)
        
        # finally add linear layer from input
        logits = self.classifier(sequence_outputs[:, 0, : ].view(-1, 768 ))
        
        # calculates loss 
        loss = None
        if labels is not None:
            loss_func = nn.CrossEntropyLoss() # Change this if it becomes more than binary classification
            loss = loss_func(logits.view(-1, self.num_labels), labels.view(-1))
            
            # TokenClassifierOutput - returns predicted label
            return TokenClassifierOutput(loss=loss, logits=logits)#, hidden_states=outputs.hidden_states, attentions=new_vec.attentions)
        else:
            return logits
        
class EDDispositionFineTuneModel(nn.Module):
    def __init__(self, checkpoint, num_labels=2, input_dim=768, modalities=None):
        super(EDDispositionFineTuneModel, self).__init__()
        self.encoder = EDdispositionClassifier(checkpoint=checkpoint, num_labels=num_labels)
        self.predictor = TransformerModel()
        assert modalities is not None, "Number of modalities missing"
        self.modalities = modalities
    
    def forward(self, input_ids, attention_mask, label=None):
        # input_ids: dictionary of the batch
        # attention_mask: dictionary of the batch
        embedding = []
        for modality in range(self.modalities):
            embed = self.encoder(input_ids[modality], attention_mask[modality], label)
            embedding.append(embed)
        unified_embedding = torch.cat((embedding[0],embedding[1],embedding[2],embedding[3],embedding[4],embedding[5]),1) # concatenates embeddings on the second dimension
        outputs = self.predictor(unified_embedding, label)
        return outputs


In [25]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

full_model = EDDispositionFineTuneModel(checkpoint=model, num_labels=2, input_dim=768, modalities=6).to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [60]:
model_task_specific.parameters()
predictor.parameters()

<generator object Module.parameters at 0x7ff0d57f39e0>

In [27]:
# # Get all of the model's parameters as a list of tuples.
# params = list(model.named_parameters())

# print('The BERT model has {:} different named parameters.\n'.format(len(params)))

# print('==== Embedding Layer ====\n')

# for p in params[0:5]:
#     print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

# print('\n==== First Transformer ====\n')

# for p in params[5:21]:
#     print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

# print('\n==== Output Layer ====\n')

# for p in params[-4:]:
#     print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

In [26]:
from transformers import AdamW, get_scheduler

# optimizer = AdamW(model_task_specific.parameters(), lr = 5e-5 )
optimizer = AdamW(full_model.parameters(), lr = 5e-5 )

num_epoch = 2
BATCH = 8
num_training_steps = num_epoch * len(triage_dataset_cc['train']["input_ids"]) // BATCH
print(len(triage_dataset_cc['train']["input_ids"]))

lr_scheduler = get_scheduler(
    'linear',
    optimizer = optimizer,
    num_warmup_steps = 0,
    num_training_steps = num_training_steps,
)

735




In [29]:
from datasets import load_metric
metric = load_metric("f1")


Downloading builder script: 7.55kB [00:00, 4.06MB/s]                   


In [30]:
# TRAINING LOOP
num_epoch=1

from tqdm.auto import tqdm

progress_bar_train = tqdm(range(num_epoch * len(triage_dataset_cc['train']["input_ids"]) // BATCH ))
# progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epoch * len(triage_dataset_cc['valid']["input_ids"]) // BATCH ))

for epoch in range(num_epoch):
    full_model.train()
    print(f"Epoch {epoch}...")
    random_idx = np.random.permutation(np.arange(len(train_dataloader_concat[0]['input_ids'])))
    # for i, b in tqdm(enumerate(train_dl)):
    for step, idx in enumerate(range(0, len(random_idx), BATCH)):
        iter_rand_idx = random_idx[idx:idx+BATCH]
        input_ids, attention_mask = [], []
        for modality in train_dataloader_concat:
            input_ids.append(modality['input_ids'][iter_rand_idx].to(device))
            attention_mask.append(modality['attention_mask'][iter_rand_idx].to(device))
        label = modality['label'][iter_rand_idx].to(device)
        outputs = full_model(input_ids, attention_mask, label)
#         unified_embedding = torch.cat((embedding[0],embedding[1],embedding[2],embedding[3],embedding[4],embedding[5]),1) # concatenates embeddings on the second dimension
#         outputs = predictor(unified_embedding, modality['label'][i:i+BATCH].to(device))
        # updates weights accordingly
        loss = outputs.loss
        loss.backward() # computes gradients

        optimizer.step() # updates the weights and biases based on these gradients
        lr_scheduler.step() # updates the weights and biases based on these gradients
        optimizer.zero_grad() # used to clear the gradients of all parameters in a model
        progress_bar_train.update(1)
    
    # # run on validation set
    print("Validation")
    full_model.eval()
    # for i, b in tqdm(enumerate(train_dl)):
    for step, idx in enumerate(range(0, len(valid_dataloader_concat[0]['input_ids']), BATCH)):
        input_ids, attention_mask = [], []
        for modality in valid_dataloader_concat:
            input_ids.append(modality['input_ids'][idx:idx+BATCH].to(device))
            attention_mask.append(modality['attention_mask'][idx:idx+BATCH].to(device))
        label = modality['label'][idx:idx+BATCH].to(device)
        with torch.no_grad():
            outputs = full_model(input_ids, attention_mask, label)
        logits = outputs.logits # calculates the probabilities between the labels
        predictions = torch.argmax(logits, dim = -1 ) # takes the label closest to 1
        metric.add_batch(predictions = predictions, references = label) 
        loss = outputs.loss
        progress_bar_eval.update(1)
    
    print(metric.compute())

100%|██████████| 19/19 [06:41<00:00, 21.12s/it]


epoch training 0 done
loss: tensor(0.6854, grad_fn=<NllLossBackward0>)


100%|██████████| 4/4 [00:41<00:00, 10.44s/it]


epoch validation 0 done
{'precision': 0.7407407407407407}


100%|██████████| 19/19 [04:39<00:00, 14.72s/it]


epoch training 1 done
loss: tensor(0.6004, grad_fn=<NllLossBackward0>)


100%|██████████| 4/4 [00:23<00:00,  5.98s/it]


epoch validation 1 done
{'precision': 0.7857142857142857}


100%|██████████| 19/19 [06:36<00:00, 20.85s/it]


epoch training 2 done
loss: tensor(0.7934, grad_fn=<NllLossBackward0>)


100%|██████████| 4/4 [00:41<00:00, 10.26s/it]


epoch validation 2 done
{'precision': 0.6923076923076923}


100%|██████████| 19/19 [05:14<00:00, 16.54s/it]


epoch training 3 done
loss: tensor(0.7634, grad_fn=<NllLossBackward0>)


100%|██████████| 4/4 [00:25<00:00,  6.50s/it]


epoch validation 3 done
{'precision': 0.625}


100%|██████████| 19/19 [07:28<00:00, 23.59s/it]


epoch training 4 done
loss: tensor(0.7090, grad_fn=<NllLossBackward0>)


100%|██████████| 4/4 [00:28<00:00,  7.07s/it]

epoch validation 4 done
{'precision': 0.7222222222222222}





# Test predictions

In [31]:
logit_list = []
label_list = []
probs_list = []

full_model.eval()
for step, idx in tqdm(enumerate(range(0, len(test_dataloader_concat[0]['input_ids']), BATCH))):
    input_ids, attention_mask = [], []
    for modality in test_dataloader_concat:
        input_ids.append(modality['input_ids'][idx:idx+BATCH].to(device))
        attention_mask.append(modality['attention_mask'][idx:idx+BATCH].to(device))
    label = modality['label'][idx:idx+BATCH].to(device)
    with torch.no_grad():
        outputs = full_model(input_ids, attention_mask, label)
    logits = outputs.logits # calculates the probabilities between the labels
    predictions = torch.argmax(logits, dim = -1 ) # takes the label closest to 1
    loss = outputs.loss
    logits = outputs.logits # calculates the probabilities between the labels
    logit_list.append(logits[:, 1].cpu().detach().numpy())
    label_list.append(label.cpu().detach().numpy())
    probs_list.append(torch.sigmoid(logits[:, 1]).cpu().detach().numpy())
    predictions = torch.argmax(logits, dim = -1 ) # takes the label closest to 1
    metric.add_batch(predictions = predictions, references = label)
    # print("New Batch")
    # print(predictions)
    # print(modality['label'][i:i+BATCH])

print(metric.compute()) 

{'precision': 0.6190476190476191}
