In [None]:
import pandas as pd
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from sklearn import metrics
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix,accuracy_score, roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.metrics import roc_auc_score, confusion_matrix,accuracy_score, roc_curve, auc, precision_recall_curve

from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence, pad_packed_sequence
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams["figure.figsize"] = 5,2

### Utility

In [None]:
import pickle
train_data = torch.load('train_data.pt')
test_data = torch.load('test_data.pt')

In [None]:
from torch.utils.data import Dataset
class Covid_19(Dataset):
    def __init__(self,dataList):
        self.data_list = dataList
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        ptid = self.data_list[idx][0]
        sample = self.data_list[idx][1]
        return sample
    

In [None]:
import torch
import torch.utils.data
import torchvision


class Sampler(torch.utils.data.sampler.Sampler):

    def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None):

        self.indices = list(range(len(dataset))) \
            if indices is None else indices
        self.callback_get_label = callback_get_label

        self.num_samples = len(self.indices) \
            if num_samples is None else num_samples
            
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            if label in label_to_count:
                label_to_count[label] += 1
            else:
                label_to_count[label] = 1
                
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
                   for idx in self.indices]
        self.weights = torch.DoubleTensor(weights)

    def _get_label(self, dataset, idx):
        return dataset[idx][-1]
                
    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(
            self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples

In [None]:
import torch.nn.functional as F

def pad_collate(batch):
    (pt_problem_batch, pt_lab_batch,pt_diag_batch,pt_orders_batch,pt_medAdmin_batch,pt_demo_batch,label_batch) = zip(*batch)
    pt_demo_batch =torch.stack([item[5] for item in batch])
    label_batch =[item[6] for item in batch]
    max_problem = np.max(np.array([[DB.size(0),DB.size(1)]for DB in pt_problem_batch]),axis=0)
    max_lab = np.max(np.array([[DB.size(0),DB.size(1)]for DB in pt_lab_batch]),axis=0)
    max_diag = np.max(np.array([[DB.size(0),DB.size(1)]for DB in pt_diag_batch]),axis=0)
    max_orders= np.max(np.array([[DB.size(0),DB.size(1)]for DB in pt_orders_batch]),axis=0)
    max_medAdmin= np.max(np.array([[DB.size(0),DB.size(1)]for DB in pt_medAdmin_batch]),axis=0)
    problem_batch = torch.stack([
        F.pad(DB, [0, max_problem[1] - DB.size(1), 0, max_problem[0] - DB.size(0)])
        for DB in pt_problem_batch
    ])
    lab_batch = torch.stack([
        F.pad(DB, [0, max_lab[1] - DB.size(1), 0, max_lab[0] - DB.size(0)])
        for DB in pt_lab_batch
        ])
    diag_batch = torch.stack([
        F.pad(DB, [0, max_diag[1] - DB.size(1), 0, max_diag[0] - DB.size(0)])
        for DB in pt_diag_batch
        ])
    orders_batch = torch.stack([
        F.pad(DB, [0, max_orders[1] - DB.size(1), 0, max_orders[0] - DB.size(0)])
        for DB in pt_orders_batch
        ])
    medAdmin_batch = torch.stack([
        F.pad(DB, [0, max_medAdmin[1] - DB.size(1), 0, max_medAdmin[0] - DB.size(0)])
        for DB in pt_medAdmin_batch
        ])

    return problem_batch, lab_batch, diag_batch, orders_batch, medAdmin_batch, pt_demo_batch,label_batch

In [None]:
from torch.utils.data import DataLoader

train_dataset=Covid_19(train_data)
validation_dataset = Covid_19(test_data)
trainSampler = Sampler(train_dataset)
dataloader = DataLoader(train_dataset, batch_size=24,
                        shuffle=False, num_workers=0,drop_last=True,collate_fn=pad_collate,sampler = trainSampler)
validation_loader = DataLoader(validation_dataset, batch_size=24,
                        shuffle=False, num_workers=0,drop_last=True,collate_fn=pad_collate)

In [None]:
class DBNet(nn.Module):
    def __init__(self):

        super(DBNet, self).__init__()
        
        ####vars to set
        dropout_GRU = 0.5
        hidden_size=512
        no_GRU_layers=4
        output_size = 2
        GRU_input_size = 376 
        no_hops = 8
        
        
        ###(kernel_size,in_channel,out_channel, stride,pad)  
        problem_kernels = [(7,1,8,2,0),(5,8,8,2,0),(3,8,1,1,0)]
        lab_kernels = [(19,1,8,2,0),(15,8,8,2,0),(11,8,8,2,0),(7,8,8,2,0),(4,8,8,2,0),(3,8,1,2,1)]
        diag_kernels =  [(11,1,8,2,0),(7,8,8,2,0),(5,8,8,2,0),(3,8,1,2,1)]
        orders_kernels = [(19,1,8,2,0),(15,8,8,2,0),(11,8,8,2,0),(7,8,8,2,0),(4,8,8,2,0),(3,8,1,2,1)]
        medAdmin_kernels = [(19,1,8,2,0),(15,8,8,2,0),(11,8,8,2,0),(7,8,8,2,0),(4,8,8,2,0),(3,8,1,2,1)]

        self.problem_layers = nn.ModuleList()
        self.lab_layers = nn.ModuleList()
        self.diag_layers = nn.ModuleList()
        self.orders_layers = nn.ModuleList()
        self.medAdmin_layers = nn.ModuleList()

        
        self.make_encoder_block(problem_kernels,self.problem_layers)
        self.make_encoder_block(lab_kernels,self.lab_layers)
        self.make_encoder_block(diag_kernels,self.diag_layers)
        self.make_encoder_block(orders_kernels,self.orders_layers)
        self.make_encoder_block(medAdmin_kernels,self.medAdmin_layers)

        self.GRU = nn.GRU(GRU_input_size, hidden_size, no_GRU_layers, dropout=dropout_GRU,
                                         batch_first=True, bias=False, bidirectional=True)
        self.GRU_dropout = nn.Dropout(p=dropout_GRU)
        self.conv_att = nn.Conv1d(in_channels=1, out_channels=no_hops, kernel_size=hidden_size * 2, stride=1)

        self.linear = nn.Linear(hidden_size * 2 * no_hops + 296, output_size, bias=False)


        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Hardtanh(0, 1)
        self.init_weights()
        self.no_hops = no_hops
        self.hidden_size = hidden_size 
        self.output_size = output_size
        self.no_GRU_layers = no_GRU_layers

        
    def make_encoder_block(self,kernel_list,layer_list):
        for i,kernels in enumerate(kernel_list):
            layer_list.append(nn.Conv1d(in_channels=kernels[1],
                                        out_channels=kernels[2],
                                        kernel_size=kernels[0],
                                        stride=kernels[3],
                                        padding=kernels[4]))
            if i <len(layer_list)-1:
                layer_list.append(nn.ReLU())
            layer_list.append(nn.Dropout(p=0.3))
        layer_list.append(nn.AdaptiveAvgPool1d(1))

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Conv1d):
                torch.nn.init.xavier_uniform_(m.weight)

    def problem_encoder(self, problem):
        encoded_problems=[]
        for i in range(problem.size(1)):
            for j,layer in enumerate(self.problem_layers):
                if j==0:
                    out = layer(torch.unsqueeze(problem[:,i],dim=1))
                else:
                    out =layer(out)
            encoded_problems.append(out)
        encoded_problems = torch.stack(encoded_problems)
        encoded_problems = torch.squeeze(encoded_problems, dim=3)
        encoded_problems =  torch.transpose(encoded_problems, 0, 1)
        encoded_problems = torch.transpose(encoded_problems, 1, 2)
        encoded_problems = self.tanh(encoded_problems)
        encoded_problems = torch.bmm(encoded_problems, problem)
        encoded_problems = torch.squeeze(encoded_problems, dim=1)
        return encoded_problems
    
    def lab_encoder(self, lab):
        encoded_DAMs=[]
        for i in range(lab.size(1)):
            for j,layer in enumerate(self.lab_layers):
                if j==0:
                    out = layer(torch.unsqueeze(lab[:,i],dim=1))
                else:
                    out =layer(out)
            encoded_DAMs.append(out)
        encoded_DAMs = torch.stack(encoded_DAMs)
        encoded_DAMs = torch.squeeze(encoded_DAMs, dim=3)
        encoded_DAMs = torch.transpose(encoded_DAMs, 0, 1)
        encoded_DAMs = torch.transpose(encoded_DAMs, 1, 2)
        encoded_DAMs = self.tanh(encoded_DAMs)
        encoded_DAMs = torch.bmm(encoded_DAMs, lab)
        encoded_DAMs = torch.squeeze(encoded_DAMs, dim=1)
        return encoded_DAMs
    
    def diag_encoder(self, diag):
        encoded_DPCs=[]
        for i in range(diag.size(1)):
            for j,layer in enumerate(self.diag_layers):
                if j==0:
                    out = layer(torch.unsqueeze(diag[:,i],dim=1))
                else:
                    out =layer(out)
            encoded_DPCs.append(out)
        encoded_DPCs = torch.stack(encoded_DPCs)
        encoded_DPCs = torch.squeeze(encoded_DPCs, dim=3)
        encoded_DPCs =  torch.transpose(encoded_DPCs, 0, 1)
        encoded_DPCs = torch.transpose(encoded_DPCs, 1, 2)
        encoded_DPCs = self.tanh(encoded_DPCs)
        encoded_DPCs = torch.bmm(encoded_DPCs, diag)
        encoded_DPCs = torch.squeeze(encoded_DPCs, dim=1)
        return encoded_DPCs
    
    def orders_encoder(self, orders):
        encoded_orderss = []
        for i in range(orders.size(1)):
            for j,layer in enumerate(self.orders_layers):
                if j==0:
                    out = layer(torch.unsqueeze(orders[:,i],dim=1))
                else:
                    out =layer(out)
            encoded_orderss.append(out)
        encoded_orderss = torch.stack(encoded_orderss)
        encoded_orderss = torch.squeeze(encoded_orderss, dim=3)
        encoded_orderss =  torch.transpose(encoded_orderss, 0, 1)
        encoded_orderss = torch.transpose(encoded_orderss, 1, 2)
        encoded_orderss = self.tanh(encoded_orderss)
        encoded_orderss = torch.bmm(encoded_orderss, orders)
        encoded_orderss = torch.squeeze(encoded_orderss, dim=1)
        return encoded_orderss
    
    def medAdmin_encoder(self, medAdmin):
        encoded_medAdmins = []
        for i in range(medAdmin.size(1)):
            for j,layer in enumerate(self.medAdmin_layers):
                if j==0:
                    out = layer(torch.unsqueeze(medAdmin[:,i],dim=1))
                else:
                    out =layer(out)
            encoded_medAdmins.append(out)
        encoded_medAdmins = torch.stack(encoded_medAdmins)
        encoded_medAdmins = torch.squeeze(encoded_medAdmins, dim=3)
        encoded_medAdmins =  torch.transpose(encoded_medAdmins, 0, 1)
        encoded_medAdmins = torch.transpose(encoded_medAdmins, 1, 2)
        encoded_medAdmins = self.tanh(encoded_medAdmins)
        encoded_medAdmins = torch.bmm(encoded_medAdmins, medAdmin)
        encoded_medAdmins = torch.squeeze(encoded_medAdmins, dim=1)
        return encoded_medAdmins
    
    
    def init_GRU(self, batch_size):
        self.weight = next(self.parameters()).data
        init_state = (Variable(self.weight.new(self.no_GRU_layers * 2, batch_size, self.hidden_size).zero_()))
        return init_state

    
    def GRU_Decoder(self, inputs,batch_size):
        inputs = self.GRU_dropout(inputs)
        init_state = self.init_GRU(batch_size)
        outputs, states = self.GRU(inputs, init_state)
        return outputs,states
    

    def Self_Attention(self, hidden_states, batch_size):
        Attention_list = []
        for i in range(5):
            m1 = self.conv_att(torch.unsqueeze(hidden_states[:, i], dim=1))
            Attention_list.append(torch.squeeze(m1, dim=2))
        Attention_list = torch.stack(Attention_list, dim=2)
        Attention_hops = []
        for i in range(self.no_hops):
            attention_single = self.softmax(Attention_list[:, i])
            Attention_hops.append(attention_single)
        Attention_hops = torch.stack(Attention_hops, dim=1)
        output = torch.bmm(Attention_hops, hidden_states)
        output = output.view(batch_size, -1)
        return output

    def forward(self,problem, lab, diag, orders,medAdmin,pt_demo,batch_size):
        encoded_problems = self.problem_encoder(problem)
        encoded_labs = self.lab_encoder(lab)
        encoded_diags = self.diag_encoder(diag)
        encoded_orderss = self.orders_encoder(orders)
        encoded_medAdmins = self.medAdmin_encoder(medAdmin)
        GRU_input=pad_sequence([torch.transpose(encoded_problems, 0, 1),torch.transpose(encoded_labs, 0, 1),
                                torch.transpose(encoded_diags, 0, 1),torch.transpose(encoded_orderss, 0, 1),
                                torch.transpose(encoded_medAdmins, 0, 1)
                               ])
        GRU_input = torch.transpose(GRU_input, 0, 2)
        outputs_GRU,states_GRU = self.GRU_Decoder(GRU_input, batch_size)
        GRU_out = self.Self_Attention(outputs_GRU, batch_size)
        context = torch.cat((GRU_out, pt_demo), 1)
        linear_y = self.linear(context)
        out = self.sigmoid(linear_y)
        return linear_y,out

### Train

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
import time
EPOCHS = 200
batch_size=24
model=DBNet()
model = model.to(device)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

print("Starting Training of {} model")

epoch_times = []
best_AUC=0
for epoch in range(1,EPOCHS+1):
    start_time = time.time()
    avg_loss = 0.
    counter = 0
    correct = 0
    total = 0
    model.train()
    for sample in dataloader:
        problem, lab, diag, orders,medAdmin,pt_demo,label = sample
        model.zero_grad()
        one_hot_label = np.eye(2)[np.array(label,dtype="int")]
        one_hot_label = torch.tensor(one_hot_label)
        label= torch.tensor(label).type(torch.long)
        problem = problem.to(device)
        lab = lab.to(device)
        diag = diag.to(device)
        orders = orders.to(device)
        medAdmin = medAdmin.to(device)
        pt_demo = pt_demo.to(device)
        counter += 1
        out,predict=model(problem, lab, diag, orders,medAdmin, pt_demo,batch_size)
        one_hot_label = one_hot_label.type_as(out).to(device)
        loss = criterion(out, one_hot_label)
        _, predicted = torch.max(predict.detach().cpu(), 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
    current_time = time.time()
    print("Epoch {}/{} Done, Total Loss: {}, Accuracy : {} ".format(epoch, EPOCHS, avg_loss/len(dataloader),correct/total))
    print("Total Time Elapsed: {} seconds".format(str(current_time-start_time)))

    epoch_times.append(current_time-start_time)
    val_total=0
    val_correct = 0
    model.eval()
    predicted_list=[]
    prediction_probablity=[]
    label_list = []
    for sample in validation_loader:
        problem, lab, diag, orders,medAdmin, pt_demo,label = sample
        #model.zero_grad()
        label= torch.tensor(label).type(torch.long)
        problem = problem.to(device)
        lab = lab.to(device)
        diag = diag.to(device)
        orders = orders.to(device)
        medAdmin = medAdmin.to(device)
        pt_demo = pt_demo.to(device)
        out,predict=model(problem, lab, diag, orders,medAdmin, pt_demo,batch_size)
        _, predicted = torch.max(predict.detach().cpu(), 1)
        predicted_list.append(predicted.cpu().numpy())
        predicted_2 = predict.detach().cpu().numpy()
        prediction_prob = predicted_2[:,1]
        prediction_probablity.append(prediction_prob)
        label_list.append(label.cpu().numpy())
        val_total += label.size(0)
        val_correct += (predicted == label).sum().item()
    Accuracy = val_correct/val_total
    y=np.array(label_list)
    false_positive_rate, recall, thresholds = roc_curve(y.flatten(), np.array(prediction_probablity).flatten())
    roc_auc = auc(false_positive_rate, recall)
    if roc_auc > best_AUC:
        best_AUC = roc_auc
        torch.save(model.state_dict(), "Models_saved/State_checkpoints_{}.thr".format(epoch))
        torch.save(model, "Models_saved/Model_checkpoints_{}.thr".format(epoch))
    print("AUC: {} , Accuracy: {}".format(roc_auc,Accuracy))
print("Total Training Time: {} seconds".format(str(sum(epoch_times))))
