In [1]:
import torch
import pandas as pd
import numpy as np
import sklearn
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import torch.nn.functional as F

In [2]:
ligands = ["TNF", "R84", "PIC", "P3K", "FLA", "CpG", "FSL", "LPS", "UST"]
polarization = ["", "ib", "ig", "i0", "i3", "i4"]
replicas, size = 2, 1288

In [3]:
from core.getdata import *
from core.dataset import *

<h3>Example of GetData</h3>

In [4]:
TNFib1 = GetData(ligands[0], polarization[1], replicas, size)
TNFib1.X.shape

(1288, 98)

<h3>Example of Dataset</h3>

In [5]:
data = Dataset(ligands, polarization, replicas, size)

In [6]:
%%time
for _ in data:
    pass

Wall time: 37 ms


In [7]:
data.data.shape

(69552, 98, 1)

In [8]:
print(type(data.data), type(data.labels))

<class 'numpy.ndarray'> <class 'numpy.ndarray'>


In [9]:
print(data.labels[65000], data.labels[65000].shape) #labels are multi-hot encoded

8 ()


<h3>Initializing Dataloaders</h3>

In [10]:
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

In [11]:
X_len, test_len = int(len(data.data) * 0.75), int(len(data.data) * 0.25)
train_len, val_len = int(X_len * 0.75), int(X_len * 0.25)
print(X_len, test_len, train_len, val_len) #lengths
print(X_len + test_len, len(data.data))

52164 17388 39123 13041
69552 69552


In [12]:
training_data = data.data.reshape(data.data.shape[0], data.data.shape[1], 1) #adds extra dimension

In [13]:
data.data[1000].shape
len(data)

69552

In [14]:
dataset_X, dataset_val = torch.utils.data.random_split(data, [X_len, test_len]) # need separate data and labels for LSTM?

In [15]:
dataloader_train = torch.utils.data.DataLoader(dataset_X, batch_size=64, shuffle=True, num_workers=4)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=64, shuffle=True, num_workers=4)

In [16]:
print(len(dataloader_train), len(dataloader_val), len(dataloader_train) + len(dataloader_val))
###
#data[0:10]

816 272 1088


<h3>Model Training and Evaluation</h3>

In [17]:
from core.network import *
from core.trainer import *

In [18]:
torch.cuda.is_available()

True

In [19]:
#model parameters
input_size = 1
hidden_sizes = [8, 16]
output_size = 9

#training parameters
n_epochs = 55
batch_size = 64
learning_rate = 1e-3

In [20]:
model = LSTM(input_size, hidden_sizes, output_size, num_layers=2, device="cuda:0")
model.train()

LSTM(
  (lstm): LSTM(1, 8, batch_first=True)
  (fc1): Linear(in_features=8, out_features=16, bias=True)
  (fc2): Linear(in_features=16, out_features=9, bias=True)
)

In [21]:
trainer = LSTMTrainer(model=model, device="cuda:0")

In [None]:
trainer.train(dataloader_train, dataloader_val, batch_size=batch_size, n_epochs=n_epochs)

  2%|█▌                                                                                 | 1/55 [00:30<27:47, 30.88s/it]

Epoch 001: | Training Loss: 2.1332293157191837 | Validation Loss: 2.0692616766866516


  4%|███                                                                                | 2/55 [01:02<27:30, 31.15s/it]

Epoch 002: | Training Loss: 2.0607521633307138 | Validation Loss: 2.047986217719667


  5%|████▌                                                                              | 3/55 [01:32<26:43, 30.83s/it]

Epoch 003: | Training Loss: 2.0867508107832835 | Validation Loss: 2.0922131472650696


  7%|██████                                                                             | 4/55 [02:02<25:52, 30.44s/it]

Epoch 004: | Training Loss: 2.070156151611431 | Validation Loss: 2.0685699932715473


  9%|███████▌                                                                           | 5/55 [02:33<25:24, 30.49s/it]

Epoch 005: | Training Loss: 2.0475830960215307 | Validation Loss: 2.039106902830741


 11%|█████████                                                                          | 6/55 [03:02<24:36, 30.13s/it]

Epoch 006: | Training Loss: 2.084438876632382 | Validation Loss: 2.0815005876562176


 13%|██████████▌                                                                        | 7/55 [03:32<24:04, 30.10s/it]

Epoch 007: | Training Loss: 2.0759993367919733 | Validation Loss: 2.073856258217026


 15%|████████████                                                                       | 8/55 [04:01<23:24, 29.87s/it]

Epoch 008: | Training Loss: 2.0729815600549473 | Validation Loss: 2.07822387796991


 16%|█████████████▌                                                                     | 9/55 [04:31<22:50, 29.80s/it]

Epoch 009: | Training Loss: 2.078182556027291 | Validation Loss: 2.088716547717066


 18%|██████████████▉                                                                   | 10/55 [05:01<22:29, 29.98s/it]

Epoch 010: | Training Loss: 2.095988966524601 | Validation Loss: 2.09807893195573


 20%|████████████████▍                                                                 | 11/55 [05:32<22:04, 30.10s/it]

Epoch 011: | Training Loss: 2.1151560590255496 | Validation Loss: 2.130501796217526


 22%|█████████████████▉                                                                | 12/55 [06:01<21:26, 29.93s/it]

Epoch 012: | Training Loss: 2.1091154638458702 | Validation Loss: 2.0869246962315895


 24%|███████████████████▍                                                              | 13/55 [06:31<20:48, 29.74s/it]

Epoch 013: | Training Loss: 2.1440831631714223 | Validation Loss: 2.1602816572960686


 25%|████████████████████▊                                                             | 14/55 [07:04<21:02, 30.79s/it]

Epoch 014: | Training Loss: 2.155123578274951 | Validation Loss: 2.1532078323995365


 27%|██████████████████████▎                                                           | 15/55 [07:38<21:06, 31.67s/it]

Epoch 015: | Training Loss: 2.148520204366422 | Validation Loss: 2.1514930821516933


 29%|███████████████████████▊                                                          | 16/55 [08:09<20:29, 31.53s/it]

Epoch 016: | Training Loss: 2.1475355899801443 | Validation Loss: 2.1478071651038


 31%|█████████████████████████▎                                                        | 17/55 [08:40<19:59, 31.57s/it]

Epoch 017: | Training Loss: 2.146405831855886 | Validation Loss: 2.145888783475932


 33%|██████████████████████████▊                                                       | 18/55 [09:14<19:44, 32.02s/it]

Epoch 018: | Training Loss: 2.1503147628961825 | Validation Loss: 2.1431754608364666


 35%|████████████████████████████▎                                                     | 19/55 [09:47<19:25, 32.38s/it]

Epoch 019: | Training Loss: 2.1186047972125164 | Validation Loss: 2.111833350623355


 36%|█████████████████████████████▊                                                    | 20/55 [10:20<18:57, 32.50s/it]

Epoch 020: | Training Loss: 2.0938580147191588 | Validation Loss: 2.072431746212875


 38%|███████████████████████████████▎                                                  | 21/55 [10:51<18:13, 32.17s/it]

Epoch 021: | Training Loss: 2.1612986118770112 | Validation Loss: 2.151713851620169


 40%|████████████████████████████████▊                                                 | 22/55 [11:23<17:44, 32.25s/it]

Epoch 022: | Training Loss: 2.137763762590932 | Validation Loss: 2.160586160771987


 42%|██████████████████████████████████▎                                               | 23/55 [11:55<17:06, 32.07s/it]

Epoch 023: | Training Loss: 2.1278300649102997 | Validation Loss: 2.1036323514931343


 44%|███████████████████████████████████▊                                              | 24/55 [12:27<16:33, 32.04s/it]

Epoch 024: | Training Loss: 2.0965589655964982 | Validation Loss: 2.0782321047256973


 45%|█████████████████████████████████████▎                                            | 25/55 [12:59<16:01, 32.07s/it]

Epoch 025: | Training Loss: 2.0736647226354656 | Validation Loss: 2.0836501962998333


 47%|██████████████████████████████████████▊                                           | 26/55 [13:32<15:33, 32.18s/it]

Epoch 026: | Training Loss: 2.0758433853294336 | Validation Loss: 2.0628608377540814


 49%|████████████████████████████████████████▎                                         | 27/55 [14:03<14:53, 31.90s/it]

Epoch 027: | Training Loss: 2.073740473857113 | Validation Loss: 2.071895223330049


 51%|█████████████████████████████████████████▋                                        | 28/55 [14:34<14:18, 31.78s/it]

Epoch 028: | Training Loss: 2.0857216499599756 | Validation Loss: 2.078623380292864


 53%|███████████████████████████████████████████▏                                      | 29/55 [15:06<13:44, 31.71s/it]

Epoch 029: | Training Loss: 2.0785388115282153 | Validation Loss: 2.062289711307077


 55%|████████████████████████████████████████████▋                                     | 30/55 [15:37<13:10, 31.61s/it]

Epoch 030: | Training Loss: 2.0750265212035646 | Validation Loss: 2.083890999064726


 56%|██████████████████████████████████████████████▏                                   | 31/55 [16:09<12:37, 31.55s/it]

Epoch 031: | Training Loss: 2.125489899194708 | Validation Loss: 2.0861482585177704


 58%|███████████████████████████████████████████████▋                                  | 32/55 [16:41<12:12, 31.87s/it]

Epoch 032: | Training Loss: 2.0824370911600543 | Validation Loss: 2.086531546624268


 60%|█████████████████████████████████████████████████▏                                | 33/55 [17:13<11:40, 31.85s/it]

Epoch 033: | Training Loss: 2.077543146177834 | Validation Loss: 2.0546576832147205


 62%|██████████████████████████████████████████████████▋                               | 34/55 [17:45<11:08, 31.84s/it]

Epoch 034: | Training Loss: 2.0504665639178428 | Validation Loss: 2.0440893694758415


 64%|████████████████████████████████████████████████████▏                             | 35/55 [18:17<10:37, 31.88s/it]

Epoch 035: | Training Loss: 2.0415974609992085 | Validation Loss: 2.0357464363469795


 65%|█████████████████████████████████████████████████████▋                            | 36/55 [18:49<10:05, 31.88s/it]

Epoch 036: | Training Loss: 2.059981928590466 | Validation Loss: 2.106528951402973


 67%|███████████████████████████████████████████████████████▏                          | 37/55 [19:21<09:34, 31.92s/it]

Epoch 037: | Training Loss: 2.049786928968102 | Validation Loss: 2.0294663739555023


 69%|████████████████████████████████████████████████████████▋                         | 38/55 [19:53<09:02, 31.92s/it]

Epoch 038: | Training Loss: 2.0277996774689826 | Validation Loss: 2.0229412203325943


 71%|██████████████████████████████████████████████████████████▏                       | 39/55 [20:24<08:30, 31.88s/it]

Epoch 039: | Training Loss: 2.0375321009287646 | Validation Loss: 2.0814727720092323


 73%|███████████████████████████████████████████████████████████▋                      | 40/55 [20:56<07:58, 31.91s/it]

Epoch 040: | Training Loss: 2.0467810196911587 | Validation Loss: 2.0353603244704357


 75%|█████████████████████████████████████████████████████████████▏                    | 41/55 [21:30<07:32, 32.34s/it]

Epoch 041: | Training Loss: 2.0304443520658157 | Validation Loss: 2.0230367074117943


 76%|██████████████████████████████████████████████████████████████▌                   | 42/55 [22:02<07:01, 32.39s/it]

Epoch 042: | Training Loss: 2.0128204898506987 | Validation Loss: 1.9974002360421068


 78%|████████████████████████████████████████████████████████████████                  | 43/55 [22:35<06:28, 32.36s/it]

Epoch 043: | Training Loss: 2.0059830248355865 | Validation Loss: 1.999757150078521


 80%|█████████████████████████████████████████████████████████████████▌                | 44/55 [23:08<05:59, 32.72s/it]

Epoch 044: | Training Loss: 2.003213495162188 | Validation Loss: 1.9927260805578793


 82%|███████████████████████████████████████████████████████████████████               | 45/55 [23:41<05:28, 32.83s/it]

Epoch 045: | Training Loss: 1.9899409338831902 | Validation Loss: 1.9933400903554523


 84%|████████████████████████████████████████████████████████████████████▌             | 46/55 [24:14<04:54, 32.67s/it]

Epoch 046: | Training Loss: 1.98520324130853 | Validation Loss: 1.9806803486803


 85%|██████████████████████████████████████████████████████████████████████            | 47/55 [24:47<04:23, 32.90s/it]

Epoch 047: | Training Loss: 1.9843419774198066 | Validation Loss: 1.9833346692954792


 87%|███████████████████████████████████████████████████████████████████████▌          | 48/55 [25:23<03:56, 33.79s/it]

Epoch 048: | Training Loss: 1.9780989130045854 | Validation Loss: 1.9726786687970161


 89%|█████████████████████████████████████████████████████████████████████████         | 49/55 [25:56<03:21, 33.58s/it]

Epoch 049: | Training Loss: 1.9802090459886719 | Validation Loss: 1.9925547706730224


 91%|██████████████████████████████████████████████████████████████████████████▌       | 50/55 [26:29<02:46, 33.33s/it]

Epoch 050: | Training Loss: 1.9795347056260295 | Validation Loss: 1.9631944276830728


 93%|████████████████████████████████████████████████████████████████████████████      | 51/55 [27:02<02:13, 33.38s/it]

Epoch 051: | Training Loss: 1.9631725814120442 | Validation Loss: 1.9620239401564878


 95%|█████████████████████████████████████████████████████████████████████████████▌    | 52/55 [27:41<01:45, 35.01s/it]

Epoch 052: | Training Loss: 1.9630371900457961 | Validation Loss: 1.9441814396311254


 96%|███████████████████████████████████████████████████████████████████████████████   | 53/55 [28:18<01:11, 35.67s/it]

Epoch 053: | Training Loss: 1.9692647484879868 | Validation Loss: 1.9475742635481499


 98%|████████████████████████████████████████████████████████████████████████████████▌ | 54/55 [28:53<00:35, 35.38s/it]

Epoch 054: | Training Loss: 1.9387424848243302 | Validation Loss: 1.9257610760190909


In [None]:
# dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=32, shuffle=True, num_workers=4)
# dic = {"y_pred": np.array([]), "y_true": np.array([])}
# for x_batch, y_batch in dataloader_val:
#     x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
#     # to do : convert y_batch to np array and vstack it y_true
#     y_batch = y_batch.detach().cpu().numpy()
#     dic[y_true].vstack(y_true, y_batch)
#     y_pred = trainer.model(x_batch)
#     y_pred = F.softmax(y_pred, dim=1)
#     # to do : convert to np array and vstack it to y_pred
#     y_pred = y_pred.detach().cpu().numpy()
#     dic[y_pred].vstack(y_pred)
    
# df = pd.DataFrame(dic)

In [None]:
torch.cuda.empty_cache()
x_batch, y_batch = dataset_val[0:10000]
# to do : convert x_batch to tensor and send to gpu 
x_batch = torch.tensor(x_batch, device=torch.device("cuda:0"))
y_pred = trainer.model(x_batch)
y_pred = F.softmax(y_pred, dim=1)
# to do : convert to np array and vstack it to y_pred
y_pred = y_pred.detach().cpu().numpy()
y_pred = np.argmax(y_pred, axis=1)
dic = {"y_pred": y_pred, "y_true": y_batch}

df = pd.DataFrame(dic)

In [None]:
df

In [None]:
sum(df["y_pred"] == df["y_true"])/10000

In [None]:
1/9

<h3>Classification Report</h3>

In [None]:
import sklearn.metrics