## DSCA-NET

- Reference codes : Deep Learning for Healthcare, CS598, UIUC, Spring2022

In [1]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split


# set seed
seed = 100
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "./preprocessed/"

In [64]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pid.pkl'), 'rb'))
x_dem = pickle.load(open(os.path.join(DATA_PATH,'x_dem.pkl'), 'rb'))
x_per = pickle.load(open(os.path.join(DATA_PATH,'x_per_added.pkl'), 'rb'))
x_d = pickle.load(open(os.path.join(DATA_PATH,'x_d_added.pkl'), 'rb'))


In [65]:
print(len(pids))
print(len(x_dem))
print(len(x_per))
print(len(x_d))

46520
46520
46520
46520


In [66]:
print("Patient ID:", pids[0])
print("DemoGraphic:", x_dem[0])
print("Examination of first visit:", x_per[0][0])
print("Prescription of first visit:", x_d[0][0])


# for visit in range(len(vids[3])):
#     print(f"\t{visit}-th visit id:", vids[3][visit])
#     print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
#     print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 249
DemoGraphic: [0, 1, 1, 3, 0]
Examination of first visit: [9.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Prescription of first visit: [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [67]:
### Leaving Only patients who have both of examination and prescriptions
dem_no_idx = set()
per_no_idx = set()
d_no_idx = set()

for idx in range(len(x_dem)):
    
    if len(x_dem[idx]) == 0:
        dem_no_idx.add(idx)
    
    if len(x_per[idx]) == 0:
        per_no_idx.add(idx)
    
    if len(x_d[idx]) == 0:
        d_no_idx.add(idx)

In [68]:
tot_no_idx = per_no_idx | d_no_idx
all_idx = set(range(len(x_dem)))
alive_idx = all_idx - tot_no_idx

alive_idx_list = list(alive_idx)
alive_idx_list.sort()


x_dem_f = [x_dem[i] for i in alive_idx_list]
x_per_f = [x_per[i] for i in alive_idx_list]
x_d_f = [x_d[i] for i in alive_idx_list]


In [69]:
print(len(x_dem_f))
print(len(x_per_f))
print(len(x_d_f))


19432
19432
19432


### Build Dataset

In [70]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, x_dem, x_per, x_d):
        
        self.x_dem = x_dem
        self.x_per = x_per
        self.x_d = x_d
    
    def __len__(self):
        
        return len(self.x_d)
    
    def __getitem__(self, index):
        
        x_dem = self.x_dem[index]
        x_per = self.x_per[index]
        x_d = self.x_d[index]
        return x_dem, x_per, x_d
        

dataset = CustomDataset(x_dem_f, x_per_f, x_d_f)

### Collate

In [78]:
num_patients = len(x_per_f) # 11619
num_visits = [len(patient) for patient in x_per_f] #
max_num_visits = max(num_visits)


### Number of codes

num_codes = len(x_d_f[0][0])
num_per_codes = len(x_per_f[0][0])
print(max_num_visits)


146


In [79]:
len(x_d_f[0])

11

In [15]:
# p = [0] * 4
# ps = [p] * 7
# print(ps)
# a = [[1,2,3,4]]
# a.extend(ps)
# a
# torch.tensor(a)

In [80]:
## 1. (x_d) For all patients, Number of Visit (len(x_d)) must be same -> Need to be padded
## 2. (x_d) For all patients, Number of Prescription codes must be same -> Already Done
## 3. (x_per) For all patients, Number of Visit (len(x_per)) must be same -> Need to be padded
## 4. Mask -> 0 For padded intputs -> No Need
# def collate_fn(data):

#     x_dem_now, x_per_now, x_d_now = zip(*data)

#     x_dem_now = torch.tensor(x_dem_now[0], dtype=torch.float32)
#     x_per_now = x_per_now[0]
#     x_d_now = x_d_now[0]

#     # print(len(x_per_now))
#     # print(len(x_d_now))


#     # x_d_new = torch.zeros((1, max_num_visits, num_codes), dtype=torch.long)
#     # x_per_new = torch.zeros((1, max_num_visits, num_per_codes), dtype=torch.long)


#     ## (x_d) 
#     length_valid_visit = len(x_d_now)
#     # print("Valid Visit : {0}".format(length_valid_visit))
#     num_pad_visit = max_num_visits - length_valid_visit
#     pad = [0] * num_codes  # [0, ... , 0] : For 400 zeros
#     tmp = x_d_now.copy()
#     tmp.extend([pad] * num_pad_visit)
#     x_d_now = torch.tensor(tmp.copy(), dtype=torch.float32)

#     # (x_per)
#     length_valid = len(x_per_now)
#     num_pad_visit = max_num_visits - length_valid
#     pad = [0] * num_per_codes
#     tmp2 = x_per_now.copy()
#     tmp2.extend([pad] * num_pad_visit)
#     x_per_now = torch.tensor(tmp2.copy(), dtype=torch.float32)
    
#     return x_dem_now, x_per_now, x_d_now

## No collating
def collate_fn(data):

    x_dem_now, x_per_now, x_d_now = zip(*data)

    x_dem_now = torch.tensor(x_dem_now[0], dtype=torch.float32)
    x_per_now = torch.tensor(x_per_now[0], dtype=torch.float32)
    x_d_now = torch.tensor(x_d_now[0], dtype=torch.float32)


    
    return x_dem_now, x_per_now, x_d_now

In [104]:
loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
loader_iter = iter(loader)
x_dem_new, x_per_new, x_d_new = next(loader_iter)
# print(x[0])
# print(x_d_new[0])

In [105]:
# x_per, x_d : num patients, max num visits, num codes
# x_dem : num_patients, num demo codes

print(x_dem_new.size())
print(x_per_new.size())  
print(x_d_new.size())  


torch.Size([5])
torch.Size([11, 19])
torch.Size([11, 400])


In [155]:
split = int(len(dataset)*0.9)

lengths = [split, len(dataset) - split]

train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 17488
Length of val dataset: 1944


In [156]:
def load_data(train_dataset, val_dataset, collate_fn):
    
    batch_size = 1
    train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False, collate_fn=collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [157]:
train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)


In [158]:
print(len(train_loader)) ## divided by 32, no remainder
print(len(val_loader)) ## remainder exists

17488
1944


torch.Size([5])
torch.Size([24, 19])
torch.Size([24, 400])
tensor([[0., 0., 1.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]])


In [159]:
## cross attention
def CA(x_c, x_d, x_c1, x_c2, x_c3, x_d1, x_d2, x_d3):
    # concatenate x_dem and x_per_new

        # x_c = x_dem_per_comb_frac.clone()
        # x_d = x_d_new_frac.clone()

        # c_linear = nn.Linear(x_c.size()[0], w_dim)
        # d_linear = nn.Linear(x_d.size()[0], w_dim)
        tanh = nn.Tanh()
        sigmoid = nn.Sigmoid()


        # print(x_c.size())  # 5 + 128 = 133
        # print(x_d.size())  # 128

        # x_c1 = c_linear(x_c)
        # x_d1 = d_linear(x_d)

        # x_c2 = c_linear(x_c)
        # x_d2 = d_linear(x_d)


        s_cd = sigmoid(x_c3 * tanh(x_c1 + x_d1))
        s_dc = sigmoid(x_d3 * tanh(x_c2 + x_d2))

        # weight_c = torch.sum(x_c, dim=0)
        # weight_d = torch.sum(x_d, dim=0)
        # s_dc = weight_d * x_c
        # s_cd = weight_c * x_d


        x = torch.concat([torch.mul(s_cd, x_d), torch.mul(s_dc, x_c)])
        s = torch.concat([s_cd, s_dc])

        # print(x.size())

        return x,s


In [20]:
# class DSCARNN(nn.Module):

#     def __init__(self, h_t_minus_1, ca_t_minus_1, x_t, s_t, num_hidden_dim, num_codes_d, tau):
#         super().__init__()

#         self.h_t_minus_1 = h_t_minus_1
#         self.ca_t_minus_1 = ca_t_minus_1
#         self.x_t = x_t
#         self.s_t = s_t
#         self.num_hidden_dim = num_hidden_dim
#         self.num_codes_d = num_codes_d
        
#         self.linear = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim * 2)
#         self.linear_dec = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim)
#         self.linear_out = nn.Linear(self.num_hidden_dim *2, self.num_codes_d)


    
#         self.tau = tau
#         self.sigmoid = nn.Sigmoid()
#         self.tanh = nn.Tanh()
#         self.relu = nn.ReLU()
    
#     def forward(self):
#         d_t = self.sigmoid(self.linear(self.x_t))
#         ca_t = self.s_t + (1 - d_t) * self.ca_t_minus_1

#         ca_t_tilde = self.tanh(self.linear(ca_t))

#         h_tilde = self.sigmoid(torch.concat([self.linear_dec(self.x_t), self.linear_dec(self.h_t_minus_1)], dim=0))

#         h_t = self.tau * ca_t_tilde + (1-self.tau) * self.sigmoid(h_tilde)

#         y_t_hat = self.sigmoid(self.linear_out(self.relu(self.linear(h_t))))

#         return h_t, ca_t, y_t_hat

In [160]:
# a = torch.zeros((1, 32))
# b = torch.zeros((1, 32))
# c = torch.concat([a, b], dim=0)

In [161]:
class DSCANet(nn.Module):
    
    def __init__(self, num_codes, num_codes_d, tau):
        super().__init__()

        # self.embedding = nn.Embedding(num_embeddings = num_codes, embedding_dim = 128)
        # self.embedding_d = nn.Embedding(num_embeddings = num_codes_d, embedding_dim = 128)
        self.num_hidden_dim = 128
        self.embedding = nn.Linear(num_codes + 5, self.num_hidden_dim)
        self.embedding_d = nn.Linear(num_codes_d, self.num_hidden_dim)

        self.num_codes_d = num_codes_d

        self.c_linear = nn.Linear(self.num_hidden_dim, self.num_hidden_dim)
        self.d_linear = nn.Linear(self.num_hidden_dim, self.num_hidden_dim)

        self.linear = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim * 2)
        self.linear_dec = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim)
        self.linear_out = nn.Linear(self.num_hidden_dim *2, self.num_codes_d)

        self.tau = tau
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

    
    def forward(self, x_dem_new, x_per_new, x_d_new):
        
        # batch_size = x_d_new.shape[0] # batch_size = 1
        x_dem_new = x_dem_new.squeeze(0) # num_demos
        x_per_new = x_per_new.squeeze(0) # num_visit, num_examinations
        x_d_new = x_d_new.squeeze(0) #num_visit, num_codes

        # print(x_dem_new.size()) # 5
        # print(x_per_new.size()) # 76, 4
        # print(x_d_new.size()) # 76,400

        if len(x_per_new.size()) == 1:
            x_per_new = x_per_new.unsqueeze(0)
            x_d_new = x_d_new.unsqueeze(0)

        num_visits = x_per_new.size()[0]

        for v_idx in range(num_visits):
            if v_idx == 0:

                # if v_idx == 0, initialize h_t and ca_t
                h_t_minus_1 = nn.init.kaiming_uniform_(torch.empty(1, self.num_hidden_dim * 2)).squeeze(0)
                ca_t_minus_1 = nn.init.kaiming_uniform_(torch.empty(1, self.num_hidden_dim * 2)).squeeze(0)

                
                # getting embeddings and cross attentions
                # print(x_dem_new.size())
                # print(x_per_new.size())

                x_c = self.embedding(torch.concat([x_dem_new, x_per_new[v_idx]], dim=0))
                x_d = self.embedding_d(x_d_new[v_idx])
                x_c1 = self.c_linear(x_c)
                x_d1 = self.d_linear(x_d)
                x_c2 = self.c_linear(x_c)
                x_d2 = self.d_linear(x_d)
                x_c3 = self.c_linear(x_c)
                x_d3 = self.d_linear(x_d)

                x_t, s_t = CA(x_c, x_d, x_c1, x_c2, x_c3, x_d1, x_d2, x_d3)

                

                d_t = self.sigmoid(self.linear(x_t))
                ca_t = s_t + (1 - d_t) * ca_t_minus_1

                ca_t_tilde = self.tanh(self.linear(ca_t))

                h_tilde = self.sigmoid(torch.concat([self.linear_dec(x_t), self.linear_dec(h_t_minus_1)], dim=0))

                h_t = self.tau * ca_t_tilde + (1-self.tau) * self.sigmoid(h_tilde)
                y_t_hat = self.sigmoid(self.linear_out(self.relu(self.linear(h_t))))

                h_t_minus_1 = h_t
                ca_t_minus_1 = ca_t                

                y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)

                
                

            else:
                # getting embeddings and cross attentions
                x_c = self.embedding(torch.concat([x_dem_new, x_per_new[v_idx]], dim=0))
                x_d = self.embedding_d(x_d_new[v_idx])
                x_c1 = self.c_linear(x_c)
                x_d1 = self.d_linear(x_d)
                x_c2 = self.c_linear(x_c)
                x_d2 = self.d_linear(x_d)
                x_c3 = self.c_linear(x_c)
                x_d3 = self.d_linear(x_d)

                x_t, s_t = CA(x_c, x_d, x_c1, x_c2, x_c3, x_d1, x_d2, x_d3)
                
                # RNN
                # h_t_minus_1, ca_t_minus_1, y_t_hat = DSCARNN(h_t_minus_1, ca_t_minus_1, x_t, s_t, self.num_hidden_dim, self.num_codes_d, self.tau).forward()
                d_t = self.sigmoid(self.linear(x_t))
                ca_t = s_t + (1 - d_t) * ca_t_minus_1

                ca_t_tilde = self.tanh(self.linear(ca_t))

                h_tilde = self.sigmoid(torch.concat([self.linear_dec(x_t), self.linear_dec(h_t_minus_1)], dim=0))

                h_t = self.tau * ca_t_tilde + (1-self.tau) * self.sigmoid(h_tilde)

                y_t_hat = self.sigmoid(self.linear_out(self.relu(self.linear(h_t))))

                h_t_minus_1 = h_t
                ca_t_minus_1 = ca_t

                
                y_hats = torch.concat([y_hats, y_t_hat.unsqueeze(0)], dim=0)

                


        return y_hats
    

# load the model here
dsca_net = DSCANet(num_codes = num_per_codes, num_codes_d=num_codes, tau=0.6)
y_hats = dsca_net.forward(x_dem_new, x_per_new, x_d_new)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


In [162]:
print(y_hats.size())
print(x_d_new.size())

torch.Size([1, 400])
torch.Size([1, 400])


In [163]:
dsca_net

DSCANet(
  (embedding): Linear(in_features=24, out_features=128, bias=True)
  (embedding_d): Linear(in_features=400, out_features=128, bias=True)
  (c_linear): Linear(in_features=128, out_features=128, bias=True)
  (d_linear): Linear(in_features=128, out_features=128, bias=True)
  (linear): Linear(in_features=256, out_features=256, bias=True)
  (linear_dec): Linear(in_features=256, out_features=128, bias=True)
  (linear_out): Linear(in_features=256, out_features=400, bias=True)
  (sigmoid): Sigmoid()
  (tanh): Tanh()
  (relu): ReLU()
)

In [164]:
dsca_net.parameters()

<generator object Module.parameters at 0x15363fac0>

In [165]:
criterion = nn.BCELoss(reduction='none')
optimizer = torch.optim.Adam(list(dsca_net.parameters()), lr=1e-2)

In [166]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, val_loader):
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    
    for idx, (x_dem_new, x_per_new, x_d_new) in enumerate(val_loader):
        y_hat = model.forward(x_dem_new, x_per_new, x_d_new)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, x_d_new.detach().to('cpu')), dim=0)

        if idx == 100:
            break
    """
    TODO:
        Calculate precision, recall, f1, and roc auc scores.
        Use `average='binary'` for calculating precision, recall, and fscore.
    """
    p, r, f, roc_auc = None, None, None, None
    # print(y_pred.size()) ## batch_size *76, 400
    # print(y_true.size()) ## batch_size *76, 400

    # Micro F1
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='micro')
    return p, r, f

In [170]:
def train(model, train_loader, val_loader, n_epochs):
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for idx, (x_dem_new, x_per_new, x_d_new) in enumerate(train_loader):
            """
            TODO:
                1. zero grad
                2. model forward
                3. calculate loss
                4. loss backward
                5. optimizer step
            """
            loss = None
            # your code here
            optimizer.zero_grad()
            
            ## model forward
            y_hat = model.forward(x_dem_new, x_per_new, x_d_new)
            print(idx)
            print(y_hat.size())
            print(x_d_new.size())
            print(y_hat)
            print(x_d_new)

            if len(x_d_new.size()) == 1:
                x_d_new = x_d_new.unsqueeze(0)
                loss = criterion(y_hat, x_d_new)
            else:
                loss = criterion(y_hat, x_d_new).sum()
            
            ##loss
            # print(loss.size()) (76,400) -> 1
            
            loss.requires_grad_(True)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            # if idx % 100 == 0:
            #     print(idx)
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f = eval_model(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'
              .format(epoch+1, p, r, f))

In [171]:
# number of epochs to train the model
# about 40 minutes, for 1 epoch
n_epochs = 1
dsca_net = DSCANet(num_codes = num_per_codes, num_codes_d=num_codes, tau=0.6)
train(dsca_net, train_loader, val_loader, n_epochs)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
torch.Size([2, 400])
torch.Size([2, 400])
tensor([[0.5229, 0.5477, 0.5396, 0.5172, 0.4919, 0.5178, 0.4695, 0.5455, 0.5310,
         0.5226, 0.4765, 0.5032, 0.5240, 0.4897, 0.4844, 0.4900, 0.4830, 0.4966,
         0.4762, 0.4860, 0.5015, 0.5105, 0.4573, 0.4715, 0.4923, 0.5189, 0.4944,
         0.5004, 0.4704, 0.5083, 0.4881, 0.5111, 0.4982, 0.5476, 0.4938, 0.4703,
         0.5359, 0.5158, 0.5177, 0.4576, 0.5016, 0.5151, 0.4673, 0.5049, 0.4990,
         0.4894, 0.5034, 0.4712, 0.5505, 0.5159, 0.4861, 0.4934, 0.5134, 0.4508,
         0.4941, 0.4930, 0.5004, 0.4809, 0.5192, 0.5071, 0.5139, 0.5529, 0.5265,
         0.4969, 0.4778, 0.4787, 0.5065, 0.4950, 0.4419, 0.4889, 0.5122, 0.5019,
         0.4793, 0.4803, 0.4868, 0.4848, 0.5080, 0.4492, 0.5048, 0.4885, 0.5275,
         0.4493, 0.4906, 0.4890, 0.4866, 0.4956, 0.5254, 0.5384, 0.5746, 0.5252,
         0.5005, 0.5164, 0.4534, 0.5152, 0.5141, 0.5045, 0.5158, 0.4786, 0.4889,
         0.5387, 0.5369, 0.5055, 0.4737, 0.5392, 0.5075, 0.5200, 

In [169]:
for idx, (a, b, c) in enumerate(train_loader):
    # print(a.size())
    # print(b.size())
    # print(c.size())
    if idx == 5963:
        print(a.size())
        print(b.size())
        print(c.size())
        print(c)
        break
    else:
        continue




torch.Size([5])
torch.Size([2, 19])
torch.Size([2, 400])
tensor([[0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [None]:
p,r,f = eval_model(dsca_net, val_loader)
# print(len(p))
# print(len(r))
# print(len(f))

print(p)
print(r)
print(f)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0.03514340289247296
0.4955447549615229
0.0656322501911033


0.0018108455175380844
0.4999100233939176
0.0036086193783380215


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 1 	 Training Loss: 21108.849920
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 1 	 Validation p: 0.00, r:0.49, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 2 	 Training Loss: 21108.802270
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 2 	 Validation p: 0.00, r:0.51, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 3 	 Training Loss: 21108.780032
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 3 	 Validation p: 0.00, r:0.50, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 4 	 Training Loss: 21108.721962
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 4 	 Validation p: 0.00, r:0.50, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 5 	 Training Loss: 21108.513059
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 5 	 Validation p: 0.00, r:0.50, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 6 	 Training Loss: 21108.921375
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 6 	 Validation p: 0.00, r:0.51, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 7 	 Training Loss: 21108.800465
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 7 	 Validation p: 0.00, r:0.51, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 8 	 Training Loss: 21108.706733
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 8 	 Validation p: 0.00, r:0.51, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
Epoch: 9 	 Training Loss: 21108.709818
torch.Size([7676, 400])
torch.Size([7676, 400])
Epoch: 9 	 Validation p: 0.00, r:0.51, f: 0.00


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900


In [71]:
y_hats = dsca_net.forward(x_dem_new, x_per_new, x_d_new)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)
