In [42]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

In [48]:
cols = ['Time', 'stay_id', 'stay_key', 'hadm_id', 'age', 'gender', 'Heart Rate',
       'Respiratory Rate', 'SpO2/SaO2', 'pH', 'Potassium', 'Calcium',
       'Glucose', 'Sodium', 'HCO3', 'White Blood Cells', 'Hemoglobin',
       'Red Blood Cells', 'Platelet Count', 'Weight', 'Urea Nitrogen',
       'Creatinine', 'Blood Pressure', '1 hours urine output',
       '6 hours urine output', 'AKI', 'gcs',
       'ventilation', 'vasoactive medications', 'sedative medications']
features = ['time_since', 'age', 'gender', 'Heart Rate',
       'Respiratory Rate', 'SpO2/SaO2', 'pH', 'Potassium', 'Calcium',
       'Glucose', 'Sodium', 'HCO3', 'White Blood Cells', 'Hemoglobin',
       'Red Blood Cells', 'Platelet Count', 'Weight', 'Urea Nitrogen',
       'Creatinine', 'Blood Pressure', '1 hours urine output',
       '6 hours urine output', 'gcs',
       'ventilation', 'vasoactive medications', 'sedative medications']

In [73]:
class AKIDataset(Dataset):

    def __init__(self, csv_file):
        self.dataframe = pd.read_csv(csv_file)
        self.dataframe['stay_key'] = self.dataframe['stay_id']
        self.dataframe = self.dataframe.groupby('stay_id')[cols].ffill().bfill()
        in_time = self.dataframe.groupby('stay_key')[['Time']].first()
        self.dataframe = pd.merge(self.dataframe, in_time, left_on=['stay_key'], right_index=True, how='left')
        self.dataframe['time_since'] = (pd.to_datetime(self.dataframe['Time_x']) - pd.to_datetime(self.dataframe['Time_y'])) / np.timedelta64(1, 'h')
        self.stay_ids = self.dataframe.stay_key.unique()

    def __len__(self):
        return len(self.stay_ids)

    def __getitem__(self, idx):
        
        data = self.dataframe[self.dataframe.stay_key == self.stay_ids[idx]][features].to_numpy()
        label = self.dataframe[self.dataframe.stay_key == self.stay_ids[idx]]['AKI'].to_numpy()
        
        return data, label

In [74]:
ds = AKIDataset('time_series.csv')

In [82]:
ds[3][0].shape, ds[3][1].shape

((203, 27), (203,))