In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from multimodal import MultimodalDataset, MultimodalNetwork, collation
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np

In [2]:
!nvidia-smi

Sun Mar 31 16:01:23 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:00:05.0 Off |                  Off |
| 37%   55C    P8    24W / 300W |      1MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
notes = pd.read_csv('data/notes_with_interval.csv')

notes.head()

Unnamed: 0,note_id,id,note_type,note_seq,charttime,text,icu_death,interval
0,17915608-RR-64,20008098,RR,64,1975-02-06 16:23:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,0,3
1,17915608-RR-65,20008098,RR,65,1975-02-07 10:50:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,0,2
2,17915608-RR-66,20008098,RR,66,1975-02-07 20:17:00,INDICATION: ___ year old man s/p RUL lobectom...,0,2
3,17915608-RR-67,20008098,RR,67,1975-02-08 12:20:00,INDICATION: ___ year old man s/p VATS to Open...,0,1
4,17915608-RR-68,20008098,RR,68,1975-02-09 07:26:00,INDICATION: ___ year old man s/p open RUL lob...,0,1


In [4]:
dynamic = pd.read_csv('data/dynamic_cleaned.csv')

dynamic.head()

Unnamed: 0,id,charttime,aniongap,bicarbonate,bun,calcium,chloride,creatinine,glucose,sodium,potassium
0,26115624,9/7/50 0:22,12.0,22.0,9.0,7.9,111.0,0.6,97.0,141.0,3.6
1,21792938,4/13/28 14:18,11.0,28.0,46.0,7.1,96.0,5.0,95.0,130.0,5.2
2,28398464,12/9/34 8:10,18.0,21.0,24.0,10.3,100.0,0.7,93.0,135.0,4.0
3,28478629,10/8/96 5:30,17.0,27.0,86.0,7.4,96.0,3.3,142.0,135.0,4.7
4,22195489,9/18/45 21:05,17.0,29.0,40.0,11.3,97.0,6.4,118.0,139.0,4.2


In [5]:
static = pd.read_csv('data/static.csv')
static = static[static['id'].isin(dynamic['id'])]
static.head()

Unnamed: 0,id,hosp_admittime,hosp_dischtime,icu_intime,icu_outtime,los_icu,icu_death,gender,race,admission_age,...,atrial_fibrillation,malignant_cancer,chf,ckd,cld,copd,diabetes,hypertension,ihd,stroke
11,21999692,2/20/77 21:08,3/6/77 16:40,2/22/77 16:35,2/25/77 17:54,3.05,0,M,BLACK/AFRICAN AMERICAN,55.139306,...,0,0,0,0,0,1,0,0,0,0
22,25936663,1/11/45 23:02,1/22/45 16:06,1/12/45 15:10,1/13/45 18:50,1.15,0,M,WHITE,65.03217,...,0,1,0,0,1,0,0,1,1,0
25,25675339,11/2/83 21:12,11/19/83 15:36,11/6/83 20:59,11/10/83 0:44,3.16,0,M,WHITE,61.837481,...,0,0,0,0,0,0,0,0,0,0
30,27993048,11/19/67 8:23,12/25/67 14:53,11/26/67 16:26,12/5/67 16:53,9.02,0,F,WHITE,56.881238,...,0,0,1,1,0,0,1,1,1,0
32,22987108,6/10/46 16:37,7/12/46 0:00,6/22/46 11:46,7/13/46 0:27,20.53,1,M,WHITE,69.439961,...,1,0,0,1,0,0,0,1,1,0


In [6]:
def get_outcome(row, source_df):
    working_id = row['id']
    outcome = source_df[source_df['id'] == working_id].icu_death.iloc[0]

    return outcome

In [7]:
dynamic['icu_death'] = dynamic.apply(lambda row: get_outcome(row, static), axis=1)

dynamic.head()

Unnamed: 0,id,charttime,aniongap,bicarbonate,bun,calcium,chloride,creatinine,glucose,sodium,potassium,icu_death
0,26115624,9/7/50 0:22,12.0,22.0,9.0,7.9,111.0,0.6,97.0,141.0,3.6,0
1,21792938,4/13/28 14:18,11.0,28.0,46.0,7.1,96.0,5.0,95.0,130.0,5.2,0
2,28398464,12/9/34 8:10,18.0,21.0,24.0,10.3,100.0,0.7,93.0,135.0,4.0,1
3,28478629,10/8/96 5:30,17.0,27.0,86.0,7.4,96.0,3.3,142.0,135.0,4.7,0
4,22195489,9/18/45 21:05,17.0,29.0,40.0,11.3,97.0,6.4,118.0,139.0,4.2,1


In [8]:
features = ['aniongap', 'bicarbonate', 'bun', 'calcium', 'chloride', 'creatinine', 'glucose', 'sodium', 'potassium']

scaler = StandardScaler()

dynamic[features] = scaler.fit_transform(dynamic[features])

dynamic.head()

Unnamed: 0,id,charttime,aniongap,bicarbonate,bun,calcium,chloride,creatinine,glucose,sodium,potassium,icu_death
0,26115624,9/7/50 0:22,-0.718835,-0.298273,-1.014269,-0.640887,1.416042,-0.746533,-0.560555,0.721273,-0.895593,0
1,21792938,4/13/28 14:18,-0.936561,0.813247,0.36349,-1.49212,-0.689949,1.595494,-0.584288,-1.243671,1.3574,0
2,28398464,12/9/34 8:10,0.58752,-0.483526,-0.455718,1.912812,-0.128351,-0.693305,-0.608022,-0.350515,-0.332345,1
3,28478629,10/8/96 5:30,0.369794,0.627994,1.852959,-1.172908,-0.689949,0.69062,-0.026548,-0.350515,0.65334,0
4,22195489,9/18/45 21:05,0.369794,0.998501,0.14007,2.976854,-0.549549,2.340684,-0.311352,0.36401,-0.05072,1


In [9]:
notes['icu_death'] = notes.apply(lambda row: get_outcome(row, static), axis=1)
notes = notes[['id', 'charttime', 'text', 'interval']]

notes.head()

Unnamed: 0,id,charttime,text,interval
0,20008098,1975-02-06 16:23:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,3
1,20008098,1975-02-07 10:50:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,2
2,20008098,1975-02-07 20:17:00,INDICATION: ___ year old man s/p RUL lobectom...,2
3,20008098,1975-02-08 12:20:00,INDICATION: ___ year old man s/p VATS to Open...,1
4,20008098,1975-02-09 07:26:00,INDICATION: ___ year old man s/p open RUL lob...,1


In [10]:
id_lengths = dynamic['id'].value_counts().to_dict()
dynamic = dynamic.sort_values(by=['id', 'charttime'])
dynamic = dynamic.apply(lambda x: list(x[features]), axis=1).groupby(dynamic['id']).agg(list)

dynamic

id
20008098    [[0.5875197296491911, -0.8540328822821452, -0....
20013244    [[-0.501109573269091, 0.6279939835876895, -0.4...
20015730    [[-0.9365612944364038, 0.2574872671202309, -0....
20020562    [[-0.9365612944364038, 0.07223390888650152, -0...
20021110    [[0.15206800848187826, 0.6279939835876895, -0....
                                  ...                        
29990184    [[-0.06565785210177813, 1.924767491223795, 0.2...
29990494    [[0.8052455902328475, -0.6687795240484159, -0....
29992506    [[-0.7188354338527474, 0.44274062535396025, -0...
29994296    [[-0.501109573269091, -0.8540328822821452, 1.4...
29997500    [[-0.7188354338527474, 2.6657809241587125, -0....
Length: 3146, dtype: object

In [11]:
train_data = MultimodalDataset(static=static, dynamic=dynamic, id_lengths=id_lengths, notes=notes)

train_loader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collation)

In [12]:
model = MultimodalNetwork(input_size=9, out_features=1, hidden_size=128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f'device: {device}')

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


device: cuda


In [13]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [14]:
!nvidia-smi

Sun Mar 31 16:01:55 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:00:05.0 Off |                  Off |
| 30%   53C    P2    80W / 300W |   1641MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [15]:
epochs = 5
training_loss = []
validation_loss = []

for epoch in range(1, epochs+1):
    print(f'epoch: [{epoch}/{epochs}]')
    model.train()
    training_loss_epoch = 0

    for step, batch in enumerate(train_loader):
        packed_dynamic_X, notes_X, notes_intervals, los = batch

        packed_dynamic_X = packed_dynamic_X.to(device)
        los = los.to(device)

        notes_X_gpu = []
        for notes in notes_X:
            notes_gpu = {key: value.to(device) for key, value in notes.items()}
            notes_X_gpu.append(notes_gpu)

        outputs = model(packed_dynamic_X, notes_X_gpu, notes_intervals)

        loss = criterion(outputs, los)
        training_loss_epoch += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 50 == 0:
            print(f'step: [{step}/{len(train_loader)}] | loss: {loss.item():.4}')

            if step == 0 and epoch == 1:
                with open('data/losses/loss_step.txt', 'w') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

            else:
                with open('data/losses/loss_step.txt', 'a') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

    avg_training_loss_epoch = training_loss_epoch / len(train_loader)
    training_loss.append(avg_training_loss_epoch.item())
    print(f'epoch loss: {avg_training_loss_epoch.item():.4f}\n')

    if epoch == 1:
        with open('data/losses/loss_epoch.txt', 'w') as loss_epoch_f:
            loss_epoch_f.write(f'{avg_training_loss_epoch.item():.4f}\n')

    else:
        with open('data/losses/loss_epoch.txt', 'a') as loss_epoch_f:
            loss_epoch_f.write(f'{avg_training_loss_epoch.item():.4f}\n')

    print('===============================\n')

epoch: [1/5]
step: [0/1573] | loss: 4.244
step: [50/1573] | loss: 7.21
step: [100/1573] | loss: 8.383
step: [150/1573] | loss: 29.6
step: [200/1573] | loss: 37.03
step: [250/1573] | loss: 15.95
step: [300/1573] | loss: 195.1
step: [350/1573] | loss: 461.2
step: [400/1573] | loss: 15.41
step: [450/1573] | loss: 321.3
step: [500/1573] | loss: 105.2
step: [550/1573] | loss: 6.917
step: [600/1573] | loss: 1.412
step: [650/1573] | loss: 84.04
step: [700/1573] | loss: 8.559
step: [750/1573] | loss: 5.281
step: [800/1573] | loss: 8.33
step: [850/1573] | loss: 2.586
step: [900/1573] | loss: 34.06
step: [950/1573] | loss: 15.84
step: [1000/1573] | loss: 3.315
step: [1050/1573] | loss: 7.359
step: [1100/1573] | loss: 373.5
step: [1150/1573] | loss: 125.2
step: [1200/1573] | loss: 121.7
step: [1250/1573] | loss: 17.9
step: [1300/1573] | loss: 57.5
step: [1350/1573] | loss: 3.868
step: [1400/1573] | loss: 9.39
step: [1450/1573] | loss: 11.65
step: [1500/1573] | loss: 5.402
step: [1550/1573] | loss

RuntimeError: CUDA out of memory. Tried to allocate 46.00 MiB (GPU 0; 47.54 GiB total capacity; 45.25 GiB already allocated; 33.12 MiB free; 46.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF