In [None]:
import pandas as pd
import numpy as np
import torch.nn as nn
import torch

%load_ext autoreload
%autoreload 2

In [6]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler

In [7]:
pd.set_option('display.max_rows', 100, 'display.max_columns', None)

In [8]:
train_file = 'data/train_raw.csv'
test_file = 'data/test_raw.csv'

In [9]:
import random
random.seed(42)

# read data
print('Reading CSV...')
t_df = pd.read_csv(train_file, index_col=0, dtype='float')
test = pd.read_csv(test_file, index_col=0, dtype='float')
print('Done')

t_df['id'] = t_df['id'].astype(int)
test['id'] = test['id'].astype(int)

# split train-validation
sick = set(t_df[t_df['SepsisLabel'] == 1.0]['id'].unique())
healthy = set(t_df['id'].unique()) - sick
t_sick = set(random.sample(sick, int(len(sick) * 0.75)))
v_sick = sick - t_sick
t_healthy = set(random.sample(healthy, int(len(healthy) * 0.75)))
v_healthy = healthy - t_healthy

train = t_df[t_df.id.isin(list(t_sick) + list(t_healthy))].sort_values(['id', 'SepsisLabel'])
valid = t_df[t_df.id.isin(list(v_sick) + list(v_healthy))].sort_values(['id', 'SepsisLabel'])

Reading CSV...
Done


since Python 3.9 and will be removed in a subsequent version.
  t_sick = set(random.sample(sick, int(len(sick) * 0.75)))
since Python 3.9 and will be removed in a subsequent version.
  t_healthy = set(random.sample(healthy, int(len(healthy) * 0.75)))


# Under-sampling (maybe?)

In [194]:
remove_amount = len(t_healthy) - 2 * len(t_sick)
# remove_amount = 5000
remove_healthy = random.sample(healthy, remove_amount)
train = train[~train['id'].isin(remove_healthy)]

since Python 3.9 and will be removed in a subsequent version.
  remove_healthy = random.sample(healthy, remove_amount)


# Remove columns

In [195]:
# train = train.drop(train.columns[8: 34], axis=1)
# valid = valid.drop(valid.columns[8: 34], axis=1)

In [196]:
train

Unnamed: 0,HR,O2Sat,Temp,SBP,MAP,DBP,Resp,EtCO2,BaseExcess,HCO3,FiO2,pH,PaCO2,SaO2,AST,BUN,Alkalinephos,Calcium,Chloride,Creatinine,Bilirubin_direct,Glucose,Lactate,Magnesium,Phosphate,Potassium,Bilirubin_total,TroponinI,Hct,Hgb,PTT,WBC,Fibrinogen,Platelets,Age,Unit1,Unit2,HospAdmTime,ICULOS,SepsisLabel,id,Unit3,Gender_0.0,Gender_1.0
23.0,81.000000,100.000000,38.000000,136.500000,90.000000,71.000000,12.500000,35.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,65.79,0.0,0.0,-0.02,1.0,0.0,1,1.0,0.0,1.0
24.0,81.000000,100.000000,38.000000,136.500000,90.000000,71.000000,12.500000,35.000000,0.500000,0.500000,0.500000,0.500000,0.500000,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.500000,0.500000,0.000000,0.000000,0.000000,0.0,0.0,0.500000,0.500000,0.000000,0.000000,0.0,0.000000,65.79,0.0,0.0,-0.02,2.0,0.0,1,1.0,0.0,1.0
25.0,81.750000,100.000000,38.000000,136.500000,90.000000,71.000000,12.125000,35.000000,0.333333,0.333333,0.333333,0.333333,0.333333,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.333333,0.333333,0.000000,0.000000,0.000000,0.0,0.0,0.333333,0.333333,0.000000,0.000000,0.0,0.000000,65.79,0.0,0.0,-0.02,3.0,0.0,1,1.0,0.0,1.0
26.0,81.230769,100.000000,38.082500,119.625000,87.750000,63.500000,12.038462,35.000000,0.250000,0.250000,0.500000,0.250000,0.250000,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.250000,0.250000,0.000000,0.000000,0.000000,0.0,0.0,0.250000,0.250000,0.000000,0.000000,0.0,0.000000,65.79,0.0,0.0,-0.02,4.0,0.0,1,1.0,0.0,1.0
27.0,81.750000,100.000000,38.082500,115.730769,81.692308,62.461538,12.012500,35.000000,0.200000,0.400000,0.400000,0.200000,0.200000,0.0,0.0,0.200000,0.0,0.200000,0.200000,0.200000,0.0,0.400000,0.200000,0.200000,0.200000,0.200000,0.0,0.0,0.400000,0.400000,0.200000,0.200000,0.0,0.200000,65.79,0.0,0.0,-0.02,5.0,0.0,1,1.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
776714.0,92.990280,97.882567,36.447763,128.615857,89.877763,60.150472,15.603230,47.057456,0.078947,0.078947,0.236842,0.105263,0.078947,0.0,0.0,0.078947,0.0,0.052632,0.078947,0.078947,0.0,0.078947,0.000000,0.105263,0.026316,0.105263,0.0,0.0,0.078947,0.078947,0.052632,0.078947,0.0,0.078947,38.75,0.0,1.0,-0.05,38.0,0.0,19998,0.0,0.0,1.0
776715.0,73.451416,97.098386,37.161072,123.626040,73.992964,52.020070,15.955769,47.057456,0.076923,0.076923,0.230769,0.102564,0.076923,0.0,0.0,0.076923,0.0,0.051282,0.076923,0.076923,0.0,0.076923,0.000000,0.102564,0.025641,0.102564,0.0,0.0,0.076923,0.076923,0.051282,0.076923,0.0,0.076923,38.75,0.0,1.0,-0.05,39.0,0.0,19998,0.0,0.0,1.0
776716.0,73.451416,97.098386,37.161072,123.626040,73.992964,52.020070,15.955769,47.057456,0.075000,0.075000,0.225000,0.100000,0.075000,0.0,0.0,0.075000,0.0,0.050000,0.075000,0.075000,0.0,0.075000,0.000000,0.100000,0.025000,0.100000,0.0,0.0,0.075000,0.075000,0.050000,0.075000,0.0,0.075000,38.75,0.0,1.0,-0.05,40.0,0.0,19998,0.0,0.0,1.0
776717.0,73.050176,97.010936,37.161072,124.847282,76.665763,54.668775,15.995084,47.057456,0.073171,0.073171,0.219512,0.097561,0.073171,0.0,0.0,0.073171,0.0,0.048780,0.073171,0.073171,0.0,0.073171,0.000000,0.097561,0.024390,0.097561,0.0,0.0,0.073171,0.073171,0.048780,0.073171,0.0,0.073171,38.75,0.0,1.0,-0.05,41.0,0.0,19998,0.0,0.0,1.0


In [197]:
scaling_columns = ['HR', 'O2Sat', 'Temp', 'SBP', 'MAP', 'DBP', 'Resp', 'EtCO2', 'Age', 'HospAdmTime', 'ICULOS']  # rest are already scaled
scaler = MinMaxScaler()
train[scaling_columns] = scaler.fit_transform(train[scaling_columns])
valid[scaling_columns] = scaler.transform(valid[scaling_columns])
test[scaling_columns] = scaler.transform(test[scaling_columns])

In [198]:
train_y = train['SepsisLabel'].astype(int)
train_X = train.drop(['SepsisLabel'], axis=1)

valid_y = valid['SepsisLabel'].astype(int)
valid_X = valid.drop(['SepsisLabel'], axis=1)

test_y = test['SepsisLabel'].astype(int)
test_X = test.drop(['SepsisLabel'], axis=1)

In [199]:
from patient_LSTM import TrainDataset, patientLSTM, batch_collate

WINDOW = 30

train_ds = TrainDataset(train_X, train_y, WINDOW)
valid_ds = TrainDataset(valid_X, valid_y, WINDOW)
test_ds = TrainDataset(test_X, test_y, WINDOW)

In [200]:
from torch.utils.data import DataLoader, WeightedRandomSampler

samples_weight = []
for patient in train_ds.patients:
    if patient in sick:
        samples_weight.append(1. / len(t_sick))
    else:
        samples_weight.append(1. / len(t_healthy))
samples_weight = torch.tensor(samples_weight)
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

# train_dl = DataLoader(train_ds, batch_size=8, collate_fn=batch_collate, sampler=sampler)
train_dl = DataLoader(train_ds, batch_size=8, collate_fn=batch_collate, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=32, collate_fn=batch_collate)
test_dl = DataLoader(test_ds, batch_size=32, collate_fn=batch_collate)

In [201]:
for i, (data, target) in enumerate(train_dl):
    print("batch index {}, class 0 {}, other classes {}".format(
        i,
        len(target[target == 0]),
        len(target[target == 1])))

batch index 0, class 0 5, other classes 3
batch index 1, class 0 8, other classes 0
batch index 2, class 0 6, other classes 2
batch index 3, class 0 6, other classes 2
batch index 4, class 0 7, other classes 1
batch index 5, class 0 8, other classes 0
batch index 6, class 0 7, other classes 1
batch index 7, class 0 8, other classes 0
batch index 8, class 0 7, other classes 1
batch index 9, class 0 6, other classes 2
batch index 10, class 0 5, other classes 3
batch index 11, class 0 7, other classes 1
batch index 12, class 0 6, other classes 2
batch index 13, class 0 6, other classes 2
batch index 14, class 0 6, other classes 2
batch index 15, class 0 6, other classes 2
batch index 16, class 0 8, other classes 0
batch index 17, class 0 8, other classes 0
batch index 18, class 0 6, other classes 2
batch index 19, class 0 6, other classes 2
batch index 20, class 0 6, other classes 2
batch index 21, class 0 5, other classes 3
batch index 22, class 0 6, other classes 2
batch index 23, class

In [3]:
from score_functions import F1
from LSTMTrainer import RNNTrainer
from sklearn.metrics import f1_score

model = patientLSTM(features_dim=42, hidden_dim=256, n_layers=2, dropout=0.3)

loss_fn = nn.BCEWithLogitsLoss()
score_fn = F1()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.999))

epochs = 100

trainer = RNNTrainer(model, loss_fn, optimizer)

trainer.fit(train_dl, valid_dl, epochs, score_fn=score_fn, checkpoints=f'models/LSTM_window{WINDOW}', early_stopping=15)

In [4]:
trainer.test(test_dl, score_fn=score_fn)

In [2]:
import imblearn

In [10]:
train

Unnamed: 0,HR,O2Sat,Temp,SBP,MAP,DBP,Resp,EtCO2,BaseExcess,HCO3,FiO2,pH,PaCO2,SaO2,AST,BUN,Alkalinephos,Calcium,Chloride,Creatinine,Bilirubin_direct,Glucose,Lactate,Magnesium,Phosphate,Potassium,Bilirubin_total,TroponinI,Hct,Hgb,PTT,WBC,Fibrinogen,Platelets,Age,Unit1,Unit2,HospAdmTime,ICULOS,SepsisLabel,id,Unit3,Gender_0.0,Gender_1.0
0.0,61.000000,99.000000,36.440000,124.000000,65.000000,43.000000,17.500000,35.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,75.91,0.0,1.0,-98.60,1.0,0.0,0,0.0,1.0,0.0
1.0,61.000000,99.000000,36.440000,124.000000,65.000000,43.000000,17.500000,35.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,75.91,0.0,1.0,-98.60,2.0,0.0,0,0.0,1.0,0.0
2.0,63.250000,98.250000,36.440000,124.750000,64.250000,41.500000,24.625000,35.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,75.91,0.0,1.0,-98.60,3.0,0.0,0,0.0,1.0,0.0
3.0,58.230769,99.461538,36.440000,123.538462,64.769231,41.153846,13.807692,35.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,75.91,0.0,1.0,-98.60,4.0,0.0,0,0.0,1.0,0.0
4.0,63.475000,99.150000,36.440000,121.150000,66.275000,42.400000,20.012500,35.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,75.91,0.0,1.0,-98.60,5.0,0.0,0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
776714.0,92.990280,97.882567,36.447763,128.615857,89.877763,60.150472,15.603230,47.057456,0.078947,0.078947,0.236842,0.105263,0.078947,0.0,0.0,0.078947,0.0,0.052632,0.078947,0.078947,0.0,0.078947,0.0,0.105263,0.026316,0.105263,0.0,0.0,0.078947,0.078947,0.052632,0.078947,0.0,0.078947,38.75,0.0,1.0,-0.05,38.0,0.0,19998,0.0,0.0,1.0
776715.0,73.451416,97.098386,37.161072,123.626040,73.992964,52.020070,15.955769,47.057456,0.076923,0.076923,0.230769,0.102564,0.076923,0.0,0.0,0.076923,0.0,0.051282,0.076923,0.076923,0.0,0.076923,0.0,0.102564,0.025641,0.102564,0.0,0.0,0.076923,0.076923,0.051282,0.076923,0.0,0.076923,38.75,0.0,1.0,-0.05,39.0,0.0,19998,0.0,0.0,1.0
776716.0,73.451416,97.098386,37.161072,123.626040,73.992964,52.020070,15.955769,47.057456,0.075000,0.075000,0.225000,0.100000,0.075000,0.0,0.0,0.075000,0.0,0.050000,0.075000,0.075000,0.0,0.075000,0.0,0.100000,0.025000,0.100000,0.0,0.0,0.075000,0.075000,0.050000,0.075000,0.0,0.075000,38.75,0.0,1.0,-0.05,40.0,0.0,19998,0.0,0.0,1.0
776717.0,73.050176,97.010936,37.161072,124.847282,76.665763,54.668775,15.995084,47.057456,0.073171,0.073171,0.219512,0.097561,0.073171,0.0,0.0,0.073171,0.0,0.048780,0.073171,0.073171,0.0,0.073171,0.0,0.097561,0.024390,0.097561,0.0,0.0,0.073171,0.073171,0.048780,0.073171,0.0,0.073171,38.75,0.0,1.0,-0.05,41.0,0.0,19998,0.0,0.0,1.0
