In [35]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from tqdm import tqdm, trange

import torch.nn.functional as F

import json

In [36]:
with open("clean_dataset/diseases.json") as file:
  diseases = json.loads(file.read())

In [37]:
with open("clean_dataset/evidences.json") as file:
  evidences = json.loads(file.read())

In [38]:
class_count = len(diseases)
feature_count = len(evidences)

In [39]:
class ResBlock(nn.Module):
  def __init__(self, in_features):
    super().__init__()
    self.in_features = in_features
    self.net = nn.Sequential(
      nn.Linear(in_features, in_features),
      nn.BatchNorm1d(in_features),
      nn.ReLU(),
      nn.Dropout(0.3),
      nn.Linear(in_features, in_features),
      nn.BatchNorm1d(in_features),
      nn.ReLU(),
      nn.Dropout(0.3),
      nn.Linear(in_features, in_features),
      nn.BatchNorm1d(in_features),
      nn.ReLU(),
      nn.Dropout(0.3),
      nn.Linear(in_features, in_features),
      nn.BatchNorm1d(in_features),
      nn.ReLU(),
      nn.Dropout(0.3)
    )
  
  def forward(self, x):
    return x + self.net(x)

In [40]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [41]:
model = nn.Sequential(
  nn.Linear(feature_count, feature_count // 2),
  nn.ReLU(),
  ResBlock(feature_count // 2),
  nn.Linear(feature_count // 2, feature_count // 4),
  nn.ReLU(),
  ResBlock(feature_count // 4),
  nn.Linear(feature_count // 4, feature_count // 8),
  nn.ReLU(),
  ResBlock(feature_count // 8),
  nn.Linear(feature_count // 8, class_count),
  nn.ReLU(),
  ResBlock(class_count),
  nn.Softmax()
  )


In [42]:
model.load_state_dict(torch.load("model.pt"))

<All keys matched successfully>

In [43]:
model.eval()

Sequential(
  (0): Linear(in_features=894, out_features=447, bias=True)
  (1): ReLU()
  (2): ResBlock(
    (net): Sequential(
      (0): Linear(in_features=447, out_features=447, bias=True)
      (1): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.3, inplace=False)
      (4): Linear(in_features=447, out_features=447, bias=True)
      (5): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.3, inplace=False)
      (8): Linear(in_features=447, out_features=447, bias=True)
      (9): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Dropout(p=0.3, inplace=False)
      (12): Linear(in_features=447, out_features=447, bias=True)
      (13): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): Dropout(p=0.3, inplace=False)
    )
  

In [44]:
torch.onnx.export(model,                                # model being run
                  torch.randn(1, feature_count),    # model input (or a tuple for multiple inputs)
                  "model.onnx",           # where to save the model (can be a file or file-like object)
                  input_names = ['input'],              # the model's input names
                  output_names = ['output'])            # the model's output names


verbose: False, log level: Level.ERROR



In [31]:
model.eval()

Sequential(
  (0): Linear(in_features=894, out_features=447, bias=True)
  (1): ReLU()
  (2): ResBlock(
    (net): Sequential(
      (0): Linear(in_features=447, out_features=447, bias=True)
      (1): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.3, inplace=False)
      (4): Linear(in_features=447, out_features=447, bias=True)
      (5): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.3, inplace=False)
      (8): Linear(in_features=447, out_features=447, bias=True)
      (9): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Dropout(p=0.3, inplace=False)
      (12): Linear(in_features=447, out_features=447, bias=True)
      (13): BatchNorm1d(447, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): Dropout(p=0.3, inplace=False)
    )
  

In [32]:
input = torch.zeros(1, feature_count)
tobe_active = [580, 838, 834, 574, 581, 812, 796, 809,   7,   3]

for elem in tobe_active:
  input[0][tobe_active] = 1

In [33]:
input.shape

torch.Size([1, 894])

In [34]:
torch.argmax(model(input))

  input = module(input)


tensor(10)