In [1]:
import time
import pickle
import torch
import copy
import torch.nn          as nn
import numpy             as np
import pandas            as pd
import matplotlib.pyplot as plt 
import seaborn           as sns

from typing           import List
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import DataLoader as PyG_Dataloader

from config import (
    PATH_TO_FEATURES,
    PATH_TO_SAVED_DRUG_FEATURES,
    PATH_SUMMARY_DATASETS
)
WITHOUT_MISSING_FOLDER = '/without_missing/'

torch.manual_seed(42)
sns.set_theme(style="white")

---

# Experiments on the `TabGraph` approach

In this notebook we are going to expirment the approach of 
- having the cell-line branch using tabular input (`Tab`)
- replacing the drug branch by a GNN (`Graph`)

## Base Datasets

In [2]:
# (Tab) Reading cell-line gene matrix.
with open(f'{PATH_SUMMARY_DATASETS}{WITHOUT_MISSING_FOLDER}cell_line_gene_matrix.pkl', 'rb') as f: cl_gene_mat = pickle.load(f)
# (Graph) Reading drug SMILES graph.
with open(f'{PATH_SUMMARY_DATASETS}drug_graphs_dict.pkl', 'rb') as f: drug_graphs = pickle.load(f)
# (Tab) Reading drug response matrix.
with open(f'{PATH_SUMMARY_DATASETS}{WITHOUT_MISSING_FOLDER}drug_response_matrix__gdsc2.pkl', 'rb') as f: drug_response_matrix = pickle.load(f)  

In [3]:
print(f"Cell-line gene matrix\n{21*'='}")
cl_gene_mat.set_index(['CELL_LINE_NAME'], inplace=True)
print(cl_gene_mat.shape)
cl_gene_mat.head(3)

Cell-line gene matrix
(732, 3432)


Unnamed: 0_level_0,FBXL12_gexpr,PIN1_gexpr,PAK4_gexpr,GNA15_gexpr,ARPP19_gexpr,EAPP_gexpr,MOK_gexpr,MTHFD2_gexpr,TIPARP_gexpr,CASP3_gexpr,...,PDHX_mut,DFFB_mut,FOSL1_mut,ETS1_mut,EBNA1BP2_mut,MYL9_mut,MLLT11_mut,PFKL_mut,FGFR4_mut,SDHB_mut
CELL_LINE_NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
22RV1,7.023759,6.067534,4.31875,3.261427,6.297582,8.313991,5.514912,10.594112,5.222366,6.635925,...,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0
23132-87,6.714387,5.695096,4.536146,3.295886,7.021037,8.50008,4.862145,10.609245,6.528668,7.238143,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
42-MG-BA,7.752402,5.475753,4.033714,3.176525,7.279671,8.013367,4.957332,11.266705,7.445954,6.312424,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [4]:
print(f"Drug SMILES fingerprint graphs\n{30*'='}")
print(f"Number of drugs: {len(list(drug_graphs.keys()))}")
for drug, G in drug_graphs.items():
    assert G.edge_index.max() < G.num_nodes, f'FAIL for drug: {drug}'
print("SUCCESS: All drugs succeeded according to this issue: https://github.com/pyg-team/pytorch_geometric/issues/4588")
print(f"Examples:\n{9*'-'}")
for drug_id in [1003, 1004, 1006]:  
    drug_name = drug_response_matrix.loc[drug_response_matrix.DRUG_ID==drug_id].DRUG_NAME.unique()[0]
    print(f"drug_id {drug_id} = {drug_name:13s} has graph: {drug_graphs[drug_id]}")

Drug SMILES fingerprint graphs
Number of drugs: 152
SUCCESS: All drugs succeeded according to this issue: https://github.com/pyg-team/pytorch_geometric/issues/4588
Examples:
---------
drug_id 1003 = Camptothecin  has graph: Data(x=[26, 9], edge_index=[2, 60], edge_attr=[60, 3], smiles='CC[C@@]1(c2cc3c4c(cc5ccccc5n4)Cn3c(=O)c2COC1=O)O')
drug_id 1004 = Vinblastine   has graph: Data(x=[59, 9], edge_index=[2, 134], edge_attr=[134, 3], smiles='CC[C@@]1(C[C@@H]2C[C@](c3cc4c(cc3OC)N(C)[C@@H]3[C@@]54CCN4CC=C[C@](CC)([C@@H]54)[C@H]([C@@]3(C(=O)OC)O)OC(=O)C)(c3c(CCN(C2)C1)c1ccccc1[nH]3)C(=O)OC)O')
drug_id 1006 = Cytarabine    has graph: Data(x=[17, 9], edge_index=[2, 36], edge_attr=[36, 3], smiles='c1cn([C@H]2[C@H]([C@@H]([C@@H](CO)O2)O)O)c(nc1=N)O')


In [5]:
print(f"Drug response matrix\n{20*'='}")
print(drug_response_matrix.shape)
print("Number of unique cell lines:", len(drug_response_matrix.CELL_LINE_NAME.unique()))
print("Number of unique drug id's:", len(drug_response_matrix.DRUG_ID.unique()))
print("Number of unique drug names's:", len(drug_response_matrix.DRUG_NAME.unique()))
print(drug_response_matrix.isna().sum())
drug_response_matrix.head(3)

Drug response matrix
(91991, 5)
Number of unique cell lines: 732
Number of unique drug id's: 152
Number of unique drug names's: 152
CELL_LINE_NAME    0
DRUG_ID           0
DRUG_NAME         0
DATASET           0
LN_IC50           0
dtype: int64


Unnamed: 0,CELL_LINE_NAME,DRUG_ID,DRUG_NAME,DATASET,LN_IC50
3441054,22RV1,1003,Camptothecin,GDSC2,-3.142631
3459252,22RV1,1004,Vinblastine,GDSC2,-4.459259
3508920,22RV1,1006,Cytarabine,GDSC2,3.826935


## Build PyTorch Dataset

In [6]:
from torch_geometric.data import Dataset

class TabGraphDataset(Dataset): 
    def __init__(self, cl_gene_mat, drug_graphs, drm):
        super().__init__()

        # Cell-line gene matrix and drug SMILES fingerprints graphs.
        self.cl_gene_mat = cl_gene_mat
        self.drug_graphs = drug_graphs

        # Lookup datasets for the response values.
        drm.reset_index(drop=True, inplace=True)
        self.cell_lines = drm['CELL_LINE_NAME']
        self.drug_ids = drm['DRUG_ID']
        self.drug_names = drm['DRUG_NAME']
        self.ic50s = drm['LN_IC50']

    def __len__(self):
        return len(self.ic50s)

    def __getitem__(self, idx: int):
        """
        Returns a tuple of cell-line, drug and the corresponding ln(IC50)
        value for a given index.

        Args:
            idx (`int`): Index to specify the row in the drug response matrix.  
        Returns:
            `np.ndarray, Tuple[torch_geometric.data.data.Data], np.float64]`:
            Tuple of a cell-line gene values, drug SMILES fingerprint graph and 
            the corresponding ln(IC50) value.
        """
        return (self.cl_gene_mat.loc[self.cell_lines.iloc[idx]].values.tolist(),
                self.drug_graphs[self.drug_ids.iloc[idx]],
                self.ic50s.iloc[idx])

    def print_dataset_summary(self):
        print(f"TabGraphDataset Summary")
        print(f"{23*'='}")
        print(f"# observations : {len(self.ic50s)}")
        print(f"# cell-lines   : {len(np.unique(self.cell_lines))}")
        print(f"# drugs        : {len(self.drug_graphs.keys())}")
        print(f"# genes        : {len(self.cl_gene_mat.columns)/4}")

In [7]:
tab_graph_dataset = TabGraphDataset(cl_gene_mat=cl_gene_mat, drug_graphs=drug_graphs, drm=drug_response_matrix)
tab_graph_dataset.print_dataset_summary()   

TabGraphDataset Summary
# observations : 91991
# cell-lines   : 732
# drugs        : 152
# genes        : 858.0


## Set Hyperparamaters

In [8]:
class Args:
    def __init__(self, batch_size, lr, train_ratio, val_ratio, num_epochs):
        self.BATCH_SIZE = batch_size
        self.LR = lr
        self.TRAIN_RATIO = train_ratio
        self.TEST_VAL_RATIO = 1-self.TRAIN_RATIO
        self.VAL_RATIO = val_ratio
        self.NUM_EPOCHS = num_epochs
        self.RANDOM_SEED = 12345      

args = Args(batch_size=1_000, 
            lr=0.0001, 
            train_ratio=0.8, 
            val_ratio=0.5, 
            num_epochs=5)

## Create `DataLoader` Datasets

In [9]:
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader as PyG_DataLoader


def _collate_tab_graph(samples):
    cls, drugs, ic50s = map(list, zip(*samples))
    cls = [torch.tensor(cl, dtype=torch.float64) for cl in cls]
    return torch.stack(cls, 0), Batch.from_data_list(drugs), torch.tensor(ic50s)

def create_datasets(drm, cl_gene_mat, drug_graphs, args):
    print(f"Full     shape: {drm.shape}")
    train_set, test_val_set = train_test_split(drm, 
                                               test_size=args.TEST_VAL_RATIO, 
                                               random_state=args.RANDOM_SEED,
                                               stratify=drm['CELL_LINE_NAME'])
    test_set, val_set = train_test_split(test_val_set,
                                         test_size=args.VAL_RATIO,
                                         random_state=args.RANDOM_SEED,
                                         stratify=test_val_set['CELL_LINE_NAME'])
    print(f"train    shape: {train_set.shape}")
    print(f"test_val shape: {test_val_set.shape}")
    print(f"test     shape: {test_set.shape}")
    print(f"val      shape: {val_set.shape}")

    train_dataset = TabGraphDataset(cl_gene_mat=cl_gene_mat, drug_graphs=drug_graphs, drm=train_set)
    test_dataset = TabGraphDataset(cl_gene_mat=cl_gene_mat, drug_graphs=drug_graphs, drm=test_set)
    val_dataset = TabGraphDataset(cl_gene_mat=cl_gene_mat, drug_graphs=drug_graphs, drm=val_set)

    print("\ntrain_dataset:")
    train_dataset.print_dataset_summary()
    print("\n\ntest_dataset:")
    test_dataset.print_dataset_summary()
    print("\n\nval_dataset:")
    val_dataset.print_dataset_summary()

    # TODO: try out different `num_workers` (by using external python files).
    train_loader = PyG_DataLoader(dataset=train_dataset, batch_size=args.BATCH_SIZE, shuffle=True)
    test_loader = PyG_DataLoader(dataset=test_dataset, batch_size=args.BATCH_SIZE, shuffle=True)
    val_loader = PyG_DataLoader(dataset=val_dataset, batch_size=args.BATCH_SIZE, shuffle=True)

    return train_loader, test_loader, val_loader

train_loader, test_loader, val_loader = create_datasets(drug_response_matrix, cl_gene_mat, drug_graphs, args)

Full     shape: (91991, 5)
train    shape: (73592, 5)
test_val shape: (18399, 5)
test     shape: (9199, 5)
val      shape: (9200, 5)

train_dataset:
TabGraphDataset Summary
# observations : 73592
# cell-lines   : 732
# drugs        : 152
# genes        : 858.0


test_dataset:
TabGraphDataset Summary
# observations : 9199
# cell-lines   : 732
# drugs        : 152
# genes        : 858.0


val_dataset:
TabGraphDataset Summary
# observations : 9200
# cell-lines   : 732
# drugs        : 152
# genes        : 858.0


In [10]:
print("Number of batches per dataset:")
print(f"  train : {len(train_loader)}")
print(f"  test  : {len(test_loader)}")
print(f"  val   : {len(val_loader)}")

Number of batches per dataset:
  train : 74
  test  : 10
  val   : 10


In [11]:
for step, data in enumerate(train_loader):
    if (step > 2) & (step < len(train_loader)-1):
        if step % 10 == 0: 
            print("... step", step) 
        continue
    else:    
        cl_genes, drugs, targets = data
        print(f'Step {step + 1}:')
        print(f'=======')
        print("Number of graphs in the batch:", drugs.num_graphs)
        print(len(cl_genes))
        print(targets.shape)

Step 1:
Number of graphs in the batch: 1000
3432
torch.Size([1000])
Step 2:
Number of graphs in the batch: 1000
3432
torch.Size([1000])
Step 3:
Number of graphs in the batch: 1000
3432
torch.Size([1000])
... step 10
... step 20
... step 30
... step 40
... step 50
... step 60
... step 70
Step 74:
Number of graphs in the batch: 592
3432
torch.Size([592])


## Model development

In [12]:
from tqdm import tqdm
from time import sleep
from sklearn.metrics import r2_score, mean_absolute_error
from scipy.stats import pearsonr

class BuildModel():
    def __init__(self, model, criterion, optimizer, num_epochs, 
        train_loader, test_loader, val_loader, device):
        self.train_losses = []
        self.test_losses = []
        self.val_losses = []
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.val_loader = val_loader
        self.num_epochs = num_epochs
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device

    def train(self, loader): 
        train_epoch_losses, val_epoch_losses = [], []
        all_batch_losses = [] # TODO: this is just for monitoring
        n_batches = len(loader)

        self.model = self.model.float() # TODO: maybe remove
        for epoch in range(self.num_epochs):
            self.model.train()
            batch_losses = []
            for i, data in enumerate(tqdm(loader, desc='Iteration')):
                sleep(0.01)
                cell, drug, targets = data
                cell = torch.stack(cell, 0).transpose(1, 0) # Note that this is only neede when geometric 
                                                            # Dataloader is used and no collate.
                cell, drug, targets = cell.to(device), drug.to(device), targets.to(device)

                self.optimizer.zero_grad()

                #print('cell.shape    : ', cell.size)
                # print('drug.shape    : ', drug.shape)
                # print('targets.size  : ', targets.shape)

                # Models predictions of the ic50s for a batch of cell-lines and drugs
                preds = self.model(cell.float(), drug).unsqueeze(1)
                # print(100*"=")
                # print(targets)
                # print(targets.view(-1, 1))
                loss = self.criterion(preds, targets.view(-1, 1).float()) # =train_loss
                batch_losses.append(loss)

                loss.backward()
                self.optimizer.step()

            all_batch_losses.append(batch_losses) # TODO: this is just for monitoring
            total_epoch_loss = sum(batch_losses)
            train_epoch_losses.append(total_epoch_loss / n_batches)

            mse, _, _, _, _ = self.validate(self.val_loader)
            val_epoch_losses.append(mse)

            print("=====Epoch ", epoch)
            print(f"Train      | MSE: {train_epoch_losses[-1]:2.5f}")
            print(f"Validation | MSE: {mse:2.5f}")

        return train_epoch_losses, val_epoch_losses            

    def validate(self, loader):
        self.model.eval()
        y_true, y_pred = [], []
        total_loss = 0
        with torch.no_grad():
            for data in tqdm(loader, desc='Iter', position=0, leave=True):
                sleep(0.01)
                cl, dr, ic50 = data
                cl = torch.stack(cl, 0).transpose(1, 0)

                preds = self.model(cl.float(), dr).unsqueeze(1)
                ic50 = ic50.to(self.device)
                total_loss += self.criterion(preds, ic50.view(-1,1).float())
                y_true.append(ic50.view(-1, 1))
                y_pred.append(preds)
        
        y_true = torch.cat(y_true, dim=0)
        y_pred = torch.cat(y_pred, dim=0)
        mse = total_loss / len(loader)
        rmse = torch.sqrt(mse)
        mae = mean_absolute_error(y_true.cpu(), y_pred.cpu())
        r2 = r2_score(y_true.cpu(), y_pred.cpu())
        pearson_corr_coef, _ = pearsonr(y_true.cpu().numpy().flatten(), 
                                        y_pred.cpu().numpy().flatten())

        return mse, rmse, mae, r2, pearson_corr_coef

In [13]:
%load_ext autoreload
%autoreload
# from v3_GCN import GraphTab_v1
# from my_utils.model_helpers import train_and_test_model
from torch_geometric.nn import Sequential, GINConv, global_mean_pool, global_max_pool


class TabGraph_v1(torch.nn.Module):
    def __init__(self):
        super(TabGraph_v1, self).__init__()
        # torch.manual_seed(12345)

        self.cell_emb = nn.Sequential(
            nn.Linear(3432, 516),
            nn.BatchNorm1d(516),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(516, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()         
        )

        self.drug_emb = Sequential('x, edge_index, batch', 
            [
                (GINConv(
                    nn.Sequential(
                        nn.Linear(9, 128), # 9 = num_node_features
                        nn.BatchNorm1d(128),
                        nn.ReLU(),
                        nn.Linear(128, 128)
                    )
                ), 'x, edge_index -> x1'),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(128),
                (GINConv(
                    nn.Sequential(
                        nn.Linear(128, 128),
                        nn.BatchNorm1d(128),
                        nn.ReLU(),
                        nn.Linear(128, 128)
                    )
                ), 'x1, edge_index -> x2'),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(128),
                # TODO: research maybe JumpingKnowledge at this point
                (global_max_pool, 'x2, batch -> x3'),
                nn.Linear(128, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(p=0.1),
                nn.Linear(128, 128),
                nn.ReLU()
            ]
        )

        self.fcn = nn.Sequential(
            nn.Linear(2*128, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ELU(),
            nn.Dropout(p=0.1),
            nn.Linear(64, 1)
        )

    def forward(self, cell, drug):
        cell_emb = self.cell_emb(cell)
        drug_emb = self.drug_emb(drug.x.float(), drug.edge_index, drug.batch)
        concat = torch.cat([cell_emb, drug_emb], -1)
        y_pred = self.fcn(concat)
        y_pred = y_pred.reshape(y_pred.shape[0])
        return y_pred

torch.manual_seed(args.RANDOM_SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model = TabGraph_v1().to(device)
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.LR) # TODO: include weight_decay of lr

build_model = BuildModel(model=model,
                         criterion=loss_func,
                         optimizer=optimizer,
                         num_epochs=20,
                         train_loader=train_loader,
                         test_loader=test_loader,
                         val_loader=val_loader, 
                         device=device)

device: cpu


In [150]:
train_losses, val_losses = build_model.train(build_model.train_loader)

Iteration: 100%|██████████| 74/74 [03:25<00:00,  2.78s/it]
Iter: 100%|██████████| 10/10 [00:19<00:00,  1.93s/it]


=====Epoch  0
Train      | MSE: 9.92948
Validation | MSE: 7.26444


Iteration: 100%|██████████| 74/74 [03:16<00:00,  2.66s/it]
Iter: 100%|██████████| 10/10 [00:20<00:00,  2.08s/it]


=====Epoch  1
Train      | MSE: 7.04456
Validation | MSE: 6.17102


Iteration: 100%|██████████| 74/74 [03:06<00:00,  2.53s/it]
Iter: 100%|██████████| 10/10 [00:16<00:00,  1.60s/it]


=====Epoch  2
Train      | MSE: 6.25747
Validation | MSE: 5.91426


Iteration: 100%|██████████| 74/74 [02:48<00:00,  2.28s/it]
Iter: 100%|██████████| 10/10 [00:15<00:00,  1.59s/it]


=====Epoch  3
Train      | MSE: 5.64175
Validation | MSE: 5.25848


Iteration: 100%|██████████| 74/74 [03:05<00:00,  2.50s/it]
Iter: 100%|██████████| 10/10 [00:17<00:00,  1.70s/it]


=====Epoch  4
Train      | MSE: 5.12141
Validation | MSE: 4.23384


Iteration: 100%|██████████| 74/74 [02:59<00:00,  2.42s/it]
Iter: 100%|██████████| 10/10 [00:16<00:00,  1.64s/it]


=====Epoch  5
Train      | MSE: 4.65983
Validation | MSE: 3.87459


Iteration: 100%|██████████| 74/74 [03:00<00:00,  2.44s/it]
Iter: 100%|██████████| 10/10 [00:17<00:00,  1.77s/it]


=====Epoch  6
Train      | MSE: 4.22909
Validation | MSE: 4.34660


Iteration: 100%|██████████| 74/74 [02:49<00:00,  2.29s/it]
Iter: 100%|██████████| 10/10 [00:17<00:00,  1.72s/it]


=====Epoch  7
Train      | MSE: 3.83737
Validation | MSE: 3.67671


Iteration: 100%|██████████| 74/74 [02:59<00:00,  2.43s/it]
Iter: 100%|██████████| 10/10 [00:17<00:00,  1.75s/it]


=====Epoch  8
Train      | MSE: 3.46817
Validation | MSE: 3.00270


Iteration: 100%|██████████| 74/74 [02:50<00:00,  2.30s/it]
Iter: 100%|██████████| 10/10 [00:16<00:00,  1.64s/it]


=====Epoch  9
Train      | MSE: 3.15483
Validation | MSE: 3.12171


Iteration: 100%|██████████| 74/74 [03:00<00:00,  2.43s/it]
Iter: 100%|██████████| 10/10 [00:18<00:00,  1.81s/it]


=====Epoch  10
Train      | MSE: 2.84782
Validation | MSE: 2.74759


Iteration: 100%|██████████| 74/74 [02:57<00:00,  2.40s/it]
Iter: 100%|██████████| 10/10 [00:19<00:00,  1.94s/it]


=====Epoch  11
Train      | MSE: 2.59350
Validation | MSE: 2.33532


Iteration: 100%|██████████| 74/74 [02:58<00:00,  2.41s/it]
Iter: 100%|██████████| 10/10 [00:16<00:00,  1.65s/it]


=====Epoch  12
Train      | MSE: 2.36423
Validation | MSE: 2.11116


Iteration: 100%|██████████| 74/74 [02:55<00:00,  2.37s/it]
Iter: 100%|██████████| 10/10 [00:18<00:00,  1.85s/it]


=====Epoch  13
Train      | MSE: 2.16472
Validation | MSE: 1.76590


Iteration: 100%|██████████| 74/74 [03:04<00:00,  2.50s/it]
Iter: 100%|██████████| 10/10 [00:17<00:00,  1.73s/it]


=====Epoch  14
Train      | MSE: 1.99389
Validation | MSE: 1.87217


Iteration: 100%|██████████| 74/74 [03:06<00:00,  2.52s/it]
Iter: 100%|██████████| 10/10 [00:18<00:00,  1.83s/it]


=====Epoch  15
Train      | MSE: 1.85516
Validation | MSE: 1.63568


Iteration: 100%|██████████| 74/74 [02:59<00:00,  2.42s/it]
Iter: 100%|██████████| 10/10 [00:16<00:00,  1.64s/it]


=====Epoch  16
Train      | MSE: 1.74592
Validation | MSE: 1.55687


Iteration: 100%|██████████| 74/74 [02:53<00:00,  2.35s/it]
Iter: 100%|██████████| 10/10 [00:16<00:00,  1.65s/it]


=====Epoch  17
Train      | MSE: 1.65012
Validation | MSE: 1.55923


Iteration: 100%|██████████| 74/74 [02:41<00:00,  2.19s/it]
Iter: 100%|██████████| 10/10 [00:15<00:00,  1.58s/it]


=====Epoch  18
Train      | MSE: 1.58180
Validation | MSE: 1.58497


Iteration: 100%|██████████| 74/74 [03:01<00:00,  2.46s/it]
Iter: 100%|██████████| 10/10 [00:23<00:00,  2.35s/it]


=====Epoch  19
Train      | MSE: 1.52831
Validation | MSE: 1.40334


Iteration: 100%|██████████| 74/74 [03:04<00:00,  2.50s/it]
Iter: 100%|██████████| 10/10 [00:15<00:00,  1.59s/it]


=====Epoch  20
Train      | MSE: 1.47629
Validation | MSE: 1.39805


Iteration: 100%|██████████| 74/74 [02:43<00:00,  2.21s/it]
Iter: 100%|██████████| 10/10 [00:16<00:00,  1.66s/it]


=====Epoch  21
Train      | MSE: 1.45808
Validation | MSE: 1.40369


Iteration: 100%|██████████| 74/74 [03:01<00:00,  2.46s/it]
Iter: 100%|██████████| 10/10 [00:20<00:00,  2.00s/it]


=====Epoch  22
Train      | MSE: 1.41938
Validation | MSE: 1.34122


Iteration: 100%|██████████| 74/74 [03:10<00:00,  2.57s/it]
Iter: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]


=====Epoch  23
Train      | MSE: 1.39795
Validation | MSE: 1.32620


Iteration: 100%|██████████| 74/74 [03:08<00:00,  2.55s/it]
Iter: 100%|██████████| 10/10 [00:20<00:00,  2.03s/it]


=====Epoch  24
Train      | MSE: 1.37905
Validation | MSE: 1.27177


Iteration: 100%|██████████| 74/74 [03:19<00:00,  2.70s/it]
Iter: 100%|██████████| 10/10 [00:16<00:00,  1.69s/it]


=====Epoch  25
Train      | MSE: 1.37930
Validation | MSE: 1.27288


Iteration: 100%|██████████| 74/74 [02:53<00:00,  2.34s/it]
Iter: 100%|██████████| 10/10 [00:15<00:00,  1.58s/it]


=====Epoch  26
Train      | MSE: 1.35394
Validation | MSE: 1.29869


Iteration: 100%|██████████| 74/74 [02:37<00:00,  2.13s/it]
Iter: 100%|██████████| 10/10 [00:15<00:00,  1.55s/it]


=====Epoch  27
Train      | MSE: 1.34701
Validation | MSE: 1.28977


Iteration: 100%|██████████| 74/74 [02:39<00:00,  2.15s/it]
Iter: 100%|██████████| 10/10 [00:15<00:00,  1.53s/it]


=====Epoch  28
Train      | MSE: 1.33672
Validation | MSE: 1.28424


Iteration: 100%|██████████| 74/74 [02:46<00:00,  2.25s/it]
Iter: 100%|██████████| 10/10 [00:18<00:00,  1.83s/it]


=====Epoch  29
Train      | MSE: 1.33391
Validation | MSE: 1.24056


Iteration: 100%|██████████| 74/74 [03:08<00:00,  2.55s/it]
Iter: 100%|██████████| 10/10 [00:20<00:00,  2.08s/it]


=====Epoch  30
Train      | MSE: 1.32257
Validation | MSE: 1.31365


Iteration: 100%|██████████| 74/74 [03:09<00:00,  2.56s/it]
Iter: 100%|██████████| 10/10 [00:18<00:00,  1.80s/it]


=====Epoch  31
Train      | MSE: 1.30753
Validation | MSE: 1.27190


Iteration: 100%|██████████| 74/74 [03:11<00:00,  2.59s/it]
Iter: 100%|██████████| 10/10 [00:22<00:00,  2.27s/it]


=====Epoch  32
Train      | MSE: 1.29936
Validation | MSE: 1.23165


Iteration: 100%|██████████| 74/74 [03:17<00:00,  2.67s/it]
Iter: 100%|██████████| 10/10 [00:22<00:00,  2.29s/it]


=====Epoch  33
Train      | MSE: 1.30104
Validation | MSE: 1.26342


Iteration:  59%|█████▉    | 44/74 [02:02<01:32,  3.09s/it]