In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import StandardScaler
from sklearn.impute import KNNImputer
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from transformers import AutoTokenizer, AutoModel
from torch.nn import LSTM

In [2]:
!nvidia-smi

Thu Apr  4 16:20:14 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   38C    P8    23W / 300W |      1MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
notes = pd.read_csv('data/notes_with_interval.csv')

notes.head()

Unnamed: 0,note_id,id,note_type,note_seq,charttime,text,icu_death,interval
0,17915608-RR-64,20008098,RR,64,1975-02-06 16:23:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,0,3
1,17915608-RR-65,20008098,RR,65,1975-02-07 10:50:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,0,2
2,17915608-RR-66,20008098,RR,66,1975-02-07 20:17:00,INDICATION: ___ year old man s/p RUL lobectom...,0,2
3,17915608-RR-67,20008098,RR,67,1975-02-08 12:20:00,INDICATION: ___ year old man s/p VATS to Open...,0,1
4,17915608-RR-68,20008098,RR,68,1975-02-09 07:26:00,INDICATION: ___ year old man s/p open RUL lob...,0,1


In [4]:
notes[notes['id'] == 22987108]

Unnamed: 0,note_id,id,note_type,note_seq,charttime,text,icu_death,interval
6799,10007818-RR-10,22987108,RR,10,2046-06-11 08:38:00,EXAMINATION: DUPLEX DOPP ABD/PEL\n\nINDICATIO...,1,11
6800,10007818-RR-11,22987108,RR,11,2046-06-11 14:46:00,EXAMINATION: Ultrasound-guided paracentesis\n...,1,11
6801,10007818-RR-12,22987108,RR,12,2046-06-11 17:29:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,1,11
6802,10007818-RR-13,22987108,RR,13,2046-06-13 14:27:00,EXAMINATION: Ultrasound-guided paracentesis\n...,1,9
6803,10007818-RR-14,22987108,RR,14,2046-06-15 16:42:00,EXAMINATION: RENAL U.S.\n\nINDICATION: ___ y...,1,7
6804,10007818-RR-15,22987108,RR,15,2046-06-17 15:24:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,1,5
6805,10007818-RR-16,22987108,RR,16,2046-06-17 20:13:00,EXAMINATION: CHEST (PORTABLE AP)\n\nINDICATIO...,1,5
6806,10007818-RR-17,22987108,RR,17,2046-06-19 13:55:00,EXAMINATION: CHEST (PA AND LAT)\n\nINDICATION...,1,3
6807,10007818-RR-18,22987108,RR,18,2046-06-19 12:38:00,EXAMINATION: US Interventional Procedure\n\nI...,1,3
6808,10007818-RR-19,22987108,RR,19,2046-06-20 11:58:00,"INDICATION: ___ year old man with cirrhosis, ...",1,2


In [5]:
static = pd.read_csv('data/static.csv')
static = static[static['id'].isin(notes['id'])]

static['gender'] = static['gender'].map(lambda x: 1 if x == 'M' else 0).astype(int)
static.head()

Unnamed: 0,id,hosp_admittime,hosp_dischtime,icu_intime,icu_outtime,los_icu,icu_death,gender,race,admission_age,...,atrial_fibrillation,malignant_cancer,chf,ckd,cld,copd,diabetes,hypertension,ihd,stroke
11,21999692,2/20/77 21:08,3/6/77 16:40,2/22/77 16:35,2/25/77 17:54,3.05,0,1,BLACK/AFRICAN AMERICAN,55.139306,...,0,0,0,0,0,1,0,0,0,0
22,25936663,1/11/45 23:02,1/22/45 16:06,1/12/45 15:10,1/13/45 18:50,1.15,0,1,WHITE,65.03217,...,0,1,0,0,1,0,0,1,1,0
25,25675339,11/2/83 21:12,11/19/83 15:36,11/6/83 20:59,11/10/83 0:44,3.16,0,1,WHITE,61.837481,...,0,0,0,0,0,0,0,0,0,0
30,27993048,11/19/67 8:23,12/25/67 14:53,11/26/67 16:26,12/5/67 16:53,9.02,0,0,WHITE,56.881238,...,0,0,1,1,0,0,1,1,1,0
32,22987108,6/10/46 16:37,7/12/46 0:00,6/22/46 11:46,7/13/46 0:27,20.53,1,1,WHITE,69.439961,...,1,0,0,1,0,0,0,1,1,0


In [6]:
static.isna().sum()

id                        0
hosp_admittime            0
hosp_dischtime            0
icu_intime                0
icu_outtime               0
los_icu                   0
icu_death                 0
gender                    0
race                      0
admission_age             0
weight_admit            125
height                 1381
admission_type            0
first_careunit            0
charlson_score            0
atrial_fibrillation       0
malignant_cancer          0
chf                       0
ckd                       0
cld                       0
copd                      0
diabetes                  0
hypertension              0
ihd                       0
stroke                    0
dtype: int64

In [7]:
numeric_cols = static.select_dtypes(include=[np.number]).columns

imputer = KNNImputer(n_neighbors=3)
static_imputed_numeric = imputer.fit_transform(static[numeric_cols])

static[numeric_cols] = static_imputed_numeric

static = static[numeric_cols]

static.head()

Unnamed: 0,id,los_icu,icu_death,gender,admission_age,weight_admit,height,charlson_score,atrial_fibrillation,malignant_cancer,chf,ckd,cld,copd,diabetes,hypertension,ihd,stroke
11,21999692.0,3.05,0.0,1.0,55.139306,52.0,172.0,11.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
22,25936663.0,1.15,0.0,1.0,65.03217,189.0,185.0,5.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0
25,25675339.0,3.16,0.0,1.0,61.837481,95.0,159.333333,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
30,27993048.0,9.02,0.0,0.0,56.881238,78.366667,168.0,8.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,1.0,1.0,0.0
32,22987108.0,20.53,1.0,1.0,69.439961,86.2,185.0,8.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0


In [8]:
static_x = static.drop(['icu_death', 'los_icu'], axis=1)

static_y = static['icu_death']

In [9]:
X_train, X_test, y_train, y_test = train_test_split(
    static_x,static_y,test_size=0.2,random_state=10,
    stratify=static_y, shuffle=True
)

In [10]:
correlation_matrix = X_train.corr()

threshold = 0.9
threshold_matrix = correlation_matrix[abs(correlation_matrix) >= threshold]
threshold_matrix = threshold_matrix.stack().reset_index().dropna(how='any')

threshold_matrix.columns = ['variable_1', 'variable_2', 'correlation']
threshold_matrix = threshold_matrix.sort_values(by='correlation', ascending=False)
threshold_matrix = threshold_matrix[threshold_matrix.variable_1 != threshold_matrix.variable_2]

threshold_matrix['sorted_pair'] = threshold_matrix.apply(lambda row: tuple(sorted([row['variable_1'], 
                                                                                   row['variable_2']])), 
                                                         axis=1)

threshold_matrix = threshold_matrix.drop_duplicates(subset=['sorted_pair']).drop('sorted_pair', axis=1)

threshold_matrix

Unnamed: 0,variable_1,variable_2,correlation


In [11]:
X_train.head()

Unnamed: 0,id,gender,admission_age,weight_admit,height,charlson_score,atrial_fibrillation,malignant_cancer,chf,ckd,cld,copd,diabetes,hypertension,ihd,stroke
15505,21201300.0,1.0,63.176846,175.1,178.0,6.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0
4115,26145609.0,1.0,66.842143,96.5,170.0,5.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0
17914,20545926.0,1.0,64.588663,61.7,185.0,7.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0
11280,24761294.0,0.0,40.921687,64.7,163.0,12.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
1340,24842483.0,0.0,89.894356,45.0,152.0,5.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


In [12]:
len(X_train.columns)

16

In [13]:
class ICUDeathPredictionDataset(Dataset):
    '''
    Parameters
    ----------
    static: the static DataFrame
    notes: the notes DataFrame
    '''
    
    def __init__(self, static, notes, labels):
        self.static = static
        self.notes = notes
        self.labels = labels
        self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    def __len__(self):
        return len(self.static)
    
    def __getitem__(self, idx):
        static_X = self.static.iloc[idx]
        patient_id = static_X['id']

        # notes
        notes = self.notes[self.notes['id'] == patient_id].sort_index()['text'].tolist()
        notes_X = self.tokenizer(notes, return_tensors='pt', truncation=True, padding='max_length')
        patient_timesteps = len(notes)

        return static_X, notes_X, patient_timesteps

In [14]:
class ICUDeathPredictionModel(nn.Module):
    def __init__(self, static_input_size, out_features=2, hidden_size=256, num_layers=1, batch_first=True):
        super(ICUDeathPredictionModel, self).__init__()

        self.text_model = AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

        self.lstm = LSTM(input_size=768, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first)

        self.fc = nn.Sequential(
            nn.LayerNorm(normalized_shape=static_input_size + hidden_size),
            nn.Linear(in_features=static_input_size + hidden_size, out_features=out_features),
            )

    def forward(self, static_X_batch, notes_X_batch, patient_timesteps):
        embeddings = [self.text_model(**tokenized_notes).pooler_output for tokenized_notes in notes_X_batch]
        padded_embeddings = pad_sequence(embeddings, batch_first=True)
        packed_embeddings = pack_padded_sequence(input=padded_embeddings, lengths=patient_timesteps, batch_first=True, enforce_sorted=False)
        _, (ht, _) = self.time_series_model(packed_embeddings)

        ht = ht[-1]

        st = torch.tensor(static_X_batch.to_numpy())
        combined_representation = torch.cat((ht, st), dim=1)

        y_pred = self.fc(combined_representation)

        return y_pred

In [15]:
train_data = ICUDeathPredictionDataset(static=static_x, notes=notes, labels=static_y)

train_loader = DataLoader(train_data, batch_size=2, shuffle=True)

In [16]:
model = ICUDeathPredictionModel(static_input_size=len(static_x.columns))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f'device: {device}')

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


device: cuda


In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [18]:
epochs = 5
training_loss = []
validation_loss = []

for epoch in range(1, epochs+1):
    print(f'epoch: [{epoch}/{epochs}]')
    model.train()
    training_loss_epoch = 0

    for step, batch in enumerate(train_loader):
        static_X, notes_X, patient_timesteps = batch

        packed_dynamic_X = packed_dynamic_X.to(device)
        los = los.to(device)

        notes_X_gpu = []
        for notes in notes_X:
            notes_gpu = {key: value.to(device) for key, value in notes.items()}
            notes_X_gpu.append(notes_gpu)

        outputs = model(packed_dynamic_X, notes_X_gpu)

        loss = criterion(outputs, los)
        training_loss_epoch += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 50 == 0:
            print(f'step: [{step}/{len(train_loader)}] | loss: {loss.item():.4}')

            if step == 0 and epoch == 1:
                with open('data/losses/loss_step.txt', 'w') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

            else:
                with open('data/losses/loss_step.txt', 'a') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

    avg_training_loss_epoch = training_loss_epoch / len(train_loader)
    training_loss.append(avg_training_loss_epoch.item())
    print(f'epoch loss: {avg_training_loss_epoch.item():.4f}\n')

    if epoch == 1:
        with open('data/losses/loss_epoch.txt', 'w') as loss_epoch_f:
            loss_epoch_f.write(f'{avg_training_loss_epoch.item():.4f}\n')

    else:
        with open('data/losses/loss_epoch.txt', 'a') as loss_epoch_f:
            loss_epoch_f.write(f'{avg_training_loss_epoch.item():.4f}\n')

    print('===============================\n')

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


epoch: [1/5]


ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).