In [1]:
import pandas as pd
import numpy as np

%load_ext autoreload
%autoreload 2

In [2]:
from sklearn.preprocessing import StandardScaler, OneHotEncoder, FunctionTransformer, MinMaxScaler
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import GridSearchCV
from sklearn.feature_selection import SequentialFeatureSelector
from sklearn.feature_selection import RFECV
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV

from IPython.display import display

In [3]:
pd.set_option('display.max_rows', 100, 'display.max_columns', None)

In [4]:
train_file = 'data/train_raw.csv'
# test_file = 'data/train_raw.csv'

In [33]:
import random
random.seed(42)
# read data
print('Reading CSV...')
t_df = pd.read_csv(train_file, index_col=0, dtype='float')
print('Done')
t_df['id'] = t_df['id'].astype(int)
t_df['id'] = t_df['id'].astype(int)
# split train-validation
sick = set(t_df[t_df['SepsisLabel'] == 1.0]['id'].unique())
healthy = set(t_df['id'].unique()) - sick
t_sick = set(random.sample(tuple(sick), 1000))
v_sick = sick - t_sick
t_healthy = set(random.sample(tuple(healthy), 1000))
v_healthy = healthy - t_healthy

train = t_df[t_df.id.isin(list(t_sick) + list(t_healthy))].sort_values(['id', 'SepsisLabel'])
valid = t_df[t_df.id.isin(list(v_sick) + list(v_healthy))].sort_values(['id', 'SepsisLabel'])

Reading CSV...
Done


In [34]:
train

Unnamed: 0,HR,O2Sat,Temp,SBP,MAP,DBP,Resp,EtCO2,BaseExcess,HCO3,FiO2,pH,PaCO2,SaO2,AST,BUN,Alkalinephos,Calcium,Chloride,Creatinine,Bilirubin_direct,Glucose,Lactate,Magnesium,Phosphate,Potassium,Bilirubin_total,TroponinI,Hct,Hgb,PTT,WBC,Fibrinogen,Platelets,Age,Unit1,Unit2,HospAdmTime,ICULOS,SepsisLabel,id,Unit3,Gender_0.0,Gender_1.0
23.0,81.000000,100.000000,38.0000,136.500000,90.000000,71.000000,12.500000,35.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.0,65.79,0.0,0.0,-0.02,1.0,0.0,1,1.0,0.0,1.0
24.0,81.000000,100.000000,38.0000,136.500000,90.000000,71.000000,12.500000,35.000000,0.500000,0.500000,0.500000,0.500000,0.500000,0.0,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.500000,0.500000,0.0,0.0,0.0,0.000000,0.0,0.500000,0.500000,0.0,0.0,0.000000,0.0,65.79,0.0,0.0,-0.02,2.0,0.0,1,1.0,0.0,1.0
25.0,81.750000,100.000000,38.0000,136.500000,90.000000,71.000000,12.125000,35.000000,0.333333,0.333333,0.333333,0.333333,0.333333,0.0,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.333333,0.333333,0.0,0.0,0.0,0.000000,0.0,0.333333,0.333333,0.0,0.0,0.000000,0.0,65.79,0.0,0.0,-0.02,3.0,0.0,1,1.0,0.0,1.0
26.0,81.230769,100.000000,38.0825,119.625000,87.750000,63.500000,12.038462,35.000000,0.250000,0.250000,0.500000,0.250000,0.250000,0.0,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.250000,0.250000,0.0,0.0,0.0,0.000000,0.0,0.250000,0.250000,0.0,0.0,0.000000,0.0,65.79,0.0,0.0,-0.02,4.0,0.0,1,1.0,0.0,1.0
27.0,81.750000,100.000000,38.0825,115.730769,81.692308,62.461538,12.012500,35.000000,0.200000,0.400000,0.400000,0.200000,0.200000,0.0,0.000000,0.2,0.000000,0.2,0.2,0.2,0.0,0.400000,0.200000,0.2,0.2,0.2,0.000000,0.0,0.400000,0.400000,0.2,0.2,0.000000,0.2,65.79,0.0,0.0,-0.02,5.0,0.0,1,1.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
775883.0,67.000000,97.000000,37.4400,143.000000,74.500000,46.500000,30.500000,47.057456,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.000000,0.0,33.53,1.0,0.0,-0.71,1.0,0.0,19977,0.0,1.0,0.0
775884.0,67.000000,97.000000,37.4400,143.000000,74.500000,46.500000,30.500000,47.057456,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.500000,0.0,0.500000,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.500000,0.0,0.000000,0.000000,0.0,0.0,0.500000,0.0,33.53,1.0,0.0,-0.71,2.0,0.0,19977,0.0,1.0,0.0
775885.0,65.125000,97.375000,37.4400,143.375000,73.375000,45.750000,36.125000,47.057456,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.333333,0.0,0.333333,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.333333,0.0,0.000000,0.000000,0.0,0.0,0.333333,0.0,33.53,1.0,0.0,-0.71,3.0,0.0,19977,0.0,1.0,0.0
775886.0,65.125000,97.115385,37.4400,143.375000,73.375000,45.750000,36.125000,47.057456,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.250000,0.0,0.250000,0.0,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.250000,0.0,0.000000,0.000000,0.0,0.0,0.250000,0.0,33.53,1.0,0.0,-0.71,4.0,0.0,19977,0.0,1.0,0.0


In [56]:
# scaling_columns = ['HR', 'O2Sat', 'Temp', 'SBP', 'MAP', 'DBP', 'Resp', 'EtCO2', 'Age', 'HospAdmTime', 'ICULOS']  # rest are already scaled
# scaler = MinMaxScaler()
# scaler.fit(train[scaling_columns])

In [57]:
train_y = train['SepsisLabel']
train_X = train.drop(['SepsisLabel'], axis=1)

valid_y = valid['SepsisLabel']
valid_X = valid.drop(['SepsisLabel'], axis=1)

In [58]:
from patient_LSTM import TrainDataset, TestDataset, patientLSTM, batch_collate

WINDOW = 3

train_ds = TrainDataset(train_X, train_y, WINDOW)
valid_ds = TrainDataset(valid_X, valid_y, WINDOW)


In [59]:
from torch.utils.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=batch_collate)
valid_dl = DataLoader(valid_ds, batch_size=8, collate_fn=batch_collate)

In [60]:
from score_functions import F1
from LSTMTrainer import RNNTrainer
import torch.nn as nn
import torch
from sklearn.metrics import f1_score

model = patientLSTM(43, 64, 1)

loss_fn = nn.BCEWithLogitsLoss()
score_fn = F1()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999))

epochs = 100

trainer = RNNTrainer(model, loss_fn, optimizer)

trainer.fit(train_dl, valid_dl, epochs, score_fn=score_fn)

--- EPOCH 1/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(432) tensor(430) tensor(570) tensor(568)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(16) tensor(735) tensor(16850) tensor(399)
--- EPOCH 2/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(423) tensor(431) tensor(569) tensor(577)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(130) tensor(2704) tensor(14881) tensor(285)
--- EPOCH 3/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(461) tensor(442) tensor(558) tensor(539)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(73) tensor(415) tensor(17170) tensor(342)
--- EPOCH 4/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(559) tensor(522) tensor(478) tensor(441)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(415) tensor(17585) tensor(0) tensor(0)
--- EPOCH 5/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(428) tensor(339) tensor(661) tensor(572)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(69) tensor(2) tensor(17583) tensor(346)
--- EPOCH 6/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(461) tensor(405) tensor(595) tensor(539)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(69) tensor(3) tensor(17582) tensor(346)
--- EPOCH 7/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(526) tensor(455) tensor(545) tensor(474)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(73) tensor(185) tensor(17400) tensor(342)
--- EPOCH 8/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(426) tensor(335) tensor(665) tensor(574)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(415) tensor(17585) tensor(0) tensor(0)
--- EPOCH 9/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(567) tensor(485) tensor(515) tensor(433)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(415) tensor(17585) tensor(0) tensor(0)
--- EPOCH 10/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(427) tensor(320) tensor(680) tensor(573)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(415) tensor(17585) tensor(0) tensor(0)
--- EPOCH 11/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(480) tensor(407) tensor(593) tensor(520)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(84) tensor(191) tensor(17394) tensor(331)
--- EPOCH 12/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(460) tensor(395) tensor(605) tensor(540)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(415) tensor(17585) tensor(0) tensor(0)
--- EPOCH 13/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

tensor(512) tensor(398) tensor(602) tensor(488)


test_batch:   0%|          | 0/2250 [00:00<?, ?it/s]

tensor(415) tensor(17585) tensor(0) tensor(0)
--- EPOCH 14/100 ---


train_batch:   0%|          | 0/250 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [23]:
torch.tensor([1]).squeeze(1).shape

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)