In [1]:
import sys
sys.path.append('..')

In [2]:
import torch
import pandas as pd
import numpy as np
import sklearn
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import torch.nn.functional as F

In [3]:
ligands = ["TNF", "R84", "PIC", "P3K", "FLA", "CpG", "FSL", "LPS", "UST"]
polarization = ["", "ib", "ig", "i0", "i3", "i4"]
replicas, size = 1, 1288 # unreplicated

In [4]:
from core.getdata import *
from core.dataset import *

<h3>Example of GetData</h3>

In [5]:
TNFib1 = GetData(ligands[0], polarization[1], replicas, size)
TNFib1.X.shape

(1288, 98)

<h3>Example of Dataset</h3>

In [6]:
data = Dataset(ligands, polarization, replicas, size)

In [7]:
%%time
for _ in data:
    pass

Wall time: 40 ms


In [8]:
data.data.shape

(69552, 98, 1)

In [9]:
print(type(data.data), type(data.labels))

<class 'numpy.ndarray'> <class 'numpy.ndarray'>


In [10]:
print(data.labels, data.labels.shape) #labels are multi-hot encoded

[0 0 0 ... 8 8 8] (69552,)


<h3>Initializing Dataloaders</h3>

In [11]:
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

In [12]:
X_len, test_len = int(len(data.data) * 0.9), int(len(data.data) * 0.1)
train_len, val_len = int(X_len * 0.9), int(X_len * 0.1)
print(X_len, test_len, train_len, val_len) #lengths
print(X_len + test_len, len(data.data))

62596 6955 56336 6259
69551 69552


In [13]:
training_data = data.data.reshape(data.data.shape[0], data.data.shape[1], 1) #adds extra dimension

In [14]:
data.data[1000].shape
len(data)

69552

In [15]:
dataset_X, dataset_val = torch.utils.data.random_split(data, [X_len, test_len + 1]) # need separate data and labels for LSTM?

In [16]:
dataloader_train = torch.utils.data.DataLoader(dataset_X, batch_size=64, shuffle=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=64, shuffle=True)

In [17]:
print(len(dataloader_train), len(dataloader_val), len(dataloader_train) + len(dataloader_val))
###
#data[0:10]

979 109 1088


<h3>Model Training</h3>

In [18]:
from core.network import *
from core.trainer import *

In [19]:
torch.cuda.is_available()

True

In [20]:
#model parameters
input_size = 1
hidden_sizes = 98
output_size = 9

#training parameters
n_epochs = 80
batch_size = 64
learning_rate = 1e-3

In [21]:
model = LSTM(input_size, hidden_sizes, output_size, num_layers=3, device="cuda:0")
model.train()

LSTM(
  (lstm): LSTM(1, 98, num_layers=3, batch_first=True)
  (fc1): Linear(in_features=98, out_features=9, bias=True)
)

In [22]:
trainer = LSTMTrainer(model=model, device="cuda:0")

In [None]:
trainer.train(dataloader_train, dataloader_val, batch_size=batch_size, n_epochs=n_epochs)

  1%|█                                                                                  | 1/80 [00:14<19:19, 14.68s/it]

Epoch 001: | Training Loss: 2.1598007090853963 | Validation Loss: 2.1578694308569673


  2%|██                                                                                 | 2/80 [00:28<18:13, 14.02s/it]

Epoch 002: | Training Loss: 1.9780745944645601 | Validation Loss: 1.8644620158256742


  4%|███                                                                                | 3/80 [00:42<18:06, 14.11s/it]

Epoch 003: | Training Loss: 1.743507109940113 | Validation Loss: 1.666335723815708


  5%|████▏                                                                              | 4/80 [00:59<19:30, 15.40s/it]

Epoch 004: | Training Loss: 1.6142943435109793 | Validation Loss: 1.5762905221466625


  6%|█████▏                                                                             | 5/80 [01:16<19:52, 15.91s/it]

Epoch 005: | Training Loss: 1.5872845356262253 | Validation Loss: 1.5551306162405452


  8%|██████▏                                                                            | 6/80 [01:33<19:58, 16.19s/it]

Epoch 006: | Training Loss: 1.5378254978357224 | Validation Loss: 1.487369628127562


  9%|███████▎                                                                           | 7/80 [01:51<20:21, 16.73s/it]

Epoch 007: | Training Loss: 1.502789656280619 | Validation Loss: 1.4817006631728706


 10%|████████▎                                                                          | 8/80 [02:07<20:02, 16.70s/it]

Epoch 008: | Training Loss: 1.4516066127705014 | Validation Loss: 1.4186193603988086


 11%|█████████▎                                                                         | 9/80 [02:24<19:52, 16.80s/it]

Epoch 009: | Training Loss: 1.4259741259424874 | Validation Loss: 1.3973809535350274


 12%|██████████▎                                                                       | 10/80 [02:42<19:47, 16.96s/it]

Epoch 010: | Training Loss: 1.3840691516056003 | Validation Loss: 1.3787433521463237


 14%|███████████▎                                                                      | 11/80 [02:59<19:41, 17.13s/it]

Epoch 011: | Training Loss: 1.4478476393332398 | Validation Loss: 1.422013785860954


 15%|████████████▎                                                                     | 12/80 [03:16<19:20, 17.07s/it]

Epoch 012: | Training Loss: 1.3930190016227308 | Validation Loss: 1.3537268758913792


 16%|█████████████▎                                                                    | 13/80 [03:33<19:04, 17.08s/it]

Epoch 013: | Training Loss: 1.3619366343588823 | Validation Loss: 1.3496188406550556


 18%|██████████████▎                                                                   | 14/80 [03:51<19:05, 17.35s/it]

Epoch 014: | Training Loss: 1.3270730888271234 | Validation Loss: 1.3090142219438465


 19%|███████████████▍                                                                  | 15/80 [04:08<18:35, 17.16s/it]

Epoch 015: | Training Loss: 1.28345515078982 | Validation Loss: 1.291894266911603


 20%|████████████████▍                                                                 | 16/80 [04:25<18:10, 17.04s/it]

Epoch 016: | Training Loss: 1.2446769934387323 | Validation Loss: 1.2694698801828086


 21%|█████████████████▍                                                                | 17/80 [04:41<17:42, 16.86s/it]

Epoch 017: | Training Loss: 1.223891994317535 | Validation Loss: 1.2698885831264182


 22%|██████████████████▍                                                               | 18/80 [04:58<17:28, 16.91s/it]

Epoch 018: | Training Loss: 1.1937604571509044 | Validation Loss: 1.2002680044655407


 24%|███████████████████▍                                                              | 19/80 [05:15<17:09, 16.88s/it]

Epoch 019: | Training Loss: 1.1677049224296312 | Validation Loss: 1.2090222485568545


 25%|████████████████████▌                                                             | 20/80 [05:32<16:49, 16.83s/it]

Epoch 020: | Training Loss: 1.1432626631578946 | Validation Loss: 1.1974168819025022


 26%|█████████████████████▌                                                            | 21/80 [05:49<16:48, 17.10s/it]

Epoch 021: | Training Loss: 1.1874581976005079 | Validation Loss: 1.2079636651441592


 28%|██████████████████████▌                                                           | 22/80 [06:06<16:27, 17.03s/it]

Epoch 022: | Training Loss: 1.1984812590513336 | Validation Loss: 1.2254272278295744


 29%|███████████████████████▌                                                          | 23/80 [06:23<16:03, 16.90s/it]

Epoch 023: | Training Loss: 1.1525419368441916 | Validation Loss: 1.2019341287262943


 30%|████████████████████████▌                                                         | 24/80 [06:40<15:43, 16.85s/it]

Epoch 024: | Training Loss: 1.1077981980268266 | Validation Loss: 1.1680383381493595


 31%|█████████████████████████▋                                                        | 25/80 [06:57<15:28, 16.88s/it]

Epoch 025: | Training Loss: 1.125674228259078 | Validation Loss: 1.1705148384111737


 32%|██████████████████████████▋                                                       | 26/80 [07:13<15:08, 16.83s/it]

Epoch 026: | Training Loss: 1.0937590546700513 | Validation Loss: 1.1658927749056336


 34%|███████████████████████████▋                                                      | 27/80 [07:30<14:52, 16.84s/it]

Epoch 027: | Training Loss: 1.0601894763565647 | Validation Loss: 1.1297222796930086


 35%|████████████████████████████▋                                                     | 28/80 [07:47<14:35, 16.84s/it]

Epoch 028: | Training Loss: 1.1544388353641966 | Validation Loss: 1.5204974762890318


 36%|█████████████████████████████▋                                                    | 29/80 [08:04<14:15, 16.78s/it]

Epoch 029: | Training Loss: 1.3922556476159529 | Validation Loss: 1.3197878249194643


 38%|██████████████████████████████▊                                                   | 30/80 [08:21<14:02, 16.85s/it]

Epoch 030: | Training Loss: 1.2800799015228304 | Validation Loss: 1.269384565156534


 39%|███████████████████████████████▊                                                  | 31/80 [08:38<13:47, 16.89s/it]

Epoch 031: | Training Loss: 1.2106933202879908 | Validation Loss: 1.196385176903611


 40%|████████████████████████████████▊                                                 | 32/80 [08:55<13:30, 16.89s/it]

Epoch 032: | Training Loss: 1.1612094645602467 | Validation Loss: 1.1781352260790834


 41%|█████████████████████████████████▊                                                | 33/80 [09:11<13:07, 16.75s/it]

Epoch 033: | Training Loss: 1.118311533523653 | Validation Loss: 1.138872607585487


 42%|██████████████████████████████████▊                                               | 34/80 [09:28<12:53, 16.81s/it]

Epoch 034: | Training Loss: 1.0789916058726403 | Validation Loss: 1.1169183478442901


 44%|███████████████████████████████████▉                                              | 35/80 [09:46<12:51, 17.15s/it]

Epoch 035: | Training Loss: 1.0585906825221474 | Validation Loss: 1.1014986989694997


 45%|████████████████████████████████████▉                                             | 36/80 [10:03<12:29, 17.02s/it]

Epoch 036: | Training Loss: 1.0042459697627193 | Validation Loss: 1.0767526566435437


 46%|█████████████████████████████████████▉                                            | 37/80 [10:19<12:06, 16.90s/it]

Epoch 037: | Training Loss: 0.9780911079953226 | Validation Loss: 1.083727750756325


 48%|██████████████████████████████████████▉                                           | 38/80 [10:36<11:49, 16.89s/it]

Epoch 038: | Training Loss: 0.9406924291577597 | Validation Loss: 1.0490426295394197


 49%|███████████████████████████████████████▉                                          | 39/80 [10:53<11:31, 16.86s/it]

Epoch 039: | Training Loss: 0.935351922270957 | Validation Loss: 1.0361218419643716


 50%|█████████████████████████████████████████                                         | 40/80 [11:10<11:13, 16.84s/it]

Epoch 040: | Training Loss: 0.8855489613269517 | Validation Loss: 1.0186585220721884


 51%|██████████████████████████████████████████                                        | 41/80 [11:26<10:55, 16.80s/it]

Epoch 041: | Training Loss: 0.8508554427019788 | Validation Loss: 1.0019755554855416


 52%|███████████████████████████████████████████                                       | 42/80 [11:43<10:35, 16.72s/it]

Epoch 042: | Training Loss: 0.8269959952419698 | Validation Loss: 0.9957441851633404


<h3>Evaluation</h3>

In [None]:
torch.cuda.empty_cache()
x_batch, y_batch = dataset_val[0:10000]
# to do : convert x_batch to tensor and send to gpu 
x_batch = torch.tensor(x_batch, device=torch.device("cuda:0"))
y_pred = trainer.network(x_batch)
y_pred = F.softmax(y_pred, dim=1)
# to do : convert to np array and vstack it to y_pred
y_pred = y_pred.detach().cpu().numpy()
y_pred = np.argmax(y_pred, axis=1)
dic = {"y_pred": y_pred, "y_true": y_batch}

df = pd.DataFrame(dic)

In [None]:
df

In [None]:
print(f' Accuracy: {sum(df["y_pred"] == df["y_true"])/6956}')

In [None]:
dir_save = '../models/'
trainer.save(dir_save + 'lstm2.pth')# model trained on replicated dataset

<h3>Plots</h3>

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(trainer.val_losses)
plt.plot(trainer.train_losses)

Overfitting seems to occur ~50 epochs | Overall accuracy of the model seems to be the same across replicated/unreplicated datasets