### 5 Crossfold Validation for 1-layer Model
##### 80:20 split for each crossfold validation step

In [1]:
import sys
sys.path.append('..')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import sklearn
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
from core.getdata import *
from core.dataset import *
from core.network import *
from core.trainer import *
from core.visualization import *

from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

In [4]:
ligands = ["TNF", "R84", "PIC", "P3K", "FLA", "CpG", "FSL", "LPS", "UST"]
polarization = ["", "ib", "ig", "i0", "i3", "i4"]
replicas, size = 2, 1288

In [5]:
load_dir = '../models/cfv/'
save_dir = '../models/cfv/'
save_name = 'mmkfcv'

In [6]:
#model parameters
input_size = 1 
hidden_sizes = 98
output_size = 9
num_layers = 1

#training parameters
n_epochs = 80
batch_size = 65
learning_rate = 1e-3

#device
torch.cuda.is_available()
device = torch.device("cuda:0")

### Empty Model for CV
* only need to execute this cell once

In [7]:
# net = LSTM(input_size, hidden_sizes, output_size, num_layers=num_layers, device=device)
# net.train()
# model = LSTMTrainer(model=net, device=device)

In [8]:
# model.save('../models/' + 'empty.pth')

### CFV 1

In [9]:
data1 = DatasetPolar(ligands, polarization, replicas, size)
k = 5

In [10]:
net1 = LSTM(input_size, hidden_sizes, output_size, num_layers=num_layers, device=device)
net1.train()

LSTM(
  (lstm): LSTM(1, 98, batch_first=True)
  (fc1): Linear(in_features=98, out_features=9, bias=True)
)

In [11]:
model1 = LSTMTrainer(model=net1, device=device)

In [12]:
model1.kfcv(data1, k, save_name, batch_size=batch_size, n_epochs=n_epochs)

  1%|█                                                                                  | 1/80 [00:06<08:15,  6.27s/it]

Epoch 001: | Training Loss: 2.1207124706462164 | Validation Loss: 2.11233077985104


  2%|██                                                                                 | 2/80 [00:11<07:34,  5.83s/it]

Epoch 002: | Training Loss: 2.132062260653371 | Validation Loss: 2.098253051254237


  4%|███                                                                                | 3/80 [00:17<07:14,  5.64s/it]

Epoch 003: | Training Loss: 2.101598839614993 | Validation Loss: 2.1891270541699135


  5%|████▏                                                                              | 4/80 [00:23<07:37,  6.02s/it]

Epoch 004: | Training Loss: 2.1804367788484162 | Validation Loss: 2.0970092020302173


  6%|█████▏                                                                             | 5/80 [00:31<08:25,  6.74s/it]

Epoch 005: | Training Loss: 1.9556865271563841 | Validation Loss: 1.8488152890561897


  8%|██████▏                                                                            | 6/80 [00:40<08:57,  7.26s/it]

Epoch 006: | Training Loss: 1.8265183999716679 | Validation Loss: 1.8013377356752056


  9%|███████▎                                                                           | 7/80 [00:48<09:15,  7.61s/it]

Epoch 007: | Training Loss: 1.747670546034786 | Validation Loss: 1.7302524776102226


 10%|████████▎                                                                          | 8/80 [00:56<09:26,  7.87s/it]

Epoch 008: | Training Loss: 1.698553997083245 | Validation Loss: 1.6761747492808048


 11%|█████████▎                                                                         | 9/80 [01:05<09:28,  8.01s/it]

Epoch 009: | Training Loss: 1.657294596745589 | Validation Loss: 1.6233002108948253


 12%|██████████▎                                                                       | 10/80 [01:13<09:27,  8.10s/it]

Epoch 010: | Training Loss: 1.625351560450046 | Validation Loss: 1.6111224421830934


 14%|███████████▎                                                                      | 11/80 [01:21<09:24,  8.18s/it]

Epoch 011: | Training Loss: 1.5963335787860033 | Validation Loss: 1.578599421220405


 15%|████████████▎                                                                     | 12/80 [01:30<09:19,  8.22s/it]

Epoch 012: | Training Loss: 1.5711725989234782 | Validation Loss: 1.560860651118733


 16%|█████████████▎                                                                    | 13/80 [01:38<09:12,  8.25s/it]

Epoch 013: | Training Loss: 1.5528079126482812 | Validation Loss: 1.544118105808151


 18%|██████████████▎                                                                   | 14/80 [01:46<09:06,  8.29s/it]

Epoch 014: | Training Loss: 1.5330818908236852 | Validation Loss: 1.5268523793354212


 19%|███████████████▍                                                                  | 15/80 [01:55<08:59,  8.30s/it]

Epoch 015: | Training Loss: 1.514660596708271 | Validation Loss: 1.5241350542719119


 20%|████████████████▍                                                                 | 16/80 [02:03<08:51,  8.31s/it]

Epoch 016: | Training Loss: 1.486078790972166 | Validation Loss: 1.4743105011565663


 21%|█████████████████▍                                                                | 17/80 [02:11<08:44,  8.33s/it]

Epoch 017: | Training Loss: 1.468662980978734 | Validation Loss: 1.4788475465551716


 22%|██████████████████▍                                                               | 18/80 [02:20<08:35,  8.32s/it]

Epoch 018: | Training Loss: 1.4422072553746055 | Validation Loss: 1.425377448585546


 24%|███████████████████▍                                                              | 19/80 [02:28<08:26,  8.31s/it]

Epoch 019: | Training Loss: 1.4136068646874382 | Validation Loss: 1.4316605825290503


 25%|████████████████████▌                                                             | 20/80 [02:36<08:20,  8.34s/it]

Epoch 020: | Training Loss: 1.3965894784604278 | Validation Loss: 1.4019093658322486


 26%|█████████████████████▌                                                            | 21/80 [02:45<08:13,  8.36s/it]

Epoch 021: | Training Loss: 1.3807145431637764 | Validation Loss: 1.391339893653014


 28%|██████████████████████▌                                                           | 22/80 [02:53<08:09,  8.43s/it]

Epoch 022: | Training Loss: 1.3803057371873722 | Validation Loss: 1.3870128309615304


 29%|███████████████████████▌                                                          | 23/80 [03:02<07:58,  8.39s/it]

Epoch 023: | Training Loss: 1.353253546371081 | Validation Loss: 1.34834999235991


 30%|████████████████████████▌                                                         | 24/80 [03:10<07:49,  8.38s/it]

Epoch 024: | Training Loss: 1.3390435164219865 | Validation Loss: 1.346453967216973


 31%|█████████████████████████▋                                                        | 25/80 [03:18<07:41,  8.39s/it]

Epoch 025: | Training Loss: 1.3292847496466103 | Validation Loss: 1.3900124358239574


 32%|██████████████████████████▋                                                       | 26/80 [03:27<07:30,  8.35s/it]

Epoch 026: | Training Loss: 1.3119919110681408 | Validation Loss: 1.3181236570126542


 34%|███████████████████████████▋                                                      | 27/80 [03:35<07:18,  8.28s/it]

Epoch 027: | Training Loss: 1.2990895659427777 | Validation Loss: 1.3147291529958494


 35%|████████████████████████████▋                                                     | 28/80 [03:43<07:12,  8.33s/it]

Epoch 028: | Training Loss: 1.286265109952922 | Validation Loss: 1.3120735830792756


 36%|█████████████████████████████▋                                                    | 29/80 [03:52<07:05,  8.34s/it]

Epoch 029: | Training Loss: 1.272100763830626 | Validation Loss: 1.3220991236027155


 38%|██████████████████████████████▊                                                   | 30/80 [04:00<06:54,  8.29s/it]

Epoch 030: | Training Loss: 1.2635982064443214 | Validation Loss: 1.290941465402318


 39%|███████████████████████████████▊                                                  | 31/80 [04:08<06:47,  8.31s/it]

Epoch 031: | Training Loss: 1.2558398554536785 | Validation Loss: 1.3346634977888838


 40%|████████████████████████████████▊                                                 | 32/80 [04:16<06:31,  8.16s/it]

Epoch 032: | Training Loss: 1.2489665348396124 | Validation Loss: 1.307119829353885


 41%|█████████████████████████████████▊                                                | 33/80 [04:23<06:03,  7.74s/it]

Epoch 033: | Training Loss: 1.240866489499529 | Validation Loss: 1.288164649889848


 42%|██████████████████████████████████▊                                               | 34/80 [04:31<06:05,  7.95s/it]

Epoch 034: | Training Loss: 1.2289494105449348 | Validation Loss: 1.2810651910639255


 44%|███████████████████████████████████▉                                              | 35/80 [04:40<06:03,  8.07s/it]

Epoch 035: | Training Loss: 1.212344937305027 | Validation Loss: 1.2864832883683321


 45%|████████████████████████████████████▉                                             | 36/80 [04:48<05:58,  8.15s/it]

Epoch 036: | Training Loss: 1.2054648180709822 | Validation Loss: 1.2810211086941656


 46%|█████████████████████████████████████▉                                            | 37/80 [04:56<05:53,  8.21s/it]

Epoch 037: | Training Loss: 1.2071102268506433 | Validation Loss: 1.2596194161989978


 48%|██████████████████████████████████████▉                                           | 38/80 [05:04<05:45,  8.23s/it]

Epoch 038: | Training Loss: 1.1979128922396731 | Validation Loss: 1.2891026100265646


 49%|███████████████████████████████████████▉                                          | 39/80 [05:13<05:38,  8.26s/it]

Epoch 039: | Training Loss: 1.1758263076995021 | Validation Loss: 1.2694933846175114


 50%|█████████████████████████████████████████                                         | 40/80 [05:19<05:03,  7.59s/it]

Epoch 040: | Training Loss: 1.169534950682493 | Validation Loss: 1.2517876870164246


 51%|██████████████████████████████████████████                                        | 41/80 [05:24<04:28,  6.89s/it]

Epoch 041: | Training Loss: 1.1615635222781484 | Validation Loss: 1.2257434121359174


 52%|███████████████████████████████████████████                                       | 42/80 [05:29<03:58,  6.28s/it]

Epoch 042: | Training Loss: 1.1602036230196462 | Validation Loss: 1.2371724027896596


 54%|████████████████████████████████████████████                                      | 43/80 [05:34<03:34,  5.81s/it]

Epoch 043: | Training Loss: 1.1474990046052176 | Validation Loss: 1.2250280427598508


 55%|█████████████████████████████████████████████                                     | 44/80 [05:38<03:17,  5.49s/it]

Epoch 044: | Training Loss: 1.1370205036669134 | Validation Loss: 1.2355804014428753


 56%|██████████████████████████████████████████████▏                                   | 45/80 [05:43<03:03,  5.25s/it]

Epoch 045: | Training Loss: 1.1354194885102389 | Validation Loss: 1.2230330577520567


 57%|███████████████████████████████████████████████▏                                  | 46/80 [05:48<02:53,  5.09s/it]

Epoch 046: | Training Loss: 1.1256475243195194 | Validation Loss: 1.2031756578204789


 59%|████████████████████████████████████████████████▏                                 | 47/80 [05:52<02:43,  4.94s/it]

Epoch 047: | Training Loss: 1.126217465952178 | Validation Loss: 1.2344933553276776


 60%|█████████████████████████████████████████████████▏                                | 48/80 [05:57<02:34,  4.83s/it]

Epoch 048: | Training Loss: 1.1165662661334064 | Validation Loss: 1.2043365295802322


 61%|██████████████████████████████████████████████████▏                               | 49/80 [06:02<02:28,  4.78s/it]

Epoch 049: | Training Loss: 1.1035861954371506 | Validation Loss: 1.20322667911788


 62%|███████████████████████████████████████████████████▎                              | 50/80 [06:06<02:21,  4.72s/it]

Epoch 050: | Training Loss: 1.1054779180438719 | Validation Loss: 1.1978138350994787


 64%|████████████████████████████████████████████████████▎                             | 51/80 [06:11<02:15,  4.68s/it]

Epoch 051: | Training Loss: 1.1069441430061777 | Validation Loss: 1.2001950988702685


 65%|█████████████████████████████████████████████████████▎                            | 52/80 [06:15<02:10,  4.64s/it]

Epoch 052: | Training Loss: 1.1096438743243708 | Validation Loss: 1.1974669970641627


 66%|██████████████████████████████████████████████████████▎                           | 53/80 [06:20<02:04,  4.62s/it]

Epoch 053: | Training Loss: 1.0829471949244214 | Validation Loss: 1.2023855929619798


 68%|███████████████████████████████████████████████████████▎                          | 54/80 [06:25<01:59,  4.62s/it]

Epoch 054: | Training Loss: 1.073767597499852 | Validation Loss: 1.2526953858192835


 69%|████████████████████████████████████████████████████████▍                         | 55/80 [06:29<01:55,  4.61s/it]

Epoch 055: | Training Loss: 1.083275028199793 | Validation Loss: 1.2207103327055957


 70%|█████████████████████████████████████████████████████████▍                        | 56/80 [06:34<01:50,  4.59s/it]

Epoch 056: | Training Loss: 1.069570557327471 | Validation Loss: 1.2100936826144424


 71%|██████████████████████████████████████████████████████████▍                       | 57/80 [06:38<01:45,  4.59s/it]

Epoch 057: | Training Loss: 1.0689461179007993 | Validation Loss: 1.174155316898756


 72%|███████████████████████████████████████████████████████████▍                      | 58/80 [06:43<01:40,  4.58s/it]

Epoch 058: | Training Loss: 1.0610322811614687 | Validation Loss: 1.1689147475723909


 74%|████████████████████████████████████████████████████████████▍                     | 59/80 [06:47<01:36,  4.58s/it]

Epoch 059: | Training Loss: 1.0591305607111654 | Validation Loss: 1.1984591567627738


 75%|█████████████████████████████████████████████████████████████▌                    | 60/80 [06:52<01:31,  4.57s/it]

Epoch 060: | Training Loss: 1.0403312091654706 | Validation Loss: 1.160025977085684


 76%|██████████████████████████████████████████████████████████████▌                   | 61/80 [06:57<01:27,  4.58s/it]

Epoch 061: | Training Loss: 1.0261457016534894 | Validation Loss: 1.1567900526189359


 78%|███████████████████████████████████████████████████████████████▌                  | 62/80 [07:01<01:22,  4.59s/it]

Epoch 062: | Training Loss: 1.048266715674757 | Validation Loss: 1.1676971892887187


 79%|████████████████████████████████████████████████████████████████▌                 | 63/80 [07:06<01:18,  4.59s/it]

Epoch 063: | Training Loss: 1.0553454852828354 | Validation Loss: 1.1862611868114115


 80%|█████████████████████████████████████████████████████████████████▌                | 64/80 [07:10<01:13,  4.58s/it]

Epoch 064: | Training Loss: 1.017795017110967 | Validation Loss: 1.1519318983376583


 81%|██████████████████████████████████████████████████████████████████▋               | 65/80 [07:15<01:08,  4.56s/it]

Epoch 065: | Training Loss: 1.0128401164835859 | Validation Loss: 1.182074023741428


 82%|███████████████████████████████████████████████████████████████████▋              | 66/80 [07:19<01:03,  4.56s/it]

Epoch 066: | Training Loss: 1.0070143027979637 | Validation Loss: 1.1579293046042183


 84%|████████████████████████████████████████████████████████████████████▋             | 67/80 [07:24<00:59,  4.55s/it]

Epoch 067: | Training Loss: 1.0331815307106926 | Validation Loss: 1.5483866406378346


 85%|█████████████████████████████████████████████████████████████████████▋            | 68/80 [07:29<00:54,  4.58s/it]

Epoch 068: | Training Loss: 1.161755219589327 | Validation Loss: 1.2058014532673025


 86%|██████████████████████████████████████████████████████████████████████▋           | 69/80 [07:33<00:50,  4.62s/it]

Epoch 069: | Training Loss: 1.0217089022849208 | Validation Loss: 1.1478056345030525


 88%|███████████████████████████████████████████████████████████████████████▊          | 70/80 [07:38<00:46,  4.60s/it]

Epoch 070: | Training Loss: 0.9895100096397311 | Validation Loss: 1.147673562865391


 89%|████████████████████████████████████████████████████████████████████████▊         | 71/80 [07:42<00:41,  4.59s/it]

Epoch 071: | Training Loss: 0.9910596274187632 | Validation Loss: 1.1612109631021446


 90%|█████████████████████████████████████████████████████████████████████████▊        | 72/80 [07:47<00:36,  4.60s/it]

Epoch 072: | Training Loss: 0.9769249717208827 | Validation Loss: 1.1643375698651108


 91%|██████████████████████████████████████████████████████████████████████████▊       | 73/80 [07:52<00:32,  4.62s/it]

Epoch 073: | Training Loss: 0.9736132075157121 | Validation Loss: 1.127799083974874


 92%|███████████████████████████████████████████████████████████████████████████▊      | 74/80 [07:57<00:28,  4.69s/it]

Epoch 074: | Training Loss: 0.9944235613552209 | Validation Loss: 1.1324278854321097


 94%|████████████████████████████████████████████████████████████████████████████▉     | 75/80 [08:01<00:23,  4.66s/it]

Epoch 075: | Training Loss: 0.995011392110419 | Validation Loss: 1.1795660741975373


 95%|█████████████████████████████████████████████████████████████████████████████▉    | 76/80 [08:06<00:18,  4.62s/it]

Epoch 076: | Training Loss: 0.9647735509081422 | Validation Loss: 1.128770132766706


 96%|██████████████████████████████████████████████████████████████████████████████▉   | 77/80 [08:10<00:13,  4.60s/it]

Epoch 077: | Training Loss: 0.9711453840136528 | Validation Loss: 1.130139658384234


 98%|███████████████████████████████████████████████████████████████████████████████▉  | 78/80 [08:15<00:09,  4.59s/it]

Epoch 078: | Training Loss: 0.9537389584790881 | Validation Loss: 1.1113199090289179


 99%|████████████████████████████████████████████████████████████████████████████████▉ | 79/80 [08:19<00:04,  4.58s/it]

Epoch 079: | Training Loss: 0.9420805161940717 | Validation Loss: 1.120396478153835


100%|██████████████████████████████████████████████████████████████████████████████████| 80/80 [08:24<00:00,  6.31s/it]

Epoch 080: | Training Loss: 0.9583158612112018 | Validation Loss: 1.1293964976462247





UnboundLocalError: local variable 'y_pred' referenced before assignment

In [None]:
len(model1.kfcv_histories)

In [None]:
plt.plot(model1.kfcv_histories[0][0])
plt.plot(model1.kfcv_histories[0][1])