In [1]:
MEDNLI_PATH = '../physionet.org/files/mednli/1.0.0'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
cd PatientTrajectoryForecasting

/home/sifal.klioui/PatientTrajectoryForecasting


In [4]:
import os
import json
import wandb
import torch
import transformers
from transformers import BertTokenizer, Trainer, TrainingArguments
from utils.bert_classification import MosaicBertForSequenceClassification
from transformers.models.bert.configuration_bert import BertConfig
from torch.optim import AdamW
from dataclasses import dataclass
from torch.nn import CrossEntropyLoss
from utils.mednli import evaluate_model
from datasets import load_dataset, Dataset, DatasetDict
from typing import Optional, Tuple, Union
from functools import partial
import numpy as np
from tqdm import tqdm
import math
import torch
from torch.utils.data import DataLoader
from utils.mednli import compute_metrics, evaluate_model, load_mednli, convert_to_dataset, NLIDataset

2024-06-04 20:54:52.434886: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
def _get_model (num_labels: int,
    pretrained_model_name: str = 'bert-base-uncased',
    model_config: Optional[dict] = None,
    pretrained_checkpoint: Optional[str] = None,
    alibi_starting_size = 1024):
    
    model_config, unused_kwargs = BertConfig.get_config_dict(model_config)
    model_config.update(unused_kwargs)
    
    config, unused_kwargs = transformers.AutoConfig.from_pretrained(
        pretrained_model_name, return_unused_kwargs=True, **model_config)
    # This lets us use non-standard config fields (e.g. `starting_alibi_size`)
    config.update(unused_kwargs)
    config.num_labels = num_labels
    
    model = MosaicBertForSequenceClassification.from_pretrained(
            pretrained_checkpoint=pretrained_checkpoint, config=config)
                
    return model

In [6]:
ls ../bert_mimic_model_512/

[0m[01;34mstep_0[0m/  [01;34mstep_10000[0m/  [01;34mstep_20000[0m/  [01;34mstep_30000[0m/  [01;34mstep_40000[0m/


In [7]:
relative_path = os.path.join('..', 'bert_mimic_model_512/step_40000', 'pytorch_model.bin')

num_labels = 3
pretrained_model_name = 'mosaicml/mosaic-bert-base-seqlen-512'
model_config = 'mosaicml/mosaic-bert-base-seqlen-512'
pretrained_checkpoint = os.path.abspath(relative_path)


get_model = partial(_get_model, num_labels=num_labels, pretrained_model_name=pretrained_model_name, model_config=model_config, pretrained_checkpoint=pretrained_checkpoint)

In [8]:
# All variants use the same tokenizer :))
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast=True)

In [9]:
# Load the MedNLI dataset
train_data = load_mednli(os.path.join(MEDNLI_PATH,'mli_train_v1.jsonl'))
dev_data = load_mednli(os.path.join(MEDNLI_PATH,'mli_dev_v1.jsonl'))
test_data = load_mednli(os.path.join(MEDNLI_PATH,'mli_test_v1.jsonl'))

In [10]:
for partition in [ train_data, dev_data, test_data]:
    assert sorted(list(set([item['gold_label'] for item in partition]))) == ['contradiction', 'entailment', 'neutral'], 'the are some issues with the labels in you dataset'

In [11]:
train_dataset = convert_to_dataset(train_data)
dev_dataset = convert_to_dataset(dev_data)
test_dataset = convert_to_dataset(test_data)

# Create a DatasetDict
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': dev_dataset,
    'test': test_dataset
})
# Tokenize the dataset
def preprocess_function(examples, max_length = 512):
    return tokenizer(examples['premise'], examples['hypothesis'], return_tensors ='pt', padding='max_length', max_length = max_length)

encoded_dataset = dataset_dict.map(preprocess_function, batched=True, remove_columns=['premise', 'hypothesis'])

Map:   0%|          | 0/11232 [00:00<?, ? examples/s]

Map:   0%|          | 0/1395 [00:00<?, ? examples/s]

Map:   0%|          | 0/1422 [00:00<?, ? examples/s]

In [12]:
# Assuming your dataset variable is named 'dataset'
train_dataset = NLIDataset(encoded_dataset['train'])
validation_dataset = NLIDataset(encoded_dataset['validation'])
test_dataset = NLIDataset(encoded_dataset['test'])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=128)
test_dataloader = DataLoader(test_dataset, batch_size=128)

In [13]:
@dataclass
class Args:
    train_file : str = '/scratch/sifal.klioui/notes_v2/notes.txt'
    max_seq_length : str = 512

args = Args()

args.per_device_train_batch_size = 16
args.per_device_eval_batch_size = 256

args.learning_rate = 1e-5
args.beta1 = 0.9
args.beta2 = 0.98
args.eps = 1e-06
args.weight_decay =  1e-6
args.num_train_epochs = 10

In [14]:
def get_model_and_optimizer(args):
    model = get_model()
    no_decay = ["bias", "LayerNorm.weight"]
    
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, betas =(args.beta1, args.beta2), eps=args.eps, weight_decay = args.weight_decay)

    #scheduler = ReduceLROnPlateau(optimizer, 'min', min_lr = args.learning_rate *0.02, patience=2, factor=0.7)
    return model, optimizer

In [15]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [16]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model, optimizer = get_model_and_optimizer(args)
model.to(DEVICE)
total_steps = args.num_train_epochs * len(train_dataloader) 
#lr_s = transformers.get_wsd_schedule(optimizer, round(total_steps *0.06), 0, total_steps - round((total_steps*0.06)), 0.02 * args.learning_rate)
#scheduler = ReduceLROnPlateau(optimizer, 'min', min_lr = args.learning_rate *0.02, patience=2, factor=0.7, threshold = 0.04)

criterion = CrossEntropyLoss()

Found these missing keys in the checkpoint: bert.pooler.dense.weight, bert.pooler.dense.bias, classifier.weight, classifier.bias
the number of which is equal to 4
Found these unexpected keys in the checkpoint: cls.predictions.transform.dense.weight, cls.predictions.transform.dense.bias, cls.predictions.transform.LayerNorm.weight, cls.predictions.transform.LayerNorm.bias, cls.predictions.decoder.weight, cls.predictions.decoder.bias
the number of which is equal to 6


In [None]:
for epoch in range(args.num_train_epochs):
    model.train()
    loop = tqdm(train_dataloader, position=0, leave=True)
    sum_loss = 0 
    for batch in loop:
        optimizer.zero_grad()
        # Move batch to the device
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        token_type_ids = batch['token_type_ids'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)
        # Forward pass
        outputs = model.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels
        )
        pooled_output = outputs[1]

        pooled_output = model.dropout(pooled_output)
        logits = model.classifier(pooled_output)
        
        loss = criterion(logits.view(-1, model.num_labels), labels.view(-1))
        
        loss.backward()
        optimizer.step()
        #r_scheduler.step()
        # Update progress bar
        loop.set_postfix(loss=loss.item())
        sum_loss += loss.item()
    print(f'mean epoch loss = {sum_loss/len(train_dataloader)}')
    
    metrics, val_loss = evaluate_model(model, validation_dataloader, DEVICE, criterion)
    print(metrics,'validation', val_loss)
    #scheduler.step(val_loss)

100%|██████████| 351/351 [02:04<00:00,  2.81it/s, loss=1.13]


mean epoch loss = 1.0999709201334549


100%|██████████| 11/11 [00:04<00:00,  2.56it/s]


{'accuracy': 0.34838709677419355, 'f1': 0.34838709677419355, 'precision': 0.34838709677419355, 'recall': 0.34838709677419355} validation 1.0889542427929966


100%|██████████| 351/351 [01:44<00:00,  3.35it/s, loss=1.11]


mean epoch loss = 1.0976798948738988


100%|██████████| 11/11 [00:04<00:00,  2.55it/s]


{'accuracy': 0.31827956989247314, 'f1': 0.31827956989247314, 'precision': 0.31827956989247314, 'recall': 0.31827956989247314} validation 1.1119456941431218


100%|██████████| 351/351 [01:44<00:00,  3.35it/s, loss=1.1] 


mean epoch loss = 1.1044039923241336


100%|██████████| 11/11 [00:04<00:00,  2.56it/s]


{'accuracy': 0.3311827956989247, 'f1': 0.3311827956989247, 'precision': 0.3311827956989247, 'recall': 0.3311827956989247} validation 1.0987502228129993


 43%|████▎     | 152/351 [00:45<00:59,  3.36it/s, loss=1.06]

In [None]:
metrics, test_loss = evaluate_model(model, test_dataloader, DEVICE, criterion)
print(metrics,'test', test_loss)

In [None]:
for epoch in range(args.num_train_epochs):
    model.train()
    loop = tqdm(train_dataloader, position=0, leave=True)
    sum_loss = 0 
    for batch in loop:
        optimizer.zero_grad()
        # Move batch to the device
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        token_type_ids = batch['token_type_ids'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)
        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels
        )
        
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        #r_scheduler.step()
        # Update progress bar
        loop.set_postfix(loss=loss.item())
        sum_loss += loss.item()
    print(f'mean epoch loss = {sum_loss/len(train_dataloader)}')
    metrics, val_loss = evaluate_model(model, validation_dataloader, DEVICE, criterion)
    print(metrics,'validation', val_loss)
    #scheduler.step(val_loss)

In [None]:
metrics, test_loss = evaluate_model(model, test_dataloader, DEVICE, criterion)
print(metrics,'test', test_loss)