In [432]:
import torch
from torch import nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [433]:
diabetes_df = pd.read_csv("diabetic_data.csv")

In [434]:
from sklearn.model_selection import train_test_split

class ReadmissionPredictionDataset(torch.utils.data.Dataset):
    def __init__(self, diabetes_df: pd.DataFrame, verbose = False):
        diabetes_df['readmitted'] = diabetes_df['readmitted'].replace({'<30': 'YES', '>30': 'YES'})
        diabetes_df = diabetes_df.drop(columns = 'payer_code')
        diabetes_df = diabetes_df.drop(columns = 'patient_nbr')
        diabetes_df = diabetes_df.drop(columns = 'medical_specialty')
        diabetes_df = diabetes_df.drop(columns = 'encounter_id')
        diabetes_df = diabetes_df[diabetes_df['diag_1'] != '?']
        diabetes_df = diabetes_df[diabetes_df['diag_2'] != '?']
        diabetes_df = diabetes_df[diabetes_df['diag_3'] != '?']
        diabetes_df = diabetes_df[diabetes_df['race'] != '?']
        diabetes_df = diabetes_df[diabetes_df['weight'] != '?']
        diabetes_df = diabetes_df[diabetes_df['admission_type_id'] != 5]
        diabetes_df = diabetes_df[diabetes_df['admission_type_id'] != 6]
        diabetes_df = diabetes_df[diabetes_df['admission_source_id'] != 17]

        diabetes_df = pd.get_dummies(diabetes_df, columns=['gender'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['race'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['admission_type_id'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['admission_source_id'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['discharge_disposition_id'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['diag_1'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['diag_2'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['diag_3'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['max_glu_serum'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['A1Cresult'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['metformin'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['repaglinide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['nateglinide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['chlorpropamide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['glimepiride'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['acetohexamide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['glipizide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['glyburide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['tolbutamide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['pioglitazone'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['rosiglitazone'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['acarbose'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['miglitol'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['troglitazone'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['tolazamide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['examide'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['citoglipton'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['insulin'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['glyburide-metformin'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['glipizide-metformin'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['glimepiride-pioglitazone'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['metformin-rosiglitazone'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['metformin-pioglitazone'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['change'], prefix='is', prefix_sep='')
        diabetes_df = pd.get_dummies(diabetes_df, columns=['diabetesMed'], prefix='is', prefix_sep='')

        age_mapping = {
            '[0-10)': 5,
            '[10-20)': 15,
            '[20-30)': 25,
            '[30-40)': 35,
            '[40-50)': 45,
            '[50-60)': 55,
            '[60-70)': 65,
            '[70-80)': 75,
            '[80-90)': 85,
            '[90-100)': 95
        }
        diabetes_df['age'] = diabetes_df['age'].replace(age_mapping)
        weight_mapping = {
            '[0-25)': 12.5,
            '[25-50)': 37.5,
            '[50-75)': 62.5,
            '[75-100)': 87.5,
            '[100-125)': 112.5,
            '[125-150)': 137.5,
            '[150-175)': 162.5,
            '[175-200)': 187.5
        }
        diabetes_df['weight'] = diabetes_df['weight'].replace(weight_mapping)

        readmit_mapping = {
            'YES': 1,
            'NO': 0
        }
        diabetes_df['readmitted'] = diabetes_df['readmitted'].replace(readmit_mapping)

        # groups = diabetes_df.groupby('readmitted')
        # counts = groups.count()
        # print(counts)

        # unique_values = diabetes_df['diag_1'].unique()
        # #print(unique_values)
        # groups_type = diabetes_df.groupby('admission_source_id')
        # counts_type = groups_type.count()
        #print(counts_type)
        #counts['encounter_id'].plot.bar()

        #data_x =  diabetes_df.loc[:, diabetes_df.columns[:len(diabetes_df.columns) - 1]]

        data_x = diabetes_df.select_dtypes(include=[int, float]).drop('readmitted', axis=1)

        data_y = diabetes_df['readmitted']

        self.input = torch.tensor(data_x.values).type(torch.float32)

        self.output = torch.tensor(data_y.values).type(torch.float32)
    
    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        return (self.input[idx], self.output[idx])
    
train_df, test_df = train_test_split(diabetes_df, test_size=0.2, random_state=42)

train_dataset = ReadmissionPredictionDataset(train_df, verbose = True)
test_dataset = ReadmissionPredictionDataset(test_df)

# train_df.head()
print(train_dataset[0])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 100, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 100, shuffle = True)


(tensor([85.,  3., 58.,  3., 20.,  0.,  0.,  0.,  9.]), tensor(0.))


In [435]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(9,64), nn.Sigmoid(), nn.Linear(64,128), nn.Sigmoid(), nn.Linear(128, 1), nn.Sigmoid())

    def forward(self, x):
        yhat = self.layers(x)
        return yhat

In [436]:
from tqdm.notebook import tqdm

def train_network(model, train_loader, criterion, optimizer, nepoch=100):
    try:
        for epoch in tqdm(range(nepoch)):
            print('EPOCH %d'%epoch)
            total_loss = 0
            count = 0
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.squeeze(), labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                count += 1
            print('{:>12s} {:>7.5f}'.format('Train loss:', total_loss/count))
    except KeyboardInterrupt:
        print('Exiting from training early')
    return

In [437]:
def test_network(model, test_loader):
    correct = 0
    total = 0
    true, pred = [], []
    with torch.no_grad():
        for inputs, labels  in test_loader:
            outputs = model(inputs)
            predicted = torch.round(outputs).squeeze()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            true.append(labels)
            pred.append(predicted)
    acc = (100 * correct / total)
    print('accuracy: %0.3f' % (acc))
    true = np.concatenate(true)
    pred = np.concatenate(pred)
    return acc, true, pred

In [441]:
model = SimpleNet()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.0005)

In [446]:
train_network(model, train_loader, criterion, optimizer, nepoch=200)

  0%|          | 0/1000 [00:00<?, ?it/s]

EPOCH 0
 Train loss: 0.59280
EPOCH 1
 Train loss: 0.58838
EPOCH 2
 Train loss: 0.58685
EPOCH 3
 Train loss: 0.58882
EPOCH 4
 Train loss: 0.58601
EPOCH 5
 Train loss: 0.58874
EPOCH 6
 Train loss: 0.58654
EPOCH 7
 Train loss: 0.58656
EPOCH 8
 Train loss: 0.58770
EPOCH 9
 Train loss: 0.58790
EPOCH 10
 Train loss: 0.58781
EPOCH 11
 Train loss: 0.58568
EPOCH 12
 Train loss: 0.58823
EPOCH 13
 Train loss: 0.58544
EPOCH 14
 Train loss: 0.58487
EPOCH 15
 Train loss: 0.58787
EPOCH 16
 Train loss: 0.58388
EPOCH 17
 Train loss: 0.58542
EPOCH 18
 Train loss: 0.58578
EPOCH 19
 Train loss: 0.58554
EPOCH 20
 Train loss: 0.58598
EPOCH 21
 Train loss: 0.58445
EPOCH 22
 Train loss: 0.59545
EPOCH 23
 Train loss: 0.58911
EPOCH 24
 Train loss: 0.58522
EPOCH 25
 Train loss: 0.58464
EPOCH 26
 Train loss: 0.58464
EPOCH 27
 Train loss: 0.58440
EPOCH 28
 Train loss: 0.58344
EPOCH 29
 Train loss: 0.58447
EPOCH 30
 Train loss: 0.58647
EPOCH 31
 Train loss: 0.58788
EPOCH 32
 Train loss: 0.58508
EPOCH 33
 Train loss

In [447]:
acc, true, pred = test_network(model, test_loader)
print(pred)

accuracy: 56.827
[1. 0. 1. 0. 0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 1. 0. 1. 1. 0. 1. 1. 0. 0. 1.
 0. 0. 1. 1. 1. 0. 0. 1. 0. 0. 0. 1. 1. 1. 1. 0. 1. 1. 1. 1. 0. 0. 1. 0.
 0. 0. 1. 1. 0. 1. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0.
 1. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 1. 1. 1. 1. 0. 0. 0.
 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 0. 0. 1. 1. 0. 1. 1.
 1. 1. 0. 1. 1. 1. 1. 0. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 1. 0. 1. 0.
 0. 0. 0. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 1. 0. 1. 1. 0. 0. 1.
 0. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 1. 0. 0. 1. 1. 1. 1. 0. 1. 1. 1. 1. 0.
 0. 1. 1. 1. 0. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1. 1. 1. 0. 1.
 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 1. 0. 0. 1. 0. 0.
 1. 0. 0. 1. 1. 1. 1. 0. 0. 1. 0. 1. 0. 0. 1. 1. 1. 1. 0. 0. 1. 1. 1. 0.
 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1.
 0. 0. 0. 1. 1. 1. 0. 0. 1. 0. 1. 