## Code For Fully Frozen Bert Weights (Classification)
- The patients here have not been separated based on whether they died

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from multimodal import MultimodalClassifierDataset, LOSNetWeighted, collation
from sklearn.preprocessing import StandardScaler
from transformers import AutoModel
from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score

In [2]:
!nvidia-smi

Sun Apr  7 18:48:59 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    Off  | 00000000:00:05.0 Off |                  Off |
| 37%   55C    P8    20W / 230W |      1MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
base_path = 'data/split/with-outliers/combined'

static_train = pd.read_csv(f'{base_path}/static_train.csv')
static_val = pd.read_csv(f'{base_path}/static_val.csv')
static_test = pd.read_csv(f'{base_path}/static_test.csv')

In [4]:
notes = pd.read_csv('data/notes_cleaned.csv')
notes_train = notes[notes['id'].isin(static_train['id'])]
notes_val = notes[notes['id'].isin(static_val['id'])]
notes_test = notes[notes['id'].isin(static_test['id'])]

In [5]:
dynamic = pd.read_csv('data/dynamic_cleaned.csv')
dynamic_train = dynamic[dynamic['id'].isin(static_train['id'])].copy()
dynamic_val = dynamic[dynamic['id'].isin(static_val['id'])].copy()
dynamic_test = dynamic[dynamic['id'].isin(static_test['id'])].copy()

In [6]:
features = ['aniongap', 'bicarbonate', 'bun', 'calcium', 'chloride', 'creatinine', 'glucose', 'sodium', 'potassium']

scaler = StandardScaler()

dynamic_train.loc[:, features] = scaler.fit_transform(dynamic_train[features])
dynamic_val.loc[:, features] = scaler.transform(dynamic_val[features])
dynamic_test.loc[:, features] = scaler.transform(dynamic_test[features])  

dynamic_train.head()

Unnamed: 0,id,charttime,aniongap,bicarbonate,bun,calcium,chloride,creatinine,glucose,sodium,potassium
0,28793466,4/12/29 3:35,0.076122,-0.358918,-0.907262,0.44144,0.577309,-0.713705,0.061863,0.287982,-0.640548
2,28164589,3/11/59 1:11,-0.745351,2.006437,0.803869,-0.096334,0.165747,0.014806,-0.704427,1.503503,-1.165692
3,20329785,5/12/34 19:32,-0.950719,2.734239,-0.20049,1.83965,-1.892064,-0.61657,-0.633693,-0.753893,-0.377975
4,24566943,6/25/55 15:45,0.076122,-0.176968,0.320289,0.333885,0.577309,0.063373,0.285855,0.461628,-0.377975
5,21792938,4/13/28 14:18,-0.950719,0.914735,0.357487,-1.38699,-0.794565,1.42326,-0.633693,-1.27483,1.197458


### Dynamic train preprocessing

In [7]:
id_lengths_train = dynamic_train['id'].value_counts().to_dict()
dynamic_train = dynamic_train.sort_values(by=['id', 'charttime'])
dynamic_train = dynamic_train.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_train['id']).agg(list)

dynamic_train

id
20001305    [[-0.5399823568385593, 0.36888343933745094, 0....
20001361    [[-0.33461420568125255, -0.17696774170196863, ...
20003491    [[-0.12924605452394583, -0.35891813538177514, ...
20009330    [[0.2814902477906676, -1.0867197101010013, -0....
20009550    [[0.2814902477906676, 1.2786354077364837, 0.20...
                                  ...                        
29993312    [[1.9244354570491216, -0.5408685290615817, 2.4...
29994296    [[-0.5399823568385593, -0.7228189227413881, 1....
29996513    [[0.6922265501052811, -0.35891813538177514, 0....
29997500    [[-0.745350507995866, 2.734238557174936, -0.53...
29998399    [[-0.745350507995866, -0.17696774170196863, -0...
Length: 7853, dtype: object

### Dynamic val preprocessing

In [8]:
id_lengths_val = dynamic_val['id'].value_counts().to_dict()
dynamic_val = dynamic_val.sort_values(by=['id', 'charttime'])
dynamic_val = dynamic_val.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_val['id']).agg(list)

dynamic_val

id
20008098    [[0.4868583989479744, -0.7228189227413881, -0....
20018116    [[-0.9507186591531728, 1.2786354077364837, 0.1...
20020590    [[-0.5399823568385593, 0.5508338330172575, 0.0...
20032048    [[-0.33461420568125255, 0.18693304565764443, -...
20034762    [[0.2814902477906676, -0.7228189227413881, -0....
                                  ...                        
29934368    [[0.07612209663336089, -0.35891813538177514, -...
29961750    [[-1.1560868103104796, 0.004982651977837905, -...
29970938    [[-0.33461420568125255, 0.004982651977837905, ...
29981257    [[-0.12924605452394583, 0.004982651977837905, ...
29981653    [[-0.5399823568385593, 0.5508338330172575, -0....
Length: 873, dtype: object

### Dynamic test preprocessing

In [9]:
id_lengths_test = dynamic_test['id'].value_counts().to_dict()
dynamic_test = dynamic_test.sort_values(by=['id', 'charttime'])
dynamic_test = dynamic_test.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_test['id']).agg(list)

dynamic_test

id
20003425    [[-1.5668231126250929, 1.096685014056677, -0.3...
20015722    [[-0.9507186591531728, 0.18693304565764443, -0...
20042079    [[-0.745350507995866, 0.36888343933745094, 0.0...
20042619    [[-0.33461420568125255, 0.5508338330172575, -0...
20043333    [[-0.745350507995866, -0.35891813538177514, -0...
                                  ...                        
29914458    [[0.8975947012625878, -0.17696774170196863, -0...
29923363    [[-0.12924605452394583, -0.17696774170196863, ...
29925024    [[-0.5399823568385593, -0.35891813538177514, -...
29954601    [[0.8975947012625878, 0.36888343933745094, 0.3...
29988601    [[0.6922265501052811, -0.17696774170196863, 0....
Length: 970, dtype: object

In [10]:
notes = notes[['id', 'charttime', 'text', 'interval']]

notes_train = notes[notes['id'].isin(static_train['id'])].copy()
notes_val = notes[notes['id'].isin(static_val['id'])].copy()
notes_test = notes[notes['id'].isin(static_test['id'])].copy()

In [11]:
train_data = MultimodalClassifierDataset(
    static=static_train, dynamic=dynamic_train, 
    id_lengths=id_lengths_train, notes=notes_train
    )
validation_data = MultimodalClassifierDataset(
    static=static_val, dynamic=dynamic_val, 
    id_lengths=id_lengths_val, notes=notes_val
    )

train_loader = DataLoader(train_data, batch_size=2000, shuffle=True, collate_fn=collation)
val_loader = DataLoader(validation_data, batch_size=400, shuffle=True, collate_fn=collation)

In [12]:
out_features = static_train['los_icu_binned'].nunique()

out_features

10

In [13]:
text_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

for params in text_model.parameters():
    params.requires_grad = False

model = LOSNetWeighted(input_size=9, out_features=out_features, hidden_size=128, text_model=text_model, task='cls')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
model = model.to(device)

print(f'device: {device}')

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


device: cuda


In [14]:
for name, param in model.text_model.named_parameters():
    print(f"Layer: {name}, Frozen: {not param.requires_grad}\n")

Layer: embeddings.word_embeddings.weight, Frozen: True

Layer: embeddings.position_embeddings.weight, Frozen: True

Layer: embeddings.token_type_embeddings.weight, Frozen: True

Layer: embeddings.LayerNorm.weight, Frozen: True

Layer: embeddings.LayerNorm.bias, Frozen: True

Layer: encoder.layer.0.attention.self.query.weight, Frozen: True

Layer: encoder.layer.0.attention.self.query.bias, Frozen: True

Layer: encoder.layer.0.attention.self.key.weight, Frozen: True

Layer: encoder.layer.0.attention.self.key.bias, Frozen: True

Layer: encoder.layer.0.attention.self.value.weight, Frozen: True

Layer: encoder.layer.0.attention.self.value.bias, Frozen: True

Layer: encoder.layer.0.attention.output.dense.weight, Frozen: True

Layer: encoder.layer.0.attention.output.dense.bias, Frozen: True

Layer: encoder.layer.0.attention.output.LayerNorm.weight, Frozen: True

Layer: encoder.layer.0.attention.output.LayerNorm.bias, Frozen: True

Layer: encoder.layer.0.intermediate.dense.weight, Frozen: True

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [16]:
!nvidia-smi

Sun Apr  7 18:49:39 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    Off  | 00000000:00:05.0 Off |                  Off |
| 31%   55C    P2    67W / 230W |   1573MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [17]:
loss_base_path = 'losses/bert-frozen-all/with-outliers/combined_classification'
model_save_path = 'saved-models/no-outliers/bert_frozen_all_combined_classification_best_model.pth'
writer = SummaryWriter('tensorboard/runs/with-outliers/combined_classification')

In [18]:
epochs = 200
training_loss = []
validation_loss = []
train_weighted_f1_scores = []
val_weighted_f1_scores = []
patience = 10
stagnation = 0

for epoch in range (1, epochs+1):

    print(f'training epoch: [{epoch}/{epochs}]')
    model.train()
    training_loss_epoch = 0
    all_true_labels = []
    all_predicted_labels = []

    for step, batch in enumerate(train_loader):
        packed_dynamic_X, notes_X, notes_intervals, los = batch

        packed_dynamic_X = packed_dynamic_X.to(device)
        los = los.to(device)

        notes_X_gpu = []
        for notes in notes_X:
            notes_gpu = {key: value.to(device) for key, value in notes.items()}
            notes_X_gpu.append(notes_gpu)

        outputs = model(packed_dynamic_X, notes_X_gpu, notes_intervals)
        predicted_labels = torch.argmax(outputs, dim=1)
        true_labels = torch.argmax(los, dim=1)

        loss = criterion(outputs, true_labels)
        writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + step)
        training_loss_epoch += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        all_true_labels.extend(true_labels.cpu().numpy())
        all_predicted_labels.extend(predicted_labels.cpu().numpy())

        if step % max(1, round(len(train_loader) * 0.1)) == 0:
            print(f'step: [{step+1}/{len(train_loader)}] | loss: {loss.item():.4}')

            if step+1 == 1 and epoch == 1:
                with open(f'{loss_base_path}/loss_step.txt', 'w') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

            else:
                with open(f'{loss_base_path}/loss_step.txt', 'a') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

    avg_training_loss_epoch = training_loss_epoch / len(train_loader)
    writer.add_scalar('Loss/train_avg', avg_training_loss_epoch.item(), epoch)

    training_loss.append(avg_training_loss_epoch.item())
    print(f'Training epoch loss: {avg_training_loss_epoch.item():.4f}\n')

    train_weighted_f1_score = f1_score(all_true_labels, all_predicted_labels, average='weighted')
    train_weighted_f1_scores.append(round(train_weighted_f1_score, 4))
    print(f'Training Weighted F1 epoch score: {train_weighted_f1_score:.4f}\n')

    if epoch == 1:
        with open(f'{loss_base_path}/training_loss_epoch.txt', 'w') as loss_epoch_train_f:
            loss_epoch_train_f.write(f'{avg_training_loss_epoch.item():.4f}\n')

    else:
        with open(f'{loss_base_path}/training_loss_epoch.txt', 'a') as loss_epoch_train_f:
            loss_epoch_train_f.write(f'{avg_training_loss_epoch.item():.4f}\n')

    print(f'validation epoch: [{epoch}/{epochs}]')

    model.eval()
    with torch.no_grad():
        validation_loss_epoch = 0
        val_all_true_labels = []
        val_all_predicted_labels = []

        for val_step, val_batch in enumerate(val_loader):
            packed_dynamic_X_val, notes_X_val, notes_intervals_val, los_val = val_batch

            packed_dynamic_X_val = packed_dynamic_X_val.to(device)
            los_val = los_val.to(device)

            notes_X_val_gpu = []
            for notes in notes_X_val:
                notes_val_gpu = {key: value.to(device) for key, value in notes.items()}
                notes_X_val_gpu.append(notes_val_gpu)

            val_outputs = model(packed_dynamic_X_val, notes_X_val_gpu, notes_intervals_val)

            val_predicted_labels = torch.argmax(outputs, dim=1)
            val_true_labels = torch.argmax(los_val, dim=1)

            val_loss = criterion(val_outputs, val_true_labels)
            writer.add_scalar('Loss/val', val_loss.item(), epoch * len(val_loader) + val_step)
            validation_loss_epoch += val_loss

            val_all_true_labels.extend(true_labels.cpu().numpy())
            val_all_predicted_labels.extend(predicted_labels.cpu().numpy())

        avg_validation_loss = validation_loss_epoch / len(val_loader)
        writer.add_scalar('Loss/val_avg', avg_validation_loss.item(), epoch)
        print(f'Validation epoch loss: {avg_validation_loss.item():.4f}\n')

        val_weighted_f1_score = f1_score(val_all_true_labels, val_all_predicted_labels, average='weighted')
        val_weighted_f1_scores.append(round(train_weighted_f1_score, 4))
        print(f'Validation Weighted F1 epoch score: {train_weighted_f1_score:.4f}\n')
        
        if len(validation_loss) == 0 or (avg_validation_loss.item() < min(validation_loss)):
            stagnation = 0
            torch.save(model.state_dict(), model_save_path)
            print(f'new minimum validation loss')
            print(f'model saved\n')


        else:
            stagnation += 1

        validation_loss.append(avg_validation_loss.item())

        if epoch == 1:
            with open(f'{loss_base_path}/validation_loss_epoch.txt', 'w') as loss_epoch_val_f:
                loss_epoch_val_f.write(f'{avg_validation_loss.item():.4f}\n')

        else:
            with open(f'{loss_base_path}/validation_loss_epoch.txt', 'a') as loss_epoch_val_f:
                loss_epoch_val_f.write(f'{avg_validation_loss.item():.4f}\n')

        if stagnation >= patience:
            print(f'No improvement over {patience} epochs')
            print('Early stopping\n')
            break

    model.train()

    print('===============================\n')

writer.close()
print(f'min training loss: {min(training_loss):.4f}')
print(f'min validation loss: {min(validation_loss):.4f}')

training epoch: [1/200]
step: [1/4] | loss: 2.331
step: [2/4] | loss: 2.328
step: [3/4] | loss: 2.325
step: [4/4] | loss: 2.322
Training epoch loss: 2.3265

Training Weighted F1 epoch score: 0.0038

validation epoch: [1/200]
Validation epoch loss: 2.3218

Validation Weighted F1 epoch score: 0.0038

new minimum validation loss
model saved


training epoch: [2/200]
step: [1/4] | loss: 2.321
step: [2/4] | loss: 2.319
step: [3/4] | loss: 2.316
step: [4/4] | loss: 2.314
Training epoch loss: 2.3175

Training Weighted F1 epoch score: 0.0063

validation epoch: [2/200]
Validation epoch loss: 2.3118

Validation Weighted F1 epoch score: 0.0063

new minimum validation loss
model saved


training epoch: [3/200]
step: [1/4] | loss: 2.312
step: [2/4] | loss: 2.31
step: [3/4] | loss: 2.308
step: [4/4] | loss: 2.304
Training epoch loss: 2.3083

Training Weighted F1 epoch score: 0.0292

validation epoch: [3/200]
Validation epoch loss: 2.3016

Validation Weighted F1 epoch score: 0.0292

new minimum valid

KeyboardInterrupt: 