## Code For Frozen Bert Weights for Regression
- This model trains and tests solely on patients who did not die

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from multimodal import MultimodalDataset, LOSNetWeighted, collation
from sklearn.preprocessing import StandardScaler
from transformers import AutoModel
import pandas as pd
import numpy as np
from torch.utils.tensorboard import SummaryWriter

In [21]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Sun Apr  7 09:19:09 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   46C    P8    19W / 300W |   1641MiB / 49140MiB |      0%      Default |
|                               |            

In [22]:
base_path = 'data/regression/with-outliers/partitioned'

static_train = pd.read_csv(f'{base_path}/static_alive_train.csv')
static_val = pd.read_csv(f'{base_path}/static_alive_val.csv')
static_test = pd.read_csv(f'{base_path}/static_alive_test.csv')

In [23]:
static_train.shape

(6678, 24)

In [24]:
notes = pd.read_csv('data/notes_cleaned.csv')
notes_train = notes[notes['id'].isin(static_train['id'])]
notes_val = notes[notes['id'].isin(static_val['id'])]
notes_test = notes[notes['id'].isin(static_test['id'])]

In [25]:
dynamic = pd.read_csv('data/dynamic_cleaned.csv')
dynamic_train = dynamic[dynamic['id'].isin(static_train['id'])].copy()
dynamic_val = dynamic[dynamic['id'].isin(static_val['id'])].copy()
dynamic_test = dynamic[dynamic['id'].isin(static_test['id'])].copy()

In [26]:
features = ['aniongap', 'bicarbonate', 'bun', 'calcium', 'chloride', 'creatinine', 'glucose', 'sodium', 'potassium']

scaler = StandardScaler()

dynamic_train.loc[:, features] = scaler.fit_transform(dynamic_train[features])
dynamic_val.loc[:, features] = scaler.transform(dynamic_val[features])
dynamic_test.loc[:, features] = scaler.transform(dynamic_test[features])  

dynamic_train.head()

Unnamed: 0,id,charttime,aniongap,bicarbonate,bun,calcium,chloride,creatinine,glucose,sodium,potassium
0,28793466,4/12/29 3:35,0.137698,-0.426118,-0.892771,0.414173,0.584317,-0.69464,0.067924,0.285866,-0.628766
1,26115624,9/7/50 0:22,-0.746722,-0.237967,-1.007057,-0.550493,1.27878,-0.69464,-0.615583,0.643713,-0.895698
2,28164589,3/11/59 1:11,-0.746722,2.019844,0.859625,-0.121752,0.16764,0.001459,-0.711513,1.53833,-1.16263
7,28478629,10/8/96 5:30,0.358803,0.702787,1.9263,-1.086418,-0.804607,0.558338,-0.075972,-0.429827,0.572427
8,25387632,5/14/48 9:25,-0.304512,0.702787,-0.588006,0.199803,-0.804607,-0.787453,-0.759479,-0.787674,-0.895698


### Dynamic train preprocessing

In [27]:
id_lengths_train = dynamic_train['id'].value_counts().to_dict()
dynamic_train = dynamic_train.sort_values(by=['id', 'charttime'])
dynamic_train = dynamic_train.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_train['id']).agg(list)

dynamic_train

id
20001361    [[-0.3045122933429746, -0.23796719989995685, -...
20003491    [[-0.08340719150706118, -0.426118097367006, -0...
20008098    [[0.5799081140006791, -0.8024198923011043, -0....
20009330    [[0.3588030121647657, -1.1787216872352027, -0....
20011505    [[2.127643826852073, -2.495777969504547, -0.85...
                                  ...                        
29991038    [[0.5799081140006791, -0.8024198923011043, 1.3...
29991539    [[0.5799081140006791, 1.0790890823693873, 0.40...
29994296    [[-0.525617395178888, -0.8024198923011043, 1.5...
29997500    [[-0.7467224970148014, 2.77244715957283, -0.51...
29998399    [[-0.7467224970148014, -0.23796719989995685, -...
Length: 6678, dtype: object

### Dynamic val preprocessing

In [28]:
id_lengths_val = dynamic_val['id'].value_counts().to_dict()
dynamic_val = dynamic_val.sort_values(by=['id', 'charttime'])
dynamic_val = dynamic_val.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_val['id']).agg(list)

dynamic_val

id
20031816    [[0.3588030121647657, 1.643541774770535, -0.39...
20033924    [[-0.525617395178888, -0.04981630243290767, -0...
20055925    [[-0.7467224970148014, 0.3264854925011907, -1....
20058274    [[3.8964846415393803, -3.248381559372744, 4.89...
20063894    [[0.3588030121647657, -0.6142689948340552, -0....
                                  ...                        
29911812    [[-1.1889327006866282, 0.1383345950341415, -0....
29917727    [[0.8010132158365925, -1.1787216872352027, 2.7...
29935333    [[-0.3045122933429746, 1.4553908773034858, -0....
29962832    [[0.5799081140006791, -0.6142689948340552, 4.1...
29996513    [[0.8010132158365925, -0.426118097367006, 0.32...
Length: 743, dtype: object

### Dynamic test preprocessing

In [29]:
id_lengths_test = dynamic_test['id'].value_counts().to_dict()
dynamic_test = dynamic_test.sort_values(by=['id', 'charttime'])
dynamic_test = dynamic_test.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_test['id']).agg(list)

dynamic_test

id
20003425    [[-1.631142904358455, 1.0790890823693873, -0.3...
20009550    [[0.3588030121647657, 1.2672399798364367, 0.25...
20024788    [[1.2432234195084193, -0.8024198923011043, 1.7...
20029679    [[-0.9678275988507148, 0.1383345950341415, -0....
20034400    [[0.13769791032885223, -0.9905707897681535, 1....
                                  ...                        
29932591    [[-0.9678275988507148, 2.2079944671716825, -0....
29934368    [[0.13769791032885223, -0.426118097367006, -0....
29941780    [[0.5799081140006791, -0.04981630243290767, 0....
29988601    [[0.8010132158365925, -0.23796719989995685, 0....
29993312    [[2.127643826852073, -0.6142689948340552, 2.53...
Length: 825, dtype: object

In [30]:
notes = notes[['id', 'charttime', 'text', 'interval']]

notes_train = notes[notes['id'].isin(static_train['id'])].copy()
notes_val = notes[notes['id'].isin(static_val['id'])].copy()
notes_test = notes[notes['id'].isin(static_test['id'])].copy()

In [31]:
train_data = MultimodalDataset(static=static_train, dynamic=dynamic_train, id_lengths=id_lengths_train, notes=notes_train)
validation_data = MultimodalDataset(static=static_val, dynamic=dynamic_val, id_lengths=id_lengths_val, notes=notes_val)

train_loader = DataLoader(train_data, batch_size=2000, shuffle=True, collate_fn=collation)
val_loader = DataLoader(validation_data, batch_size=400, shuffle=True, collate_fn=collation)

In [32]:
text_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

for params in text_model.parameters():
    params.requires_grad = False

model = LOSNetWeighted(input_size=9, out_features=1, hidden_size=128, text_model=text_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f'device: {device}')

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


device: cuda


In [33]:
for name, param in model.text_model.named_parameters():
    print(f"Layer: {name}, Frozen: {not param.requires_grad}\n")

Layer: embeddings.word_embeddings.weight, Frozen: True

Layer: embeddings.position_embeddings.weight, Frozen: True

Layer: embeddings.token_type_embeddings.weight, Frozen: True

Layer: embeddings.LayerNorm.weight, Frozen: True

Layer: embeddings.LayerNorm.bias, Frozen: True

Layer: encoder.layer.0.attention.self.query.weight, Frozen: True

Layer: encoder.layer.0.attention.self.query.bias, Frozen: True

Layer: encoder.layer.0.attention.self.key.weight, Frozen: True

Layer: encoder.layer.0.attention.self.key.bias, Frozen: True

Layer: encoder.layer.0.attention.self.value.weight, Frozen: True

Layer: encoder.layer.0.attention.self.value.bias, Frozen: True

Layer: encoder.layer.0.attention.output.dense.weight, Frozen: True

Layer: encoder.layer.0.attention.output.dense.bias, Frozen: True

Layer: encoder.layer.0.attention.output.LayerNorm.weight, Frozen: True

Layer: encoder.layer.0.attention.output.LayerNorm.bias, Frozen: True

Layer: encoder.layer.0.intermediate.dense.weight, Frozen: True

In [34]:
class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, predicted, actual):
        return torch.sqrt(self.mse(predicted, actual))

In [35]:
criterion = RMSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [36]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Sun Apr  7 09:19:29 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   49C    P2    74W / 300W |   3941MiB / 49140MiB |      0%      Default |
|                               |            

In [37]:
loss_base_path = 'losses/bert-frozen-all/with-outliers/alive_regression'
writer = SummaryWriter('tensorboard/runs/with-outliers/alive_regression')

In [38]:
epochs = 200
training_loss = []
validation_loss = []
patience = 10
stagnation = 0

for epoch in range(1, epochs+1):
    print(f'training epoch: [{epoch}/{epochs}]')
    model.train()
    training_loss_epoch = 0

    for step, batch in enumerate(train_loader):
        packed_dynamic_X, notes_X, notes_intervals, los = batch

        packed_dynamic_X = packed_dynamic_X.to(device)
        los = los.to(device)

        notes_X_gpu = []
        for notes in notes_X:
            notes_gpu = {key: value.to(device) for key, value in notes.items()}
            notes_X_gpu.append(notes_gpu)

        outputs = model(packed_dynamic_X, notes_X_gpu, notes_intervals)

        loss = criterion(outputs, los)
        writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + step)
        training_loss_epoch += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % max(1, round(len(train_loader) * 0.1)) == 0:
            print(f'step: [{step+1}/{len(train_loader)}] | loss: {loss.item():.4}')

            if step+1 == 1 and epoch == 1:
                with open(f'{loss_base_path}/loss_step.txt', 'w') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

            else:
                with open(f'{loss_base_path}/loss_step.txt', 'a') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

    avg_training_loss_epoch = training_loss_epoch / len(train_loader)
    writer.add_scalar('Loss/train_avg', avg_training_loss_epoch.item(), epoch)

    training_loss.append(avg_training_loss_epoch.item())
    print(f'Training epoch loss: {avg_training_loss_epoch.item():.4f}\n')

    if epoch == 1:
        with open(f'{loss_base_path}/training_loss_epoch.txt', 'w') as loss_epoch_train_f:
            loss_epoch_train_f.write(f'{avg_training_loss_epoch.item():.4f}\n')

    else:
        with open(f'{loss_base_path}/training_loss_epoch.txt', 'a') as loss_epoch_train_f:
            loss_epoch_train_f.write(f'{avg_training_loss_epoch.item():.4f}\n')

    print(f'validation epoch: [{epoch}/{epochs}]')
    
    model.eval()
    with torch.no_grad():
        validation_loss_epoch = 0

        for val_step, val_batch in enumerate(val_loader):
            packed_dynamic_X_val, notes_X_val, notes_intervals_val, los_val = val_batch

            packed_dynamic_X_val = packed_dynamic_X_val.to(device)
            los_val = los_val.to(device)

            notes_X_val_gpu = []
            for notes in notes_X_val:
                notes_val_gpu = {key: value.to(device) for key, value in notes.items()}
                notes_X_val_gpu.append(notes_val_gpu)

            val_outputs = model(packed_dynamic_X_val, notes_X_val_gpu, notes_intervals_val)
            val_loss = criterion(val_outputs, los_val)

            writer.add_scalar('Loss/val', val_loss.item(), epoch * len(val_loader) + val_step)
            validation_loss_epoch += val_loss

        avg_validation_loss = validation_loss_epoch / len(val_loader)
        writer.add_scalar('Loss/val_avg', avg_validation_loss.item(), epoch)
        print(f'Validation epoch loss: {avg_validation_loss.item():.4f}\n')
        
        if len(validation_loss) == 0 or (avg_validation_loss.item() < min(validation_loss)):
            stagnation = 0
            torch.save(model.state_dict(), 'saved-models/no-outliers/bert_frozen_all_alive_regression_best_model.pth')
            print(f'new minimum validation loss')
            print(f'model saved\n')

        else:
            stagnation += 1

        validation_loss.append(avg_validation_loss.item())

        if epoch == 1:
            with open(f'{loss_base_path}/validation_loss_epoch.txt', 'w') as loss_epoch_val_f:
                loss_epoch_val_f.write(f'{avg_validation_loss.item():.4f}\n')

        else:
            with open(f'{loss_base_path}/validation_loss_epoch.txt', 'a') as loss_epoch_val_f:
                loss_epoch_val_f.write(f'{avg_validation_loss.item():.4f}\n')

        if stagnation >= patience:
            print(f'No improvement over {patience} epochs')
            print('Early stopping\n')
            break

    model.train()

    print('===============================\n')

writer.close()
print(f'min training loss: {min(training_loss):.4f}')
print(f'min validation loss: {min(validation_loss):.4f}')
    

training epoch: [1/200]
step: [1/4] | loss: 7.486
step: [2/4] | loss: 8.521
step: [3/4] | loss: 8.5
step: [4/4] | loss: 8.356
Training epoch loss: 8.2158

validation epoch: [1/200]
Validation epoch loss: 7.6988

new minimum validation loss
model saved


training epoch: [2/200]
step: [1/4] | loss: 7.858
step: [2/4] | loss: 8.22
step: [3/4] | loss: 8.867
step: [4/4] | loss: 7.006
Training epoch loss: 7.9876

validation epoch: [2/200]
Validation epoch loss: 7.6922

new minimum validation loss
model saved


training epoch: [3/200]
step: [1/4] | loss: 8.19
step: [2/4] | loss: 8.444
step: [3/4] | loss: 8.005
step: [4/4] | loss: 8.078
Training epoch loss: 8.1792

validation epoch: [3/200]
Validation epoch loss: 7.6312

new minimum validation loss
model saved


training epoch: [4/200]
step: [1/4] | loss: 8.068
step: [2/4] | loss: 8.552
step: [3/4] | loss: 7.974
step: [4/4] | loss: 8.193
Training epoch loss: 8.1969

validation epoch: [4/200]
Validation epoch loss: 7.6915


training epoch: [5/20