In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import gc

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/dataset/y_train.npy
/kaggle/input/dataset/y_test.npy
/kaggle/input/dataset/x_test.npy
/kaggle/input/dataset/x_train.npy


In [18]:
x_train = np.load('/kaggle/input/dataset/x_train.npy')
x_test = np.load('/kaggle/input/dataset/x_test.npy')
y_train = np.load('/kaggle/input/dataset/y_train.npy')
y_test = np.load('/kaggle/input/dataset/y_test.npy')

In [19]:
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

(53241, 240)
(53241, 140)
(17747, 240)
(17747, 140)


In [20]:
import os, gc, pickle
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from colorama import Fore, Back, Style
from matplotlib.ticker import MaxNLocator
from tqdm import tqdm

from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

DATA_DIR = "/kaggle/input/open-problems-multimodal/"
FP_CELL_METADATA = os.path.join(DATA_DIR,"metadata.csv")

FP_CITE_TRAIN_INPUTS = os.path.join(DATA_DIR,"train_cite_inputs.h5")
FP_CITE_TRAIN_TARGETS = os.path.join(DATA_DIR,"train_cite_targets.h5")
FP_CITE_TEST_INPUTS = os.path.join(DATA_DIR,"test_cite_inputs.h5")

FP_MULTIOME_TRAIN_INPUTS = os.path.join(DATA_DIR,"train_multi_inputs.h5")
FP_MULTIOME_TRAIN_TARGETS = os.path.join(DATA_DIR,"train_multi_targets.h5")
FP_MULTIOME_TEST_INPUTS = os.path.join(DATA_DIR,"test_multi_inputs.h5")

FP_SUBMISSION = os.path.join(DATA_DIR,"sample_submission.csv")
FP_EVALUATION_IDS = os.path.join(DATA_DIR,"evaluation_ids.csv")

cpu


In [21]:
class CFG:
    tr_batch_size = 16 # 16
    va_batch_size = 128 # 32
    
    optimizer = "AdamW"
    lr = 1e-5
    weight_decay = 0.1
    betas = (0.9, 0.999)
    epochs = 50
    

In [22]:
class CtieseqDataset(Dataset):
    """
    Train, Validation or Test dataset for CITEseq samples
    Prepares data for simple vector to vector NN
    """
    def __init__(self, X, y=None):
        self.train = False 
        if y is not None:
            self.train = True
        self.X = X
        self.y = y
            
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        X = self.X[idx]
        
        if self.train:
            y = self.y[idx]
            return {
                "X" : torch.tensor(X).to(device),
                "y" : torch.tensor(y).to(device)
            }
        else:
            return {
                "X" : torch.tensor(X).to(device)
            }

In [23]:
def criterion(outputs, labels):
    """ MSE Loss function"""
    return nn.MSELoss()(outputs, labels)

def correlation_score(y_true, y_pred):
    """
    Scores the predictions according to the competition rules. 
    It is assumed that the predictions are not constant.
    Returns the average of each sample's Pearson correlation coefficient
    """
    
    if type(y_true) == pd.DataFrame: y_true = y_true.values
    if type(y_pred) == pd.DataFrame: y_pred = y_pred.values
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

def get_optimizer(model, lr, weight_decay, betas):
    """ Gets AdamW optimizer """
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': weight_decay},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=lr,
                      weight_decay=weight_decay,
                      betas=betas,
                     )
    return optimizer

def get_scheduler(optimizer, T_max=300):
    """ Gets Consine scheduler """
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=T_max)
    return scheduler

In [24]:
class FCBlock(nn.Module):
    """
    A Pytorch Block for a fully connected Layer
    Includes Linear, Activation Function, and Dropout
    """
    def __init__(self, input_dim, hidden_dim, dropout):
        super().__init__()
        self.fc = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc(x)
        x = F.selu(x)
        x = self.dropout(x)
        return x

class Encoder(nn.Module):
    """
    Encoder module to generate embeddings of a RNA vector
    """
    def __init__(self):
        super().__init__()
        self.l0 = FCBlock(240, 120, 0.05)
        self.l1 = FCBlock(120, 60, 0.05)
        self.l2 = FCBlock(60, 30, 0.05)
        
    def forward(self, x):
        x = self.l0(x)
        x = self.l1(x)
        x = self.l2(x)
        return x
    
class Decoder(nn.Module):
    """
    Decoder module to extract Protein sequences from RNA embeddings
    """
    def __init__(self):
        super().__init__()
        self.l0 = FCBlock(30, 70, 0.05)
        self.l1 = FCBlock(70, 100, 0.05)
        self.l2 = FCBlock(100, 140, 0.05)
        
    def forward(self, x):
        x = self.l0(x)
        x = self.l1(x)
        x = self.l2(x)
        return x
    
class CtieseqModel(nn.Module):
    """
    Wrapper for the Encoder and Decoder modules
    Converts RNA sequence to Protein sequence
    """
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, x):
        embeddings = self.encoder(x)
        outputs = self.decoder(embeddings)
        return outputs

In [25]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [26]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    """ Trains one epoch and returns loss """
    model.train()
    
    losses = AverageMeter()
    corr = AverageMeter()
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        X, y = data["X"], data["y"]
        
        batch_size = X.size(0)

        outputs = model(X)

        n = outputs.size(0)
        loss = criterion(outputs, y)
        losses.update(loss.item(), n)
        loss.backward()
        
        outputs = outputs.detach().cpu().numpy()
        y = y.detach().cpu().numpy()
        corr_score = correlation_score(y, outputs)
        corr.update(corr_score, n)
        
        optimizer.step()
        optimizer.zero_grad()

        if scheduler is not None:
            scheduler.step()
        
        bar.set_postfix(Epoch=epoch, Train_Loss=losses.avg, Corr=corr.avg,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    
    return losses.avg

In [27]:
@torch.no_grad()
def valid_one_epoch(model, optimizer, dataloader, device, epoch):
    """ Evaluates one epoch and returns loss """
    model.eval()
    
    losses = AverageMeter()
    corr = AverageMeter()
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:        
        X, y = data["X"], data["y"]
        
        batch_size = X.size(0)

        outputs = model(X)
        
        n = outputs.size(0)
        loss = criterion(outputs, y)
        losses.update(loss.item(), n)
        
        outputs = outputs.detach().cpu().numpy()
        y = y.detach().cpu().numpy()
        corr_score = correlation_score(y, outputs)
        corr.update(corr_score, n)
        
        bar.set_postfix(Epoch=epoch, Valid_Loss=losses.avg, Corr=corr.avg,
                        LR=optimizer.param_groups[0]['lr'])   
    
    gc.collect()
    
    return losses.avg

In [28]:
def train_one_fold(model, 
                   optimizer, 
                   scheduler, 
                   train_loader, 
                   valid_loader, 
                   fold):
    """ Trains and saves a full fold of a pytorch model """
    best_epoch_loss = np.inf
    model.to(device)

    for epoch in range(CFG.epochs):
        gc.collect()
        train_epoch_loss = train_one_epoch(model, 
                                           optimizer, 
                                           scheduler, 
                                           dataloader=train_loader, 
                                           device=device, 
                                           epoch=epoch)

        val_epoch_loss = valid_one_epoch(model,
                                         optimizer, 
                                         valid_loader, 
                                         device=device, epoch=epoch)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_epoch_loss,
            }, '/kaggle/working/latest_model_training_stage')
        
        if val_epoch_loss <= best_epoch_loss:
            print(f"Validation Loss Improved ({best_epoch_loss} ---> {val_epoch_loss})")
            best_epoch_loss = val_epoch_loss
            torch.save(model.state_dict(), f"model_f{fold}.bin")
            
    print("Best Loss: {:.4f}".format(best_epoch_loss))

In [29]:
PATH = '/kaggle/working/latest_model_training_stage.bin'
kf = KFold(n_splits=3, shuffle=True, random_state=42)
score_list = []
for fold, (idx_tr, idx_va) in enumerate(kf.split(x_train)):
    print(f"\nfold = {fold}")
    X_tr = x_train[idx_tr] 
    y_tr = y_train[idx_tr]
    
    X_va = x_train[idx_va]
    y_va = y_train[idx_va]
    
    ds_tr = CtieseqDataset(X_tr, y_tr)
    ds_va = CtieseqDataset(X_tr, y_tr)
    dl_tr = DataLoader(ds_tr, batch_size=CFG.tr_batch_size, shuffle=True)
    dl_va = DataLoader(ds_va, batch_size=CFG.va_batch_size, shuffle=False)
    
    model = CtieseqModel()
    optimizer = get_optimizer(model, CFG.lr, CFG.weight_decay, CFG.betas)
    scheduler = get_scheduler(optimizer) 

    train_one_fold(model, optimizer, scheduler, dl_tr, dl_va, fold)
    


fold = 0


100%|██████████| 2219/2219 [00:14<00:00, 148.75it/s, Corr=0.162, Epoch=0, LR=3.41e-6, Train_Loss=14.6]  
100%|██████████| 278/278 [00:05<00:00, 54.42it/s, Corr=0.472, Epoch=0, LR=3.41e-6, Valid_Loss=12.8]


Validation Loss Improved (inf ---> 12.848822223869465)


100%|██████████| 2219/2219 [00:14<00:00, 148.61it/s, Corr=0.627, Epoch=1, LR=1.02e-6, Train_Loss=9.58] 
100%|██████████| 278/278 [00:05<00:00, 53.10it/s, Corr=0.762, Epoch=1, LR=1.02e-6, Valid_Loss=6.45]


Validation Loss Improved (12.848822223869465 ---> 6.445521283151063)


100%|██████████| 2219/2219 [00:14<00:00, 149.27it/s, Corr=0.743, Epoch=2, LR=9.14e-6, Train_Loss=5.96] 
100%|██████████| 278/278 [00:04<00:00, 55.81it/s, Corr=0.799, Epoch=2, LR=9.14e-6, Valid_Loss=4.83]


Validation Loss Improved (6.445521283151063 ---> 4.828890238116141)


100%|██████████| 2219/2219 [00:13<00:00, 158.62it/s, Corr=0.765, Epoch=3, LR=6.34e-6, Train_Loss=5.32] 
100%|██████████| 278/278 [00:05<00:00, 54.73it/s, Corr=0.806, Epoch=3, LR=6.34e-6, Valid_Loss=4.58]


Validation Loss Improved (4.828890238116141 ---> 4.579173266588752)


100%|██████████| 2219/2219 [00:14<00:00, 148.91it/s, Corr=0.77, Epoch=4, LR=6.85e-9, Train_Loss=5.18]  
100%|██████████| 278/278 [00:05<00:00, 53.90it/s, Corr=0.81, Epoch=4, LR=6.85e-9, Valid_Loss=4.45] 


Validation Loss Improved (4.579173266588752 ---> 4.450706352735899)


100%|██████████| 2219/2219 [00:15<00:00, 146.33it/s, Corr=0.776, Epoch=5, LR=6.84e-6, Train_Loss=5.05] 
100%|██████████| 278/278 [00:05<00:00, 54.99it/s, Corr=0.818, Epoch=5, LR=6.84e-6, Valid_Loss=4.31]


Validation Loss Improved (4.450706352735899 ---> 4.308395615603115)


100%|██████████| 2219/2219 [00:14<00:00, 151.80it/s, Corr=0.785, Epoch=6, LR=8.82e-6, Train_Loss=4.88] 
100%|██████████| 278/278 [00:04<00:00, 56.21it/s, Corr=0.828, Epoch=6, LR=8.82e-6, Valid_Loss=4.13]


Validation Loss Improved (4.308395615603115 ---> 4.12564015383115)


100%|██████████| 2219/2219 [00:14<00:00, 157.12it/s, Corr=0.796, Epoch=7, LR=7.23e-7, Train_Loss=4.69] 
100%|██████████| 278/278 [00:05<00:00, 54.63it/s, Corr=0.839, Epoch=7, LR=7.23e-7, Valid_Loss=3.93]


Validation Loss Improved (4.12564015383115 ---> 3.9253286881709344)


100%|██████████| 2219/2219 [00:14<00:00, 153.86it/s, Corr=0.803, Epoch=8, LR=3.91e-6, Train_Loss=4.55] 
100%|██████████| 278/278 [00:05<00:00, 51.62it/s, Corr=0.844, Epoch=8, LR=3.91e-6, Valid_Loss=3.79]


Validation Loss Improved (3.9253286881709344 ---> 3.7927118090231224)


100%|██████████| 2219/2219 [00:14<00:00, 150.27it/s, Corr=0.807, Epoch=9, LR=9.97e-6, Train_Loss=4.47] 
100%|██████████| 278/278 [00:05<00:00, 54.02it/s, Corr=0.846, Epoch=9, LR=9.97e-6, Valid_Loss=3.74]


Validation Loss Improved (3.7927118090231224 ---> 3.739588717282601)


100%|██████████| 2219/2219 [00:15<00:00, 146.97it/s, Corr=0.808, Epoch=10, LR=2.92e-6, Train_Loss=4.43] 
100%|██████████| 278/278 [00:05<00:00, 52.84it/s, Corr=0.847, Epoch=10, LR=2.92e-6, Valid_Loss=3.71]


Validation Loss Improved (3.739588717282601 ---> 3.7064409375681087)


100%|██████████| 2219/2219 [00:14<00:00, 153.73it/s, Corr=0.809, Epoch=11, LR=1.36e-6, Train_Loss=4.4]  
100%|██████████| 278/278 [00:05<00:00, 51.46it/s, Corr=0.848, Epoch=11, LR=1.36e-6, Valid_Loss=3.68]


Validation Loss Improved (3.7064409375681087 ---> 3.6806139521862935)


100%|██████████| 2219/2219 [00:14<00:00, 153.78it/s, Corr=0.81, Epoch=12, LR=9.41e-6, Train_Loss=4.38]  
100%|██████████| 278/278 [00:05<00:00, 54.66it/s, Corr=0.848, Epoch=12, LR=9.41e-6, Valid_Loss=3.66]


Validation Loss Improved (3.6806139521862935 ---> 3.6575788645809615)


100%|██████████| 2219/2219 [00:15<00:00, 146.86it/s, Corr=0.811, Epoch=13, LR=5.83e-6, Train_Loss=4.35] 
100%|██████████| 278/278 [00:05<00:00, 53.84it/s, Corr=0.849, Epoch=13, LR=5.83e-6, Valid_Loss=3.63]


Validation Loss Improved (3.6575788645809615 ---> 3.6332440727857453)


100%|██████████| 2219/2219 [00:14<00:00, 150.55it/s, Corr=0.812, Epoch=14, LR=6.16e-8, Train_Loss=4.33] 
100%|██████████| 278/278 [00:05<00:00, 52.15it/s, Corr=0.85, Epoch=14, LR=6.16e-8, Valid_Loss=3.62] 


Validation Loss Improved (3.6332440727857453 ---> 3.615335349721721)


100%|██████████| 2219/2219 [00:14<00:00, 151.70it/s, Corr=0.813, Epoch=15, LR=7.32e-6, Train_Loss=4.3]  
100%|██████████| 278/278 [00:05<00:00, 54.66it/s, Corr=0.85, Epoch=15, LR=7.32e-6, Valid_Loss=3.59] 


Validation Loss Improved (3.615335349721721 ---> 3.5932314973579578)


100%|██████████| 2219/2219 [00:15<00:00, 146.72it/s, Corr=0.814, Epoch=16, LR=8.46e-6, Train_Loss=4.28] 
100%|██████████| 278/278 [00:05<00:00, 53.41it/s, Corr=0.851, Epoch=16, LR=8.46e-6, Valid_Loss=3.57]


Validation Loss Improved (3.5932314973579578 ---> 3.5706495713024213)


100%|██████████| 2219/2219 [00:14<00:00, 152.12it/s, Corr=0.815, Epoch=17, LR=4.76e-7, Train_Loss=4.26] 
100%|██████████| 278/278 [00:05<00:00, 55.16it/s, Corr=0.852, Epoch=17, LR=4.76e-7, Valid_Loss=3.55]


Validation Loss Improved (3.5706495713024213 ---> 3.5539582230013416)


100%|██████████| 2219/2219 [00:14<00:00, 150.64it/s, Corr=0.816, Epoch=18, LR=4.43e-6, Train_Loss=4.24] 
100%|██████████| 278/278 [00:05<00:00, 55.41it/s, Corr=0.853, Epoch=18, LR=4.43e-6, Valid_Loss=3.53]


Validation Loss Improved (3.5539582230013416 ---> 3.5317419658494735)


100%|██████████| 2219/2219 [00:15<00:00, 143.19it/s, Corr=0.816, Epoch=19, LR=9.89e-6, Train_Loss=4.22] 
100%|██████████| 278/278 [00:05<00:00, 52.35it/s, Corr=0.853, Epoch=19, LR=9.89e-6, Valid_Loss=3.52]


Validation Loss Improved (3.5317419658494735 ---> 3.517273052098644)


100%|██████████| 2219/2219 [00:15<00:00, 145.80it/s, Corr=0.817, Epoch=20, LR=2.45e-6, Train_Loss=4.2]  
100%|██████████| 278/278 [00:05<00:00, 55.28it/s, Corr=0.854, Epoch=20, LR=2.45e-6, Valid_Loss=3.5] 


Validation Loss Improved (3.517273052098644 ---> 3.497806122062938)


100%|██████████| 2219/2219 [00:15<00:00, 147.53it/s, Corr=0.818, Epoch=21, LR=1.73e-6, Train_Loss=4.18] 
100%|██████████| 278/278 [00:05<00:00, 54.36it/s, Corr=0.856, Epoch=21, LR=1.73e-6, Valid_Loss=3.47]


Validation Loss Improved (3.497806122062938 ---> 3.468079904237183)


100%|██████████| 2219/2219 [00:15<00:00, 141.85it/s, Corr=0.82, Epoch=22, LR=9.63e-6, Train_Loss=4.14] 
100%|██████████| 278/278 [00:05<00:00, 54.14it/s, Corr=0.857, Epoch=22, LR=9.63e-6, Valid_Loss=3.44]


Validation Loss Improved (3.468079904237183 ---> 3.4404190470616554)


100%|██████████| 2219/2219 [00:14<00:00, 153.64it/s, Corr=0.822, Epoch=23, LR=5.31e-6, Train_Loss=4.12] 
100%|██████████| 278/278 [00:05<00:00, 54.57it/s, Corr=0.858, Epoch=23, LR=5.31e-6, Valid_Loss=3.41]


Validation Loss Improved (3.4404190470616554 ---> 3.4097210856338966)


100%|██████████| 2219/2219 [00:14<00:00, 156.70it/s, Corr=0.822, Epoch=24, LR=1.7e-7, Train_Loss=4.1]   
100%|██████████| 278/278 [00:05<00:00, 54.89it/s, Corr=0.86, Epoch=24, LR=1.7e-7, Valid_Loss=3.38] 


Validation Loss Improved (3.4097210856338966 ---> 3.3787418268106686)


100%|██████████| 2219/2219 [00:14<00:00, 154.32it/s, Corr=0.824, Epoch=25, LR=7.77e-6, Train_Loss=4.05] 
100%|██████████| 278/278 [00:04<00:00, 55.74it/s, Corr=0.861, Epoch=25, LR=7.77e-6, Valid_Loss=3.35]


Validation Loss Improved (3.3787418268106686 ---> 3.354326858381866)


100%|██████████| 2219/2219 [00:14<00:00, 154.68it/s, Corr=0.825, Epoch=26, LR=8.06e-6, Train_Loss=4.04] 
100%|██████████| 278/278 [00:05<00:00, 54.93it/s, Corr=0.862, Epoch=26, LR=8.06e-6, Valid_Loss=3.33]


Validation Loss Improved (3.354326858381866 ---> 3.330126841974949)


100%|██████████| 2219/2219 [00:14<00:00, 150.56it/s, Corr=0.826, Epoch=27, LR=2.78e-7, Train_Loss=4.02] 
100%|██████████| 278/278 [00:04<00:00, 56.16it/s, Corr=0.863, Epoch=27, LR=2.78e-7, Valid_Loss=3.31]


Validation Loss Improved (3.330126841974949 ---> 3.3073441378397357)


100%|██████████| 2219/2219 [00:14<00:00, 153.64it/s, Corr=0.827, Epoch=28, LR=4.95e-6, Train_Loss=3.99] 
100%|██████████| 278/278 [00:05<00:00, 52.96it/s, Corr=0.864, Epoch=28, LR=4.95e-6, Valid_Loss=3.29]


Validation Loss Improved (3.3073441378397357 ---> 3.2891211270708336)


100%|██████████| 2219/2219 [00:13<00:00, 160.06it/s, Corr=0.828, Epoch=29, LR=9.76e-6, Train_Loss=3.97] 
100%|██████████| 278/278 [00:04<00:00, 56.36it/s, Corr=0.865, Epoch=29, LR=9.76e-6, Valid_Loss=3.27]


Validation Loss Improved (3.2891211270708336 ---> 3.26916372598027)


100%|██████████| 2219/2219 [00:13<00:00, 158.59it/s, Corr=0.828, Epoch=30, LR=2.02e-6, Train_Loss=3.96] 
100%|██████████| 278/278 [00:05<00:00, 55.26it/s, Corr=0.865, Epoch=30, LR=2.02e-6, Valid_Loss=3.25]


Validation Loss Improved (3.26916372598027 ---> 3.2506668539445434)


100%|██████████| 2219/2219 [00:13<00:00, 159.24it/s, Corr=0.83, Epoch=31, LR=2.15e-6, Train_Loss=3.93]  
100%|██████████| 278/278 [00:04<00:00, 55.87it/s, Corr=0.866, Epoch=31, LR=2.15e-6, Valid_Loss=3.23]


Validation Loss Improved (3.2506668539445434 ---> 3.231149919524612)


100%|██████████| 2219/2219 [00:14<00:00, 154.06it/s, Corr=0.83, Epoch=32, LR=9.8e-6, Train_Loss=3.92]  
100%|██████████| 278/278 [00:05<00:00, 54.97it/s, Corr=0.867, Epoch=32, LR=9.8e-6, Valid_Loss=3.21]


Validation Loss Improved (3.231149919524612 ---> 3.2100339765823303)


100%|██████████| 2219/2219 [00:14<00:00, 157.39it/s, Corr=0.831, Epoch=33, LR=4.79e-6, Train_Loss=3.9]  
100%|██████████| 278/278 [00:05<00:00, 54.65it/s, Corr=0.868, Epoch=33, LR=4.79e-6, Valid_Loss=3.19]


Validation Loss Improved (3.2100339765823303 ---> 3.191712962034125)


100%|██████████| 2219/2219 [00:14<00:00, 152.37it/s, Corr=0.832, Epoch=34, LR=3.32e-7, Train_Loss=3.88] 
100%|██████████| 278/278 [00:05<00:00, 55.37it/s, Corr=0.868, Epoch=34, LR=3.32e-7, Valid_Loss=3.18]


Validation Loss Improved (3.191712962034125 ---> 3.17572362707616)


100%|██████████| 2219/2219 [00:14<00:00, 149.08it/s, Corr=0.833, Epoch=35, LR=8.19e-6, Train_Loss=3.87] 
100%|██████████| 278/278 [00:05<00:00, 54.79it/s, Corr=0.869, Epoch=35, LR=8.19e-6, Valid_Loss=3.16]


Validation Loss Improved (3.17572362707616 ---> 3.1565779685275537)


100%|██████████| 2219/2219 [00:14<00:00, 150.10it/s, Corr=0.833, Epoch=36, LR=7.63e-6, Train_Loss=3.86] 
100%|██████████| 278/278 [00:05<00:00, 53.08it/s, Corr=0.87, Epoch=36, LR=7.63e-6, Valid_Loss=3.14] 


Validation Loss Improved (3.1565779685275537 ---> 3.1366676250336747)


100%|██████████| 2219/2219 [00:14<00:00, 152.05it/s, Corr=0.834, Epoch=37, LR=1.32e-7, Train_Loss=3.83] 
100%|██████████| 278/278 [00:04<00:00, 55.92it/s, Corr=0.871, Epoch=37, LR=1.32e-7, Valid_Loss=3.12]


Validation Loss Improved (3.1366676250336747 ---> 3.1185148847521598)


100%|██████████| 2219/2219 [00:14<00:00, 156.23it/s, Corr=0.835, Epoch=38, LR=5.47e-6, Train_Loss=3.82] 
100%|██████████| 278/278 [00:05<00:00, 54.48it/s, Corr=0.871, Epoch=38, LR=5.47e-6, Valid_Loss=3.1] 


Validation Loss Improved (3.1185148847521598 ---> 3.0967229558842924)


100%|██████████| 2219/2219 [00:14<00:00, 156.05it/s, Corr=0.836, Epoch=39, LR=9.57e-6, Train_Loss=3.79] 
100%|██████████| 278/278 [00:05<00:00, 50.97it/s, Corr=0.872, Epoch=39, LR=9.57e-6, Valid_Loss=3.08]


Validation Loss Improved (3.0967229558842924 ---> 3.0842976512630575)


100%|██████████| 2219/2219 [00:15<00:00, 147.73it/s, Corr=0.836, Epoch=40, LR=1.62e-6, Train_Loss=3.78] 
100%|██████████| 278/278 [00:05<00:00, 52.97it/s, Corr=0.873, Epoch=40, LR=1.62e-6, Valid_Loss=3.07]


Validation Loss Improved (3.0842976512630575 ---> 3.066330948285454)


100%|██████████| 2219/2219 [00:14<00:00, 150.44it/s, Corr=0.837, Epoch=41, LR=2.59e-6, Train_Loss=3.77] 
100%|██████████| 278/278 [00:05<00:00, 54.43it/s, Corr=0.873, Epoch=41, LR=2.59e-6, Valid_Loss=3.05]


Validation Loss Improved (3.066330948285454 ---> 3.0482910923966693)


100%|██████████| 2219/2219 [00:14<00:00, 149.72it/s, Corr=0.837, Epoch=42, LR=9.92e-6, Train_Loss=3.76] 
100%|██████████| 278/278 [00:05<00:00, 53.92it/s, Corr=0.874, Epoch=42, LR=9.92e-6, Valid_Loss=3.04]


Validation Loss Improved (3.0482910923966693 ---> 3.04005609309901)


100%|██████████| 2219/2219 [00:14<00:00, 152.33it/s, Corr=0.838, Epoch=43, LR=4.27e-6, Train_Loss=3.74] 
100%|██████████| 278/278 [00:05<00:00, 54.26it/s, Corr=0.874, Epoch=43, LR=4.27e-6, Valid_Loss=3.03]


Validation Loss Improved (3.04005609309901 ---> 3.0286280715499805)


100%|██████████| 2219/2219 [00:14<00:00, 149.51it/s, Corr=0.839, Epoch=44, LR=5.45e-7, Train_Loss=3.72] 
100%|██████████| 278/278 [00:05<00:00, 54.31it/s, Corr=0.875, Epoch=44, LR=5.45e-7, Valid_Loss=3.01]


Validation Loss Improved (3.0286280715499805 ---> 3.0112322752070155)


100%|██████████| 2219/2219 [00:14<00:00, 149.09it/s, Corr=0.839, Epoch=45, LR=8.57e-6, Train_Loss=3.71] 
100%|██████████| 278/278 [00:05<00:00, 50.20it/s, Corr=0.875, Epoch=45, LR=8.57e-6, Valid_Loss=3]   


Validation Loss Improved (3.0112322752070155 ---> 2.9975843694355544)


100%|██████████| 2219/2219 [00:15<00:00, 146.87it/s, Corr=0.839, Epoch=46, LR=7.18e-6, Train_Loss=3.71] 
100%|██████████| 278/278 [00:05<00:00, 53.34it/s, Corr=0.876, Epoch=46, LR=7.18e-6, Valid_Loss=2.99]


Validation Loss Improved (2.9975843694355544 ---> 2.985598493161091)


100%|██████████| 2219/2219 [00:14<00:00, 147.95it/s, Corr=0.84, Epoch=47, LR=3.94e-8, Train_Loss=3.69]  
100%|██████████| 278/278 [00:05<00:00, 53.20it/s, Corr=0.876, Epoch=47, LR=3.94e-8, Valid_Loss=2.97]


Validation Loss Improved (2.985598493161091 ---> 2.9711871637642173)


100%|██████████| 2219/2219 [00:15<00:00, 140.18it/s, Corr=0.84, Epoch=48, LR=5.99e-6, Train_Loss=3.69] 
100%|██████████| 278/278 [00:05<00:00, 50.42it/s, Corr=0.877, Epoch=48, LR=5.99e-6, Valid_Loss=2.97]


Validation Loss Improved (2.9711871637642173 ---> 2.966742152381012)


100%|██████████| 2219/2219 [00:16<00:00, 134.37it/s, Corr=0.841, Epoch=49, LR=9.33e-6, Train_Loss=3.67] 
100%|██████████| 278/278 [00:05<00:00, 52.18it/s, Corr=0.877, Epoch=49, LR=9.33e-6, Valid_Loss=2.95]


Validation Loss Improved (2.966742152381012 ---> 2.9541534792438537)
Best Loss: 2.9542

fold = 1


100%|██████████| 2219/2219 [00:16<00:00, 134.90it/s, Corr=0.172, Epoch=0, LR=3.41e-6, Train_Loss=14.6]  
100%|██████████| 278/278 [00:05<00:00, 52.38it/s, Corr=0.438, Epoch=0, LR=3.41e-6, Valid_Loss=13.1]


Validation Loss Improved (inf ---> 13.060626913408042)


100%|██████████| 2219/2219 [00:15<00:00, 146.06it/s, Corr=0.607, Epoch=1, LR=1.02e-6, Train_Loss=9.8]  
100%|██████████| 278/278 [00:05<00:00, 52.08it/s, Corr=0.762, Epoch=1, LR=1.02e-6, Valid_Loss=6.45]


Validation Loss Improved (13.060626913408042 ---> 6.453230519949527)


100%|██████████| 2219/2219 [00:15<00:00, 146.61it/s, Corr=0.744, Epoch=2, LR=9.14e-6, Train_Loss=5.88] 
100%|██████████| 278/278 [00:05<00:00, 53.69it/s, Corr=0.801, Epoch=2, LR=9.14e-6, Valid_Loss=4.73]


Validation Loss Improved (6.453230519949527 ---> 4.734828843810884)


100%|██████████| 2219/2219 [00:16<00:00, 134.08it/s, Corr=0.766, Epoch=3, LR=6.34e-6, Train_Loss=5.27] 
100%|██████████| 278/278 [00:05<00:00, 52.97it/s, Corr=0.808, Epoch=3, LR=6.34e-6, Valid_Loss=4.53]


Validation Loss Improved (4.734828843810884 ---> 4.530509101670782)


100%|██████████| 2219/2219 [00:16<00:00, 137.22it/s, Corr=0.773, Epoch=4, LR=6.85e-9, Train_Loss=5.13] 
100%|██████████| 278/278 [00:05<00:00, 53.24it/s, Corr=0.814, Epoch=4, LR=6.85e-9, Valid_Loss=4.39]


Validation Loss Improved (4.530509101670782 ---> 4.385432575926778)


100%|██████████| 2219/2219 [00:17<00:00, 123.77it/s, Corr=0.781, Epoch=5, LR=6.84e-6, Train_Loss=4.98] 
100%|██████████| 278/278 [00:05<00:00, 51.63it/s, Corr=0.825, Epoch=5, LR=6.84e-6, Valid_Loss=4.2] 


Validation Loss Improved (4.385432575926778 ---> 4.204516177769345)


100%|██████████| 2219/2219 [00:14<00:00, 150.46it/s, Corr=0.791, Epoch=6, LR=8.82e-6, Train_Loss=4.79] 
100%|██████████| 278/278 [00:05<00:00, 54.78it/s, Corr=0.835, Epoch=6, LR=8.82e-6, Valid_Loss=4.02]


Validation Loss Improved (4.204516177769345 ---> 4.016201454397322)


100%|██████████| 2219/2219 [00:15<00:00, 139.94it/s, Corr=0.8, Epoch=7, LR=7.23e-7, Train_Loss=4.63]   
100%|██████████| 278/278 [00:05<00:00, 54.16it/s, Corr=0.842, Epoch=7, LR=7.23e-7, Valid_Loss=3.86]


Validation Loss Improved (4.016201454397322 ---> 3.8584957249891243)


100%|██████████| 2219/2219 [00:15<00:00, 144.24it/s, Corr=0.805, Epoch=8, LR=3.91e-6, Train_Loss=4.51] 
100%|██████████| 278/278 [00:05<00:00, 52.90it/s, Corr=0.845, Epoch=8, LR=3.91e-6, Valid_Loss=3.77]


Validation Loss Improved (3.8584957249891243 ---> 3.7694560549538214)


100%|██████████| 2219/2219 [00:14<00:00, 148.42it/s, Corr=0.808, Epoch=9, LR=9.97e-6, Train_Loss=4.45] 
100%|██████████| 278/278 [00:05<00:00, 54.02it/s, Corr=0.847, Epoch=9, LR=9.97e-6, Valid_Loss=3.72]


Validation Loss Improved (3.7694560549538214 ---> 3.719016355400711)


100%|██████████| 2219/2219 [00:14<00:00, 148.11it/s, Corr=0.809, Epoch=10, LR=2.92e-6, Train_Loss=4.42] 
100%|██████████| 278/278 [00:05<00:00, 54.02it/s, Corr=0.848, Epoch=10, LR=2.92e-6, Valid_Loss=3.68]


Validation Loss Improved (3.719016355400711 ---> 3.684438521260039)


100%|██████████| 2219/2219 [00:14<00:00, 150.92it/s, Corr=0.811, Epoch=11, LR=1.36e-6, Train_Loss=4.39] 
100%|██████████| 278/278 [00:05<00:00, 53.85it/s, Corr=0.849, Epoch=11, LR=1.36e-6, Valid_Loss=3.65]


Validation Loss Improved (3.684438521260039 ---> 3.65047872549461)


100%|██████████| 2219/2219 [00:14<00:00, 151.74it/s, Corr=0.812, Epoch=12, LR=9.41e-6, Train_Loss=4.35] 
100%|██████████| 278/278 [00:05<00:00, 53.39it/s, Corr=0.851, Epoch=12, LR=9.41e-6, Valid_Loss=3.62]


Validation Loss Improved (3.65047872549461 ---> 3.6209941789480156)


100%|██████████| 2219/2219 [00:14<00:00, 150.61it/s, Corr=0.813, Epoch=13, LR=5.83e-6, Train_Loss=4.32] 
100%|██████████| 278/278 [00:04<00:00, 55.91it/s, Corr=0.852, Epoch=13, LR=5.83e-6, Valid_Loss=3.6] 


Validation Loss Improved (3.6209941789480156 ---> 3.596803823524699)


100%|██████████| 2219/2219 [00:16<00:00, 137.18it/s, Corr=0.815, Epoch=14, LR=6.16e-8, Train_Loss=4.28] 
100%|██████████| 278/278 [00:05<00:00, 52.72it/s, Corr=0.853, Epoch=14, LR=6.16e-8, Valid_Loss=3.56]


Validation Loss Improved (3.596803823524699 ---> 3.555849860310588)


100%|██████████| 2219/2219 [00:14<00:00, 148.91it/s, Corr=0.817, Epoch=15, LR=7.32e-6, Train_Loss=4.24] 
100%|██████████| 278/278 [00:05<00:00, 55.43it/s, Corr=0.855, Epoch=15, LR=7.32e-6, Valid_Loss=3.52]


Validation Loss Improved (3.555849860310588 ---> 3.5212732852596815)


100%|██████████| 2219/2219 [00:14<00:00, 149.79it/s, Corr=0.818, Epoch=16, LR=8.46e-6, Train_Loss=4.21] 
100%|██████████| 278/278 [00:04<00:00, 55.75it/s, Corr=0.856, Epoch=16, LR=8.46e-6, Valid_Loss=3.49]


Validation Loss Improved (3.5212732852596815 ---> 3.4890069273711526)


100%|██████████| 2219/2219 [00:14<00:00, 152.12it/s, Corr=0.82, Epoch=17, LR=4.76e-7, Train_Loss=4.16] 
100%|██████████| 278/278 [00:04<00:00, 56.07it/s, Corr=0.858, Epoch=17, LR=4.76e-7, Valid_Loss=3.45]


Validation Loss Improved (3.4890069273711526 ---> 3.44957640586574)


100%|██████████| 2219/2219 [00:13<00:00, 160.19it/s, Corr=0.821, Epoch=18, LR=4.43e-6, Train_Loss=4.15] 
100%|██████████| 278/278 [00:04<00:00, 56.43it/s, Corr=0.859, Epoch=18, LR=4.43e-6, Valid_Loss=3.42]


Validation Loss Improved (3.44957640586574 ---> 3.4169244562033945)


100%|██████████| 2219/2219 [00:13<00:00, 161.05it/s, Corr=0.823, Epoch=19, LR=9.89e-6, Train_Loss=4.12] 
100%|██████████| 278/278 [00:05<00:00, 50.98it/s, Corr=0.86, Epoch=19, LR=9.89e-6, Valid_Loss=3.4]  


Validation Loss Improved (3.4169244562033945 ---> 3.399154110559915)


100%|██████████| 2219/2219 [00:14<00:00, 154.59it/s, Corr=0.824, Epoch=20, LR=2.45e-6, Train_Loss=4.08] 
100%|██████████| 278/278 [00:04<00:00, 55.84it/s, Corr=0.861, Epoch=20, LR=2.45e-6, Valid_Loss=3.37]


Validation Loss Improved (3.399154110559915 ---> 3.370791022916065)


100%|██████████| 2219/2219 [00:14<00:00, 148.54it/s, Corr=0.824, Epoch=21, LR=1.73e-6, Train_Loss=4.07] 
100%|██████████| 278/278 [00:04<00:00, 55.98it/s, Corr=0.862, Epoch=21, LR=1.73e-6, Valid_Loss=3.35]


Validation Loss Improved (3.370791022916065 ---> 3.3462596650848244)


100%|██████████| 2219/2219 [00:13<00:00, 162.44it/s, Corr=0.825, Epoch=22, LR=9.63e-6, Train_Loss=4.04] 
100%|██████████| 278/278 [00:04<00:00, 56.28it/s, Corr=0.863, Epoch=22, LR=9.63e-6, Valid_Loss=3.33]


Validation Loss Improved (3.3462596650848244 ---> 3.327770222971095)


100%|██████████| 2219/2219 [00:14<00:00, 150.83it/s, Corr=0.826, Epoch=23, LR=5.31e-6, Train_Loss=4.04] 
100%|██████████| 278/278 [00:04<00:00, 55.75it/s, Corr=0.863, Epoch=23, LR=5.31e-6, Valid_Loss=3.31]


Validation Loss Improved (3.327770222971095 ---> 3.309871700506328)


100%|██████████| 2219/2219 [00:14<00:00, 152.22it/s, Corr=0.827, Epoch=24, LR=1.7e-7, Train_Loss=4.01]  
100%|██████████| 278/278 [00:04<00:00, 57.99it/s, Corr=0.864, Epoch=24, LR=1.7e-7, Valid_Loss=3.29]


Validation Loss Improved (3.309871700506328 ---> 3.2883570078709377)


100%|██████████| 2219/2219 [00:13<00:00, 165.07it/s, Corr=0.827, Epoch=25, LR=7.77e-6, Train_Loss=3.99] 
100%|██████████| 278/278 [00:05<00:00, 54.86it/s, Corr=0.865, Epoch=25, LR=7.77e-6, Valid_Loss=3.28]


Validation Loss Improved (3.2883570078709377 ---> 3.277234881913124)


100%|██████████| 2219/2219 [00:13<00:00, 159.57it/s, Corr=0.828, Epoch=26, LR=8.06e-6, Train_Loss=3.97] 
100%|██████████| 278/278 [00:04<00:00, 56.22it/s, Corr=0.865, Epoch=26, LR=8.06e-6, Valid_Loss=3.26]


Validation Loss Improved (3.277234881913124 ---> 3.2554390070948487)


100%|██████████| 2219/2219 [00:14<00:00, 154.61it/s, Corr=0.829, Epoch=27, LR=2.78e-7, Train_Loss=3.95] 
100%|██████████| 278/278 [00:04<00:00, 56.61it/s, Corr=0.866, Epoch=27, LR=2.78e-7, Valid_Loss=3.24]


Validation Loss Improved (3.2554390070948487 ---> 3.239040879442328)


100%|██████████| 2219/2219 [00:13<00:00, 159.11it/s, Corr=0.83, Epoch=28, LR=4.95e-6, Train_Loss=3.93] 
100%|██████████| 278/278 [00:05<00:00, 55.23it/s, Corr=0.867, Epoch=28, LR=4.95e-6, Valid_Loss=3.22]


Validation Loss Improved (3.239040879442328 ---> 3.218011938552289)


100%|██████████| 2219/2219 [00:14<00:00, 156.95it/s, Corr=0.83, Epoch=29, LR=9.76e-6, Train_Loss=3.92] 
100%|██████████| 278/278 [00:05<00:00, 55.33it/s, Corr=0.868, Epoch=29, LR=9.76e-6, Valid_Loss=3.2] 


Validation Loss Improved (3.218011938552289 ---> 3.201093125693958)


100%|██████████| 2219/2219 [00:13<00:00, 159.49it/s, Corr=0.831, Epoch=30, LR=2.02e-6, Train_Loss=3.89] 
100%|██████████| 278/278 [00:05<00:00, 53.75it/s, Corr=0.868, Epoch=30, LR=2.02e-6, Valid_Loss=3.18]


Validation Loss Improved (3.201093125693958 ---> 3.1807921137508557)


100%|██████████| 2219/2219 [00:13<00:00, 161.41it/s, Corr=0.832, Epoch=31, LR=2.15e-6, Train_Loss=3.89] 
100%|██████████| 278/278 [00:05<00:00, 55.55it/s, Corr=0.869, Epoch=31, LR=2.15e-6, Valid_Loss=3.16]


Validation Loss Improved (3.1807921137508557 ---> 3.1589832742926087)


100%|██████████| 2219/2219 [00:14<00:00, 153.29it/s, Corr=0.833, Epoch=32, LR=9.8e-6, Train_Loss=3.86]  
100%|██████████| 278/278 [00:04<00:00, 56.60it/s, Corr=0.87, Epoch=32, LR=9.8e-6, Valid_Loss=3.14] 


Validation Loss Improved (3.1589832742926087 ---> 3.139766588005246)


100%|██████████| 2219/2219 [00:13<00:00, 160.25it/s, Corr=0.834, Epoch=33, LR=4.79e-6, Train_Loss=3.84] 
100%|██████████| 278/278 [00:05<00:00, 55.28it/s, Corr=0.871, Epoch=33, LR=4.79e-6, Valid_Loss=3.12]


Validation Loss Improved (3.139766588005246 ---> 3.1209456723273865)


100%|██████████| 2219/2219 [00:14<00:00, 152.04it/s, Corr=0.835, Epoch=34, LR=3.32e-7, Train_Loss=3.82] 
100%|██████████| 278/278 [00:04<00:00, 55.97it/s, Corr=0.872, Epoch=34, LR=3.32e-7, Valid_Loss=3.11]


Validation Loss Improved (3.1209456723273865 ---> 3.1060733199220327)


100%|██████████| 2219/2219 [00:14<00:00, 158.12it/s, Corr=0.835, Epoch=35, LR=8.19e-6, Train_Loss=3.8]  
100%|██████████| 278/278 [00:04<00:00, 56.89it/s, Corr=0.872, Epoch=35, LR=8.19e-6, Valid_Loss=3.09]


Validation Loss Improved (3.1060733199220327 ---> 3.0876187075425077)


100%|██████████| 2219/2219 [00:14<00:00, 152.18it/s, Corr=0.836, Epoch=36, LR=7.63e-6, Train_Loss=3.79] 
100%|██████████| 278/278 [00:04<00:00, 56.64it/s, Corr=0.873, Epoch=36, LR=7.63e-6, Valid_Loss=3.07]


Validation Loss Improved (3.0876187075425077 ---> 3.069286052694507)


100%|██████████| 2219/2219 [00:14<00:00, 154.92it/s, Corr=0.837, Epoch=37, LR=1.32e-7, Train_Loss=3.78] 
100%|██████████| 278/278 [00:05<00:00, 54.46it/s, Corr=0.873, Epoch=37, LR=1.32e-7, Valid_Loss=3.05]


Validation Loss Improved (3.069286052694507 ---> 3.052036856223934)


100%|██████████| 2219/2219 [00:14<00:00, 155.71it/s, Corr=0.837, Epoch=38, LR=5.47e-6, Train_Loss=3.76] 
100%|██████████| 278/278 [00:05<00:00, 53.90it/s, Corr=0.874, Epoch=38, LR=5.47e-6, Valid_Loss=3.04]


Validation Loss Improved (3.052036856223934 ---> 3.0398802365653244)


100%|██████████| 2219/2219 [00:14<00:00, 153.69it/s, Corr=0.838, Epoch=39, LR=9.57e-6, Train_Loss=3.75] 
100%|██████████| 278/278 [00:04<00:00, 55.83it/s, Corr=0.874, Epoch=39, LR=9.57e-6, Valid_Loss=3.03]


Validation Loss Improved (3.0398802365653244 ---> 3.0266790151958793)


100%|██████████| 2219/2219 [00:14<00:00, 148.41it/s, Corr=0.838, Epoch=40, LR=1.62e-6, Train_Loss=3.73] 
100%|██████████| 278/278 [00:05<00:00, 54.37it/s, Corr=0.875, Epoch=40, LR=1.62e-6, Valid_Loss=3.02]


Validation Loss Improved (3.0266790151958793 ---> 3.015134453296581)


100%|██████████| 2219/2219 [00:14<00:00, 155.61it/s, Corr=0.838, Epoch=41, LR=2.59e-6, Train_Loss=3.73] 
100%|██████████| 278/278 [00:05<00:00, 52.37it/s, Corr=0.875, Epoch=41, LR=2.59e-6, Valid_Loss=3]   


Validation Loss Improved (3.015134453296581 ---> 3.0014706928709725)


100%|██████████| 2219/2219 [00:14<00:00, 152.09it/s, Corr=0.839, Epoch=42, LR=9.92e-6, Train_Loss=3.71]
100%|██████████| 278/278 [00:05<00:00, 53.66it/s, Corr=0.876, Epoch=42, LR=9.92e-6, Valid_Loss=2.99]


Validation Loss Improved (3.0014706928709725 ---> 2.9913509152416125)


100%|██████████| 2219/2219 [00:14<00:00, 148.58it/s, Corr=0.839, Epoch=43, LR=4.27e-6, Train_Loss=3.7] 
100%|██████████| 278/278 [00:05<00:00, 55.17it/s, Corr=0.876, Epoch=43, LR=4.27e-6, Valid_Loss=2.98]


Validation Loss Improved (2.9913509152416125 ---> 2.9774598080910257)


100%|██████████| 2219/2219 [00:14<00:00, 154.56it/s, Corr=0.84, Epoch=44, LR=5.45e-7, Train_Loss=3.68] 
100%|██████████| 278/278 [00:05<00:00, 55.32it/s, Corr=0.877, Epoch=44, LR=5.45e-7, Valid_Loss=2.97]


Validation Loss Improved (2.9774598080910257 ---> 2.9728047394205728)


100%|██████████| 2219/2219 [00:14<00:00, 149.00it/s, Corr=0.84, Epoch=45, LR=8.57e-6, Train_Loss=3.67] 
100%|██████████| 278/278 [00:04<00:00, 55.89it/s, Corr=0.877, Epoch=45, LR=8.57e-6, Valid_Loss=2.96]


Validation Loss Improved (2.9728047394205728 ---> 2.958345487445967)


100%|██████████| 2219/2219 [00:15<00:00, 145.80it/s, Corr=0.84, Epoch=46, LR=7.18e-6, Train_Loss=3.68] 
100%|██████████| 278/278 [00:05<00:00, 52.75it/s, Corr=0.877, Epoch=46, LR=7.18e-6, Valid_Loss=2.95]


Validation Loss Improved (2.958345487445967 ---> 2.948406625292419)


100%|██████████| 2219/2219 [00:14<00:00, 153.66it/s, Corr=0.841, Epoch=47, LR=3.94e-8, Train_Loss=3.66] 
100%|██████████| 278/278 [00:05<00:00, 54.56it/s, Corr=0.878, Epoch=47, LR=3.94e-8, Valid_Loss=2.94]


Validation Loss Improved (2.948406625292419 ---> 2.940699591383824)


100%|██████████| 2219/2219 [00:14<00:00, 157.20it/s, Corr=0.841, Epoch=48, LR=5.99e-6, Train_Loss=3.65] 
100%|██████████| 278/278 [00:04<00:00, 55.86it/s, Corr=0.878, Epoch=48, LR=5.99e-6, Valid_Loss=2.93]


Validation Loss Improved (2.940699591383824 ---> 2.931108274614065)


100%|██████████| 2219/2219 [00:14<00:00, 156.16it/s, Corr=0.842, Epoch=49, LR=9.33e-6, Train_Loss=3.64] 
100%|██████████| 278/278 [00:05<00:00, 54.11it/s, Corr=0.878, Epoch=49, LR=9.33e-6, Valid_Loss=2.92]


Validation Loss Improved (2.931108274614065 ---> 2.9239683264227607)
Best Loss: 2.9240

fold = 2


100%|██████████| 2219/2219 [00:14<00:00, 149.33it/s, Corr=0.19, Epoch=0, LR=3.41e-6, Train_Loss=14.7]  
100%|██████████| 278/278 [00:05<00:00, 54.52it/s, Corr=0.471, Epoch=0, LR=3.41e-6, Valid_Loss=13.2]


Validation Loss Improved (inf ---> 13.223989607173166)


100%|██████████| 2219/2219 [00:15<00:00, 147.68it/s, Corr=0.621, Epoch=1, LR=1.02e-6, Train_Loss=9.9]  
100%|██████████| 278/278 [00:05<00:00, 54.74it/s, Corr=0.764, Epoch=1, LR=1.02e-6, Valid_Loss=6.49]


Validation Loss Improved (13.223989607173166 ---> 6.492581515377486)


100%|██████████| 2219/2219 [00:14<00:00, 155.09it/s, Corr=0.744, Epoch=2, LR=9.14e-6, Train_Loss=5.9]  
100%|██████████| 278/278 [00:05<00:00, 52.62it/s, Corr=0.801, Epoch=2, LR=9.14e-6, Valid_Loss=4.74]


Validation Loss Improved (6.492581515377486 ---> 4.741882689470586)


100%|██████████| 2219/2219 [00:13<00:00, 158.85it/s, Corr=0.767, Epoch=3, LR=6.34e-6, Train_Loss=5.26] 
100%|██████████| 278/278 [00:05<00:00, 54.54it/s, Corr=0.809, Epoch=3, LR=6.34e-6, Valid_Loss=4.51]


Validation Loss Improved (4.741882689470586 ---> 4.508651130144486)


100%|██████████| 2219/2219 [00:14<00:00, 154.75it/s, Corr=0.774, Epoch=4, LR=6.85e-9, Train_Loss=5.1]  
100%|██████████| 278/278 [00:04<00:00, 56.00it/s, Corr=0.816, Epoch=4, LR=6.85e-9, Valid_Loss=4.36]


Validation Loss Improved (4.508651130144486 ---> 4.35695511985646)


100%|██████████| 2219/2219 [00:14<00:00, 156.78it/s, Corr=0.782, Epoch=5, LR=6.84e-6, Train_Loss=4.95] 
100%|██████████| 278/278 [00:05<00:00, 52.52it/s, Corr=0.826, Epoch=5, LR=6.84e-6, Valid_Loss=4.17]


Validation Loss Improved (4.35695511985646 ---> 4.168572415229185)


100%|██████████| 2219/2219 [00:14<00:00, 157.17it/s, Corr=0.792, Epoch=6, LR=8.82e-6, Train_Loss=4.75] 
100%|██████████| 278/278 [00:05<00:00, 54.70it/s, Corr=0.836, Epoch=6, LR=8.82e-6, Valid_Loss=3.99]


Validation Loss Improved (4.168572415229185 ---> 3.990958674172855)


100%|██████████| 2219/2219 [00:14<00:00, 153.97it/s, Corr=0.801, Epoch=7, LR=7.23e-7, Train_Loss=4.6]  
100%|██████████| 278/278 [00:04<00:00, 58.71it/s, Corr=0.843, Epoch=7, LR=7.23e-7, Valid_Loss=3.83]


Validation Loss Improved (3.990958674172855 ---> 3.8346733976405942)


100%|██████████| 2219/2219 [00:13<00:00, 162.07it/s, Corr=0.806, Epoch=8, LR=3.91e-6, Train_Loss=4.49] 
100%|██████████| 278/278 [00:04<00:00, 56.09it/s, Corr=0.846, Epoch=8, LR=3.91e-6, Valid_Loss=3.75]


Validation Loss Improved (3.8346733976405942 ---> 3.746552112062609)


100%|██████████| 2219/2219 [00:14<00:00, 157.79it/s, Corr=0.808, Epoch=9, LR=9.97e-6, Train_Loss=4.43] 
100%|██████████| 278/278 [00:04<00:00, 57.16it/s, Corr=0.847, Epoch=9, LR=9.97e-6, Valid_Loss=3.7] 


Validation Loss Improved (3.746552112062609 ---> 3.703958052033712)


100%|██████████| 2219/2219 [00:13<00:00, 158.89it/s, Corr=0.809, Epoch=10, LR=2.92e-6, Train_Loss=4.39] 
100%|██████████| 278/278 [00:05<00:00, 52.76it/s, Corr=0.848, Epoch=10, LR=2.92e-6, Valid_Loss=3.67]


Validation Loss Improved (3.703958052033712 ---> 3.6730719127742084)


100%|██████████| 2219/2219 [00:14<00:00, 153.16it/s, Corr=0.811, Epoch=11, LR=1.36e-6, Train_Loss=4.36] 
100%|██████████| 278/278 [00:04<00:00, 56.00it/s, Corr=0.849, Epoch=11, LR=1.36e-6, Valid_Loss=3.64]


Validation Loss Improved (3.6730719127742084 ---> 3.6381618056208906)


100%|██████████| 2219/2219 [00:14<00:00, 152.23it/s, Corr=0.812, Epoch=12, LR=9.41e-6, Train_Loss=4.33] 
100%|██████████| 278/278 [00:05<00:00, 52.28it/s, Corr=0.85, Epoch=12, LR=9.41e-6, Valid_Loss=3.62] 


Validation Loss Improved (3.6381618056208906 ---> 3.624163392779457)


100%|██████████| 2219/2219 [00:15<00:00, 141.49it/s, Corr=0.813, Epoch=13, LR=5.83e-6, Train_Loss=4.31] 
100%|██████████| 278/278 [00:05<00:00, 50.90it/s, Corr=0.851, Epoch=13, LR=5.83e-6, Valid_Loss=3.6] 


Validation Loss Improved (3.624163392779457 ---> 3.598309708332125)


100%|██████████| 2219/2219 [00:15<00:00, 143.41it/s, Corr=0.814, Epoch=14, LR=6.16e-8, Train_Loss=4.29] 
100%|██████████| 278/278 [00:05<00:00, 52.69it/s, Corr=0.852, Epoch=14, LR=6.16e-8, Valid_Loss=3.57]


Validation Loss Improved (3.598309708332125 ---> 3.5737597699997465)


100%|██████████| 2219/2219 [00:15<00:00, 145.37it/s, Corr=0.815, Epoch=15, LR=7.32e-6, Train_Loss=4.26] 
100%|██████████| 278/278 [00:05<00:00, 54.19it/s, Corr=0.853, Epoch=15, LR=7.32e-6, Valid_Loss=3.55]


Validation Loss Improved (3.5737597699997465 ---> 3.5493416568020764)


100%|██████████| 2219/2219 [00:15<00:00, 144.75it/s, Corr=0.816, Epoch=16, LR=8.46e-6, Train_Loss=4.24] 
100%|██████████| 278/278 [00:05<00:00, 51.79it/s, Corr=0.854, Epoch=16, LR=8.46e-6, Valid_Loss=3.52]


Validation Loss Improved (3.5493416568020764 ---> 3.5233819818687473)


100%|██████████| 2219/2219 [00:15<00:00, 139.60it/s, Corr=0.818, Epoch=17, LR=4.76e-7, Train_Loss=4.21] 
100%|██████████| 278/278 [00:05<00:00, 54.84it/s, Corr=0.856, Epoch=17, LR=4.76e-7, Valid_Loss=3.49]


Validation Loss Improved (3.5233819818687473 ---> 3.4899463879529433)


100%|██████████| 2219/2219 [00:15<00:00, 139.18it/s, Corr=0.82, Epoch=18, LR=4.43e-6, Train_Loss=4.17]  
100%|██████████| 278/278 [00:05<00:00, 52.80it/s, Corr=0.857, Epoch=18, LR=4.43e-6, Valid_Loss=3.45]


Validation Loss Improved (3.4899463879529433 ---> 3.454836473482699)


100%|██████████| 2219/2219 [00:15<00:00, 140.23it/s, Corr=0.821, Epoch=19, LR=9.89e-6, Train_Loss=4.14] 
100%|██████████| 278/278 [00:05<00:00, 52.92it/s, Corr=0.859, Epoch=19, LR=9.89e-6, Valid_Loss=3.42]


Validation Loss Improved (3.454836473482699 ---> 3.4224768351009276)


100%|██████████| 2219/2219 [00:15<00:00, 143.98it/s, Corr=0.822, Epoch=20, LR=2.45e-6, Train_Loss=4.11] 
100%|██████████| 278/278 [00:05<00:00, 53.52it/s, Corr=0.86, Epoch=20, LR=2.45e-6, Valid_Loss=3.4]  


Validation Loss Improved (3.4224768351009276 ---> 3.3981664998635996)


100%|██████████| 2219/2219 [00:15<00:00, 140.26it/s, Corr=0.824, Epoch=21, LR=1.73e-6, Train_Loss=4.08] 
100%|██████████| 278/278 [00:05<00:00, 52.98it/s, Corr=0.861, Epoch=21, LR=1.73e-6, Valid_Loss=3.37]


Validation Loss Improved (3.3981664998635996 ---> 3.3663940919905224)


100%|██████████| 2219/2219 [00:15<00:00, 141.48it/s, Corr=0.825, Epoch=22, LR=9.63e-6, Train_Loss=4.06] 
100%|██████████| 278/278 [00:05<00:00, 53.61it/s, Corr=0.862, Epoch=22, LR=9.63e-6, Valid_Loss=3.34]


Validation Loss Improved (3.3663940919905224 ---> 3.3448496058993844)


100%|██████████| 2219/2219 [00:15<00:00, 143.18it/s, Corr=0.826, Epoch=23, LR=5.31e-6, Train_Loss=4.05] 
100%|██████████| 278/278 [00:05<00:00, 52.90it/s, Corr=0.863, Epoch=23, LR=5.31e-6, Valid_Loss=3.33]


Validation Loss Improved (3.3448496058993844 ---> 3.3327483132153275)


100%|██████████| 2219/2219 [00:15<00:00, 139.41it/s, Corr=0.827, Epoch=24, LR=1.7e-7, Train_Loss=4.01]  
100%|██████████| 278/278 [00:05<00:00, 52.23it/s, Corr=0.863, Epoch=24, LR=1.7e-7, Valid_Loss=3.32]


Validation Loss Improved (3.3327483132153275 ---> 3.3152460158559753)


100%|██████████| 2219/2219 [00:15<00:00, 139.69it/s, Corr=0.827, Epoch=25, LR=7.77e-6, Train_Loss=4]    
100%|██████████| 278/278 [00:05<00:00, 52.63it/s, Corr=0.864, Epoch=25, LR=7.77e-6, Valid_Loss=3.3] 


Validation Loss Improved (3.3152460158559753 ---> 3.297609955907009)


100%|██████████| 2219/2219 [00:16<00:00, 135.32it/s, Corr=0.828, Epoch=26, LR=8.06e-6, Train_Loss=3.99] 
100%|██████████| 278/278 [00:05<00:00, 50.42it/s, Corr=0.864, Epoch=26, LR=8.06e-6, Valid_Loss=3.28]


Validation Loss Improved (3.297609955907009 ---> 3.277377331123195)


100%|██████████| 2219/2219 [00:16<00:00, 133.58it/s, Corr=0.828, Epoch=27, LR=2.78e-7, Train_Loss=3.97] 
100%|██████████| 278/278 [00:05<00:00, 51.96it/s, Corr=0.865, Epoch=27, LR=2.78e-7, Valid_Loss=3.26]


Validation Loss Improved (3.277377331123195 ---> 3.259449489124595)


100%|██████████| 2219/2219 [00:17<00:00, 128.60it/s, Corr=0.829, Epoch=28, LR=4.95e-6, Train_Loss=3.95] 
100%|██████████| 278/278 [00:05<00:00, 51.42it/s, Corr=0.866, Epoch=28, LR=4.95e-6, Valid_Loss=3.25]


Validation Loss Improved (3.259449489124595 ---> 3.248984311086034)


100%|██████████| 2219/2219 [00:17<00:00, 127.92it/s, Corr=0.829, Epoch=29, LR=9.76e-6, Train_Loss=3.94] 
100%|██████████| 278/278 [00:05<00:00, 51.69it/s, Corr=0.866, Epoch=29, LR=9.76e-6, Valid_Loss=3.23]


Validation Loss Improved (3.248984311086034 ---> 3.228645049908359)


100%|██████████| 2219/2219 [00:16<00:00, 132.87it/s, Corr=0.83, Epoch=30, LR=2.02e-6, Train_Loss=3.93] 
100%|██████████| 278/278 [00:05<00:00, 51.82it/s, Corr=0.867, Epoch=30, LR=2.02e-6, Valid_Loss=3.21]


Validation Loss Improved (3.228645049908359 ---> 3.2117115546328763)


100%|██████████| 2219/2219 [00:17<00:00, 128.63it/s, Corr=0.83, Epoch=31, LR=2.15e-6, Train_Loss=3.92] 
100%|██████████| 278/278 [00:05<00:00, 51.84it/s, Corr=0.868, Epoch=31, LR=2.15e-6, Valid_Loss=3.19]


Validation Loss Improved (3.2117115546328763 ---> 3.1947412299540248)


100%|██████████| 2219/2219 [00:16<00:00, 135.62it/s, Corr=0.831, Epoch=32, LR=9.8e-6, Train_Loss=3.9]   
100%|██████████| 278/278 [00:05<00:00, 52.74it/s, Corr=0.868, Epoch=32, LR=9.8e-6, Valid_Loss=3.18]


Validation Loss Improved (3.1947412299540248 ---> 3.1832306779068396)


100%|██████████| 2219/2219 [00:16<00:00, 135.86it/s, Corr=0.833, Epoch=33, LR=4.79e-6, Train_Loss=3.87] 
100%|██████████| 278/278 [00:05<00:00, 50.80it/s, Corr=0.869, Epoch=33, LR=4.79e-6, Valid_Loss=3.17]


Validation Loss Improved (3.1832306779068396 ---> 3.170793592892896)


100%|██████████| 2219/2219 [00:16<00:00, 135.28it/s, Corr=0.833, Epoch=34, LR=3.32e-7, Train_Loss=3.85] 
100%|██████████| 278/278 [00:05<00:00, 52.60it/s, Corr=0.87, Epoch=34, LR=3.32e-7, Valid_Loss=3.15] 


Validation Loss Improved (3.170793592892896 ---> 3.1471296359755567)


100%|██████████| 2219/2219 [00:17<00:00, 128.42it/s, Corr=0.834, Epoch=35, LR=8.19e-6, Train_Loss=3.84] 
100%|██████████| 278/278 [00:05<00:00, 52.84it/s, Corr=0.87, Epoch=35, LR=8.19e-6, Valid_Loss=3.13] 


Validation Loss Improved (3.1471296359755567 ---> 3.1250434025392444)


100%|██████████| 2219/2219 [00:17<00:00, 129.85it/s, Corr=0.835, Epoch=36, LR=7.63e-6, Train_Loss=3.82] 
100%|██████████| 278/278 [00:05<00:00, 52.14it/s, Corr=0.871, Epoch=36, LR=7.63e-6, Valid_Loss=3.12]


Validation Loss Improved (3.1250434025392444 ---> 3.1150042578530135)


100%|██████████| 2219/2219 [00:16<00:00, 135.36it/s, Corr=0.835, Epoch=37, LR=1.32e-7, Train_Loss=3.81] 
100%|██████████| 278/278 [00:05<00:00, 52.08it/s, Corr=0.872, Epoch=37, LR=1.32e-7, Valid_Loss=3.09]


Validation Loss Improved (3.1150042578530135 ---> 3.092352348172766)


100%|██████████| 2219/2219 [00:16<00:00, 134.18it/s, Corr=0.836, Epoch=38, LR=5.47e-6, Train_Loss=3.78] 
100%|██████████| 278/278 [00:05<00:00, 53.14it/s, Corr=0.873, Epoch=38, LR=5.47e-6, Valid_Loss=3.08]


Validation Loss Improved (3.092352348172766 ---> 3.0762406780718816)


100%|██████████| 2219/2219 [00:17<00:00, 127.12it/s, Corr=0.837, Epoch=39, LR=9.57e-6, Train_Loss=3.78] 
100%|██████████| 278/278 [00:05<00:00, 52.30it/s, Corr=0.873, Epoch=39, LR=9.57e-6, Valid_Loss=3.06]


Validation Loss Improved (3.0762406780718816 ---> 3.0577780504109002)


100%|██████████| 2219/2219 [00:15<00:00, 142.34it/s, Corr=0.837, Epoch=40, LR=1.62e-6, Train_Loss=3.75] 
100%|██████████| 278/278 [00:05<00:00, 50.79it/s, Corr=0.874, Epoch=40, LR=1.62e-6, Valid_Loss=3.04]


Validation Loss Improved (3.0577780504109002 ---> 3.04440128464037)


100%|██████████| 2219/2219 [00:15<00:00, 143.30it/s, Corr=0.838, Epoch=41, LR=2.59e-6, Train_Loss=3.74] 
100%|██████████| 278/278 [00:05<00:00, 52.54it/s, Corr=0.875, Epoch=41, LR=2.59e-6, Valid_Loss=3.02]


Validation Loss Improved (3.04440128464037 ---> 3.023398758304175)


100%|██████████| 2219/2219 [00:16<00:00, 137.59it/s, Corr=0.839, Epoch=42, LR=9.92e-6, Train_Loss=3.72] 
100%|██████████| 278/278 [00:05<00:00, 52.63it/s, Corr=0.875, Epoch=42, LR=9.92e-6, Valid_Loss=3.01]


Validation Loss Improved (3.023398758304175 ---> 3.0097825113011596)


100%|██████████| 2219/2219 [00:17<00:00, 127.48it/s, Corr=0.839, Epoch=43, LR=4.27e-6, Train_Loss=3.71] 
100%|██████████| 278/278 [00:05<00:00, 51.84it/s, Corr=0.876, Epoch=43, LR=4.27e-6, Valid_Loss=2.99]


Validation Loss Improved (3.0097825113011596 ---> 2.991548910432247)


100%|██████████| 2219/2219 [00:16<00:00, 138.25it/s, Corr=0.84, Epoch=44, LR=5.45e-7, Train_Loss=3.7]  
100%|██████████| 278/278 [00:05<00:00, 52.45it/s, Corr=0.876, Epoch=44, LR=5.45e-7, Valid_Loss=2.98]


Validation Loss Improved (2.991548910432247 ---> 2.980344538555123)


100%|██████████| 2219/2219 [00:16<00:00, 132.24it/s, Corr=0.84, Epoch=45, LR=8.57e-6, Train_Loss=3.69]  
100%|██████████| 278/278 [00:05<00:00, 52.12it/s, Corr=0.877, Epoch=45, LR=8.57e-6, Valid_Loss=2.97]


Validation Loss Improved (2.980344538555123 ---> 2.9654022616091233)


100%|██████████| 2219/2219 [00:16<00:00, 135.30it/s, Corr=0.841, Epoch=46, LR=7.18e-6, Train_Loss=3.68]
100%|██████████| 278/278 [00:05<00:00, 52.16it/s, Corr=0.877, Epoch=46, LR=7.18e-6, Valid_Loss=2.96]


Validation Loss Improved (2.9654022616091233 ---> 2.95890826956483)


100%|██████████| 2219/2219 [00:16<00:00, 135.81it/s, Corr=0.841, Epoch=47, LR=3.94e-8, Train_Loss=3.67] 
100%|██████████| 278/278 [00:05<00:00, 50.36it/s, Corr=0.878, Epoch=47, LR=3.94e-8, Valid_Loss=2.95]


Validation Loss Improved (2.95890826956483 ---> 2.947902691829787)


100%|██████████| 2219/2219 [00:16<00:00, 134.62it/s, Corr=0.842, Epoch=48, LR=5.99e-6, Train_Loss=3.65] 
100%|██████████| 278/278 [00:05<00:00, 53.45it/s, Corr=0.878, Epoch=48, LR=5.99e-6, Valid_Loss=2.93]


Validation Loss Improved (2.947902691829787 ---> 2.930954406541818)


100%|██████████| 2219/2219 [00:16<00:00, 134.92it/s, Corr=0.842, Epoch=49, LR=9.33e-6, Train_Loss=3.64] 
100%|██████████| 278/278 [00:05<00:00, 53.11it/s, Corr=0.878, Epoch=49, LR=9.33e-6, Valid_Loss=2.92]


Validation Loss Improved (2.930954406541818 ---> 2.92433877615161)
Best Loss: 2.9243


In [31]:
cite_test_x = torch.Tensor(x_test)
y_pred = model(cite_test_x)
y_pred_num = y_pred.cpu().detach().numpy()
y_pred_num

array([[0.29457086, 0.40730733, 0.7199645 , ..., 0.6051998 , 1.3990179 ,
        2.580193  ],
       [0.24795142, 0.27027118, 0.6212083 , ..., 0.4302716 , 1.1435115 ,
        2.1924307 ],
       [0.36043057, 0.42863062, 0.82495874, ..., 0.53615147, 2.3786008 ,
        2.9371574 ],
       ...,
       [0.22843744, 0.42764997, 0.87325037, ..., 0.46410263, 6.767725  ,
        3.1862857 ],
       [0.09315833, 0.4532459 , 0.97633266, ..., 0.33600792, 3.545335  ,
        2.9931936 ],
       [1.1873575 , 0.35675955, 1.1085684 , ..., 1.8247237 , 3.7981806 ,
        4.4816737 ]], dtype=float32)

In [32]:
from sklearn.metrics import r2_score
r2 = r2_score(y_test, y_pred_num)
r2

0.16824296060644817