In [1]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from tqdm import tqdm, trange

import torch.nn.functional as F

import json

import random

from time import sleep

In [2]:
X_train, y_train = torch.load("clean_dataset/X_train.pt"), torch.load("clean_dataset/y_train.pt")
X_test, y_test   = torch.load("clean_dataset/X_test.pt"),  torch.load("clean_dataset/y_test.pt")
X_valid, y_valid = torch.load("clean_dataset/X_valid.pt"), torch.load("clean_dataset/y_valid.pt")

In [3]:
with open("clean_dataset/diseases.json") as file:
  diseases = json.loads(file.read())

In [4]:
train_dataset = TensorDataset(X_train, y_train)
test_dataset  = TensorDataset(X_test, y_test)
valid_dataset = TensorDataset(X_valid, y_valid)

In [5]:
dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset, valid_dataset])

In [6]:
train_dataset, test_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.8, 0.1, 0.1])

In [7]:
batch_size = 256
class_count = len(diseases)
feature_count = len(train_dataset[0][0])

In [8]:
X_train.shape

torch.Size([1025602, 894])

In [9]:
feature_count, class_count

(894, 49)

In [10]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size)

In [11]:
class ResBlock(nn.Module):
  def __init__(self, in_features):
    super().__init__()
    self.in_features = in_features
    self.net = nn.Sequential(
      nn.Linear(in_features, in_features),
      nn.BatchNorm1d(in_features),
      nn.ReLU(),
      nn.Dropout(0.3),
      nn.Linear(in_features, in_features),
      nn.BatchNorm1d(in_features),
      nn.ReLU(),
      nn.Dropout(0.3),
      nn.Linear(in_features, in_features),
      nn.BatchNorm1d(in_features),
      nn.ReLU(),
      nn.Dropout(0.3),
      nn.Linear(in_features, in_features),
      nn.BatchNorm1d(in_features),
      nn.ReLU(),
      nn.Dropout(0.3)
    )
  
  def forward(self, x):
    return x + self.net(x)

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [13]:
model = nn.Sequential(
  nn.Linear(feature_count, feature_count // 2),
  nn.ReLU(),
  ResBlock(feature_count // 2),
  nn.Linear(feature_count // 2, feature_count // 4),
  nn.ReLU(),
  ResBlock(feature_count // 4),
  nn.Linear(feature_count // 4, feature_count // 8),
  nn.ReLU(),
  ResBlock(feature_count // 8),
  nn.Linear(feature_count // 8, class_count),
  nn.ReLU(),
  ResBlock(class_count),
  nn.Softmax()
  )


In [14]:
model.to(device)

Sequential(
  (0): Linear(in_features=894, out_features=447, bias=True)
  (1): ReLU()
  (2): ResBlock(
    (net): Sequential(
      (0): Linear(in_features=447, out_features=447, bias=True)
      (1): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.3, inplace=False)
      (4): Linear(in_features=447, out_features=447, bias=True)
      (5): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.3, inplace=False)
      (8): Linear(in_features=447, out_features=447, bias=True)
      (9): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Dropout(p=0.3, inplace=False)
      (12): Linear(in_features=447, out_features=447, bias=True)
      (13): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): Dropout(p=0.3, inplace=False)
    )
  

In [15]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [16]:
count_parameters(model)

1597321

In [17]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [18]:
model.apply(init_weights)

  torch.nn.init.xavier_uniform(m.weight)


Sequential(
  (0): Linear(in_features=894, out_features=447, bias=True)
  (1): ReLU()
  (2): ResBlock(
    (net): Sequential(
      (0): Linear(in_features=447, out_features=447, bias=True)
      (1): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.3, inplace=False)
      (4): Linear(in_features=447, out_features=447, bias=True)
      (5): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.3, inplace=False)
      (8): Linear(in_features=447, out_features=447, bias=True)
      (9): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Dropout(p=0.3, inplace=False)
      (12): Linear(in_features=447, out_features=447, bias=True)
      (13): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): Dropout(p=0.3, inplace=False)
    )
  

In [19]:
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

In [20]:
loss_function = torch.nn.CrossEntropyLoss()

In [21]:
epochs = 1000
symptom_dropping_chance = 0.1

In [22]:
model.train()

Sequential(
  (0): Linear(in_features=894, out_features=447, bias=True)
  (1): ReLU()
  (2): ResBlock(
    (net): Sequential(
      (0): Linear(in_features=447, out_features=447, bias=True)
      (1): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.3, inplace=False)
      (4): Linear(in_features=447, out_features=447, bias=True)
      (5): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.3, inplace=False)
      (8): Linear(in_features=447, out_features=447, bias=True)
      (9): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Dropout(p=0.3, inplace=False)
      (12): Linear(in_features=447, out_features=447, bias=True)
      (13): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): Dropout(p=0.3, inplace=False)
    )
  

In [None]:
pbar = trange(epochs)

for epoch in pbar:
  correct = 0
  total_train = 0
  for X, y in train_loader:
    optimizer.zero_grad()
    X, y = X.to(device), y.to(device)

    r = random.uniform(0, 1)
    if r < symptom_dropping_chance:
      lucky_batch = random.randrange(batch_size)
      if lucky_batch < len(X):
        one_indexes = (X[lucky_batch] == 1).nonzero()
        lucky_index = random.randrange(one_indexes.numel())
      
        X[lucky_batch][one_indexes[lucky_index]] = 0.0

    y_hat = model(X)
    loss = loss_function(y_hat, y)
    loss.backward()
    total_train += len(X)
    optimizer.step()
    correct += torch.sum(torch.argmax(y_hat, dim = 1) == y)
  if True:
    model.eval()
    correct_valid = 0
    total_valid = 0
    for X_valid, y_valid in valid_loader:
      X_valid, y_valid = X_valid.to(device), y_valid.to(device)
      y_hat = model(X_valid)
      correct_valid += torch.sum(torch.argmax(y_hat, dim = 1) == y_valid)
      total_valid += len(X_valid)
    latest_valid_acc = correct_valid / total_valid
    model.train()

  pbar.set_description(f'Accuracy {correct/total_train}; Validation accuracy {latest_valid_acc}')

In [24]:
torch.save(model.state_dict(), "model.pt")

In [25]:
model.eval()
correct_test = 0
total_test = 0
for X_test, y_test in tqdm(test_loader):
  X_test, y_test = X_test.to(device), y_test.to(device)
  y_hat = model(X_test)
  correct_test += torch.sum(torch.argmax(y_hat, dim = 1) == y_test)
  total_test += len(X_test)
latest_test_acc = correct_test / total_test

100%|██████████| 505/505 [00:05<00:00, 96.67it/s] 


In [26]:
latest_test_acc

tensor(0.9593, device='cuda:0')