In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils import data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
from tqdm import tqdm
torch.multiprocessing.set_sharing_strategy('file_system')

dt = pd.read_csv("../data/all_train.csv")
aaindex= pd.read_csv("../data/aaindex1_pca.csv")

class pep_hla_Dataset(Dataset):
    def __init__(self, pep_hla_dt, aaindex):
        self.data_pep = pep_hla_dt
        self.aaindex = aaindex
    
    def __len__(self):
        return len(self.data_pep)
    
    def aa_encoding(self, pep, hla):
        all_aa = pep + hla
        res = []
        for i in list(all_aa):
            res.append(list(self.aaindex[i]))
        res = np.concatenate(res)
        return res
    
    def __getitem__(self, index):
        hla = self.data_pep.iloc[index,0]
        pep = self.data_pep.iloc[index,4]
        lable = self.data_pep.iloc[index,2]
        input_feature = self.aa_encoding(pep, hla)
        return torch.tensor(input_feature), torch.tensor(lable)
    
all_dt = pep_hla_Dataset(dt,aaindex)

import torch.nn.functional as F 
class immune_net(torch.nn.Module):
    def __init__(self, feature_size, model_params):
        super().__init__()
        hidden1 = model_params["model_hidden1"]
        hidden2 = model_params["model_hidden2"]
        hidden3 = model_params["model_hidden3"]
        self.linear = torch.nn.Sequential(
            torch.nn.Linear(feature_size, hidden1),
            torch.nn.LeakyReLU(),
            torch.nn.BatchNorm1d(hidden1),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(hidden1, hidden2),
            torch.nn.LeakyReLU(),
            torch.nn.BatchNorm1d(hidden2),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(hidden2, hidden3),
            torch.nn.LeakyReLU(),
            torch.nn.BatchNorm1d(hidden3),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(hidden3, 1)
        )
 
    def forward(self, x):
        output = self.linear(x)
        return output
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def train_one_epoch(epoch, model, train_loader, optimizer, loss_fn, epoch_num):
    # Enumerate over the data
    all_preds = []
    all_labels = []
    all_preds_raw = []
    running_loss = 0.0
    step = 0
    for _, batch in enumerate(tqdm(train_loader)): 
        #remove = batch[1]
        target = batch[1]
        input_x = batch[0]
        optimizer.zero_grad() 
        pred = model(input_x.float()) 
        # Calculating the loss and gradients
        loss = loss_fn(pred, target.reshape(pred.shape[0],1).float())
        loss.backward()  
        optimizer.step()  
        # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_preds_raw.append(torch.sigmoid(pred).cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    all_preds_raw = np.concatenate(all_preds_raw).ravel()
    if epoch == epoch_num:
        return running_loss/step, all_preds, all_labels, all_preds_raw
    return running_loss/step, step, step, step

def test(epoch, model, test_loader, loss_fn):
    all_preds = []
    all_labels = []
    all_preds_raw = []
    running_loss = 0.0
    step = 0
    for batch in test_loader:
        target = batch[1]
        input_x = batch[0]
        pred = model(input_x.float()) 
        loss = loss_fn(pred, target.reshape(pred.shape[0],1).float())
         # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_preds_raw.append(torch.sigmoid(pred).cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
        
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    all_preds_raw = np.concatenate(all_preds_raw).ravel()
    return all_preds, all_labels, all_preds_raw


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def run_one_training(params, dataset, loss_fn, train_subsampler, test_subsampler):

    # Prepare training
    train_loader = DataLoader(dataset, batch_size=params["batch_size"], sampler=train_subsampler, num_workers = 18, pin_memory=True)
    test_loader = DataLoader(dataset, batch_size=params["batch_size"], sampler=test_subsampler, num_workers = 18, pin_memory=True)

    # Loading the model
    print("Loading model...")
    model_params = {k: v for k, v in params.items() if k.startswith("model_")}
    model = immune_net(feature_size=900, model_params=model_params)
    print(f"Number of parameters: {count_parameters(model)}")

    # < 1 increases precision, > 1 recall
    optimizer = torch.optim.Adam(model.parameters(), 
                                lr=params["learning_rate"],
                                weight_decay=params["weight_decay"])
    #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["scheduler_gamma"])

    # Start training
    for epoch in range(26): 
        # Training
        model.train()
        loss, all_preds, all_labels, all_preds_raw = train_one_epoch(epoch, model, train_loader, optimizer, loss_fn, epoch_num=25)
        print(f"Epoch {epoch} | Train Loss {loss}")
        # Testing
        if epoch == 25:
            model.eval()
            all_preds_test, all_labels_test, all_preds_raw_test = test(epoch, model, test_loader, loss_fn)
            train_pred = pd.DataFrame({"all_preds":all_preds,"all_labels":all_labels,"all_preds_raw":all_preds_raw})
            test_pred = pd.DataFrame({"all_preds":all_preds_test,"all_labels":all_labels_test,"all_preds_raw":all_preds_raw_test})
    return train_pred, test_pred

In [3]:
HYPERPARAMETERS = {
    "batch_size": 16,
    "learning_rate": 0.0001,
    "weight_decay": 0.0001,
    "model_hidden1": 400,
    "model_hidden2": 400,
    "model_hidden3": 200
}
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(24))

In [4]:
from sklearn.model_selection import KFold
kfold = KFold(n_splits=5, shuffle=True)
for fold,(train_idx,test_idx) in enumerate(kfold.split(all_dt)):
    print(f'------------fold no---------{fold}----------------------')
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_idx)
    train_pred, test_pred = run_one_training(HYPERPARAMETERS, all_dt, loss_fn, train_subsampler, test_subsampler)
    train_pred.to_csv("../data/models/train_res_"+str(fold)+".csv")
    test_pred.to_csv("../data/models/test_res_"+str(fold)+".csv")

------------fold no---------0----------------------
Loading model...
Number of parameters: 603201


100%|██████████| 3489/3489 [00:45<00:00, 76.85it/s]


Epoch 0 | Train Loss 1.1440690280236976


100%|██████████| 3489/3489 [00:44<00:00, 78.30it/s]


Epoch 1 | Train Loss 0.938993820806874


100%|██████████| 3489/3489 [00:44<00:00, 79.21it/s]


Epoch 2 | Train Loss 0.8699418904020096


100%|██████████| 3489/3489 [00:44<00:00, 78.78it/s]


Epoch 3 | Train Loss 0.8054898581025798


100%|██████████| 3489/3489 [00:44<00:00, 78.59it/s]


Epoch 4 | Train Loss 0.7831382107498926


100%|██████████| 3489/3489 [00:44<00:00, 78.24it/s]


Epoch 5 | Train Loss 0.7419796589410691


100%|██████████| 3489/3489 [00:41<00:00, 83.38it/s]


Epoch 6 | Train Loss 0.708671648152478


100%|██████████| 3489/3489 [00:41<00:00, 83.53it/s]


Epoch 7 | Train Loss 0.699180876420302


100%|██████████| 3489/3489 [00:44<00:00, 77.90it/s]


Epoch 8 | Train Loss 0.6729087726276834


100%|██████████| 3489/3489 [00:41<00:00, 84.27it/s]


Epoch 9 | Train Loss 0.6742542065225177


100%|██████████| 3489/3489 [00:41<00:00, 83.27it/s]


Epoch 10 | Train Loss 0.6532275637691293


100%|██████████| 3489/3489 [00:41<00:00, 83.75it/s]


Epoch 11 | Train Loss 0.6400336255552981


100%|██████████| 3489/3489 [00:44<00:00, 77.98it/s]


Epoch 12 | Train Loss 0.6356990389989283


100%|██████████| 3489/3489 [00:44<00:00, 78.03it/s]


Epoch 13 | Train Loss 0.6189546784469855


100%|██████████| 3489/3489 [00:45<00:00, 76.39it/s]


Epoch 14 | Train Loss 0.6009057038542397


100%|██████████| 3489/3489 [00:44<00:00, 78.48it/s]


Epoch 15 | Train Loss 0.622291074155497


100%|██████████| 3489/3489 [00:45<00:00, 77.09it/s]


Epoch 16 | Train Loss 0.6069610346547509


100%|██████████| 3489/3489 [00:42<00:00, 82.12it/s]


Epoch 17 | Train Loss 0.5924213199902141


100%|██████████| 3489/3489 [00:45<00:00, 77.39it/s]


Epoch 18 | Train Loss 0.5848564396626493


100%|██████████| 3489/3489 [00:44<00:00, 77.66it/s]


Epoch 19 | Train Loss 0.5728211039354276


100%|██████████| 3489/3489 [00:44<00:00, 77.80it/s]


Epoch 20 | Train Loss 0.5753284990296353


100%|██████████| 3489/3489 [00:45<00:00, 77.12it/s]


Epoch 21 | Train Loss 0.5702449323139138


100%|██████████| 3489/3489 [00:45<00:00, 77.51it/s]


Epoch 22 | Train Loss 0.556767254642871


100%|██████████| 3489/3489 [00:44<00:00, 77.59it/s]


Epoch 23 | Train Loss 0.5442274558779913


100%|██████████| 3489/3489 [00:44<00:00, 78.26it/s]


Epoch 24 | Train Loss 0.5484283181667055


100%|██████████| 3489/3489 [00:44<00:00, 78.07it/s]


Epoch 25 | Train Loss 0.5429701846029944
------------fold no---------1----------------------
Loading model...
Number of parameters: 603201


100%|██████████| 3489/3489 [00:40<00:00, 86.79it/s]


Epoch 0 | Train Loss 1.167466047979351


100%|██████████| 3489/3489 [00:40<00:00, 87.09it/s]


Epoch 1 | Train Loss 0.9655971257008055


100%|██████████| 3489/3489 [00:40<00:00, 86.95it/s]


Epoch 2 | Train Loss 0.8710792568632647


100%|██████████| 3489/3489 [00:39<00:00, 87.42it/s]


Epoch 3 | Train Loss 0.8142041190796713


100%|██████████| 3489/3489 [00:40<00:00, 86.27it/s]


Epoch 4 | Train Loss 0.7761366219054529


100%|██████████| 3489/3489 [00:39<00:00, 87.28it/s]


Epoch 5 | Train Loss 0.743342198194457


100%|██████████| 3489/3489 [00:40<00:00, 86.29it/s]


Epoch 6 | Train Loss 0.7264724479265902


100%|██████████| 3489/3489 [00:40<00:00, 86.90it/s]


Epoch 7 | Train Loss 0.7023698577278834


100%|██████████| 3489/3489 [00:40<00:00, 87.19it/s]


Epoch 8 | Train Loss 0.6874341769413539


100%|██████████| 3489/3489 [00:39<00:00, 88.08it/s]


Epoch 9 | Train Loss 0.6785263733873739


100%|██████████| 3489/3489 [00:40<00:00, 85.45it/s]


Epoch 10 | Train Loss 0.6612366458711257


100%|██████████| 3489/3489 [00:41<00:00, 84.22it/s]


Epoch 11 | Train Loss 0.6471341470697952


100%|██████████| 3489/3489 [00:40<00:00, 85.33it/s]


Epoch 12 | Train Loss 0.6339767341069316


100%|██████████| 3489/3489 [00:40<00:00, 85.22it/s]


Epoch 13 | Train Loss 0.6173054520868236


100%|██████████| 3489/3489 [00:39<00:00, 87.31it/s]


Epoch 14 | Train Loss 0.6293190832477512


100%|██████████| 3489/3489 [00:40<00:00, 86.13it/s]


Epoch 15 | Train Loss 0.6146234892720853


100%|██████████| 3489/3489 [00:40<00:00, 86.38it/s]


Epoch 16 | Train Loss 0.5939726069565316


100%|██████████| 3489/3489 [00:40<00:00, 85.21it/s]


Epoch 17 | Train Loss 0.5872068751158124


100%|██████████| 3489/3489 [00:40<00:00, 86.79it/s] 


Epoch 18 | Train Loss 0.5948736057877028


100%|██████████| 3489/3489 [00:40<00:00, 87.13it/s]


Epoch 19 | Train Loss 0.5830852059472867


100%|██████████| 3489/3489 [00:41<00:00, 83.89it/s]


Epoch 20 | Train Loss 0.5830787751523633


100%|██████████| 3489/3489 [00:42<00:00, 82.51it/s]


Epoch 21 | Train Loss 0.5841556530489591


100%|██████████| 3489/3489 [00:40<00:00, 85.42it/s]


Epoch 22 | Train Loss 0.5673101478544303


100%|██████████| 3489/3489 [00:40<00:00, 86.09it/s]


Epoch 23 | Train Loss 0.5724917364414321


100%|██████████| 3489/3489 [00:39<00:00, 87.40it/s]


Epoch 24 | Train Loss 0.5526371796413355


100%|██████████| 3489/3489 [00:41<00:00, 83.92it/s]


Epoch 25 | Train Loss 0.5514641452498475
------------fold no---------2----------------------
Loading model...
Number of parameters: 603201


100%|██████████| 3489/3489 [00:41<00:00, 83.09it/s]


Epoch 0 | Train Loss 1.177031256350312


100%|██████████| 3489/3489 [00:42<00:00, 82.85it/s]


Epoch 1 | Train Loss 0.9671811308635959


100%|██████████| 3489/3489 [00:42<00:00, 82.48it/s]


Epoch 2 | Train Loss 0.8747912851624312


100%|██████████| 3489/3489 [00:40<00:00, 85.20it/s]


Epoch 3 | Train Loss 0.8351691210464681


100%|██████████| 3489/3489 [00:40<00:00, 86.84it/s]


Epoch 4 | Train Loss 0.7886793133427674


100%|██████████| 3489/3489 [00:39<00:00, 87.55it/s]


Epoch 5 | Train Loss 0.7497611338735549


100%|██████████| 3489/3489 [00:40<00:00, 85.60it/s]


Epoch 6 | Train Loss 0.7321076782248908


100%|██████████| 3489/3489 [00:40<00:00, 86.45it/s]


Epoch 7 | Train Loss 0.7066455634882608


100%|██████████| 3489/3489 [00:40<00:00, 85.91it/s]


Epoch 8 | Train Loss 0.708902508058839


100%|██████████| 3489/3489 [00:40<00:00, 86.97it/s]


Epoch 9 | Train Loss 0.6896535676815134


100%|██████████| 3489/3489 [00:39<00:00, 87.34it/s]


Epoch 10 | Train Loss 0.6681176896003366


100%|██████████| 3489/3489 [00:40<00:00, 86.04it/s]


Epoch 11 | Train Loss 0.6569029415262808


100%|██████████| 3489/3489 [00:41<00:00, 84.83it/s]


Epoch 12 | Train Loss 0.6426882435369881


100%|██████████| 3489/3489 [00:40<00:00, 85.68it/s]


Epoch 13 | Train Loss 0.6436633554885981


100%|██████████| 3489/3489 [00:40<00:00, 85.55it/s]


Epoch 14 | Train Loss 0.6334925521155523


100%|██████████| 3489/3489 [00:40<00:00, 85.59it/s]


Epoch 15 | Train Loss 0.621194542605322


100%|██████████| 3489/3489 [00:40<00:00, 86.73it/s]


Epoch 16 | Train Loss 0.6124637624396743


100%|██████████| 3489/3489 [00:40<00:00, 86.84it/s]


Epoch 17 | Train Loss 0.6192589832747963


100%|██████████| 3489/3489 [00:40<00:00, 86.40it/s]


Epoch 18 | Train Loss 0.6078869634275622


100%|██████████| 3489/3489 [00:40<00:00, 87.11it/s]


Epoch 19 | Train Loss 0.6034575132367501


100%|██████████| 3489/3489 [00:40<00:00, 86.68it/s]


Epoch 20 | Train Loss 0.5903534471655548


100%|██████████| 3489/3489 [00:40<00:00, 85.41it/s]


Epoch 21 | Train Loss 0.5782737185237365


100%|██████████| 3489/3489 [00:40<00:00, 86.37it/s]


Epoch 22 | Train Loss 0.5890736763137875


100%|██████████| 3489/3489 [00:39<00:00, 87.27it/s] 


Epoch 23 | Train Loss 0.5698961969833383


100%|██████████| 3489/3489 [00:40<00:00, 86.71it/s]


Epoch 24 | Train Loss 0.5666036839880141


100%|██████████| 3489/3489 [00:40<00:00, 86.42it/s]


Epoch 25 | Train Loss 0.5716678582561088
------------fold no---------3----------------------
Loading model...
Number of parameters: 603201


100%|██████████| 3489/3489 [00:40<00:00, 86.49it/s]


Epoch 0 | Train Loss 1.1363732136484885


100%|██████████| 3489/3489 [00:40<00:00, 85.29it/s]


Epoch 1 | Train Loss 0.9373501475850229


100%|██████████| 3489/3489 [00:40<00:00, 87.01it/s]


Epoch 2 | Train Loss 0.8444001801754007


100%|██████████| 3489/3489 [00:39<00:00, 87.53it/s]


Epoch 3 | Train Loss 0.8041024273498997


100%|██████████| 3489/3489 [00:39<00:00, 87.82it/s]


Epoch 4 | Train Loss 0.7669701090461034


100%|██████████| 3489/3489 [00:40<00:00, 86.97it/s]


Epoch 5 | Train Loss 0.7254588407896487


100%|██████████| 3489/3489 [00:40<00:00, 86.76it/s]


Epoch 6 | Train Loss 0.7024088729996161


100%|██████████| 3489/3489 [00:40<00:00, 86.28it/s]


Epoch 7 | Train Loss 0.6980083983986246


100%|██████████| 3489/3489 [00:40<00:00, 86.37it/s]


Epoch 8 | Train Loss 0.6846149569042571


100%|██████████| 3489/3489 [00:39<00:00, 88.03it/s]


Epoch 9 | Train Loss 0.6622887215664887


100%|██████████| 3489/3489 [00:41<00:00, 84.16it/s]


Epoch 10 | Train Loss 0.6523379901096793


100%|██████████| 3489/3489 [00:40<00:00, 86.69it/s]


Epoch 11 | Train Loss 0.6442390635682232


100%|██████████| 3489/3489 [00:40<00:00, 86.00it/s]


Epoch 12 | Train Loss 0.6300030524467937


100%|██████████| 3489/3489 [00:40<00:00, 85.12it/s]


Epoch 13 | Train Loss 0.6235618568395229


100%|██████████| 3489/3489 [00:40<00:00, 86.20it/s]


Epoch 14 | Train Loss 0.6130721245455379


100%|██████████| 3489/3489 [00:40<00:00, 86.43it/s]


Epoch 15 | Train Loss 0.6130888758979441


100%|██████████| 3489/3489 [00:40<00:00, 85.73it/s]


Epoch 16 | Train Loss 0.6019142514255403


100%|██████████| 3489/3489 [00:40<00:00, 86.02it/s]


Epoch 17 | Train Loss 0.6077766040696152


100%|██████████| 3489/3489 [00:40<00:00, 86.25it/s]


Epoch 18 | Train Loss 0.5967436968878779


100%|██████████| 3489/3489 [00:40<00:00, 85.58it/s]


Epoch 19 | Train Loss 0.5820718788808518


100%|██████████| 3489/3489 [00:40<00:00, 86.59it/s]


Epoch 20 | Train Loss 0.5768373218528109


100%|██████████| 3489/3489 [00:40<00:00, 85.98it/s]


Epoch 21 | Train Loss 0.5821122490478198


100%|██████████| 3489/3489 [00:40<00:00, 86.50it/s]


Epoch 22 | Train Loss 0.5709719275673982


100%|██████████| 3489/3489 [00:40<00:00, 86.34it/s]


Epoch 23 | Train Loss 0.5761011024281985


100%|██████████| 3489/3489 [00:40<00:00, 86.40it/s]


Epoch 24 | Train Loss 0.571330629193055


100%|██████████| 3489/3489 [00:40<00:00, 86.92it/s]


Epoch 25 | Train Loss 0.5526464006327804
------------fold no---------4----------------------
Loading model...
Number of parameters: 603201


100%|██████████| 3489/3489 [00:40<00:00, 86.96it/s]


Epoch 0 | Train Loss 1.1376556822702857


100%|██████████| 3489/3489 [00:40<00:00, 86.10it/s]


Epoch 1 | Train Loss 0.9551861236279248


100%|██████████| 3489/3489 [00:40<00:00, 86.27it/s]


Epoch 2 | Train Loss 0.8680471067324037


100%|██████████| 3489/3489 [00:40<00:00, 86.31it/s]


Epoch 3 | Train Loss 0.8219330529747736


100%|██████████| 3489/3489 [00:40<00:00, 87.07it/s]


Epoch 4 | Train Loss 0.777846456223212


100%|██████████| 3489/3489 [00:40<00:00, 85.99it/s]


Epoch 5 | Train Loss 0.7571967741067173


100%|██████████| 3489/3489 [00:41<00:00, 84.73it/s]


Epoch 6 | Train Loss 0.7306617027403257


100%|██████████| 3489/3489 [00:40<00:00, 86.59it/s]


Epoch 7 | Train Loss 0.7265330404171324


100%|██████████| 3489/3489 [00:39<00:00, 87.34it/s]


Epoch 8 | Train Loss 0.7045888543538974


100%|██████████| 3489/3489 [00:40<00:00, 85.69it/s]


Epoch 9 | Train Loss 0.6839661728596066


100%|██████████| 3489/3489 [00:40<00:00, 86.78it/s]


Epoch 10 | Train Loss 0.6614851817193979


100%|██████████| 3489/3489 [00:40<00:00, 87.15it/s]


Epoch 11 | Train Loss 0.6492462450076256


100%|██████████| 3489/3489 [00:39<00:00, 87.28it/s]


Epoch 12 | Train Loss 0.6476490008656771


100%|██████████| 3489/3489 [00:41<00:00, 84.32it/s]


Epoch 13 | Train Loss 0.6407257342584566


100%|██████████| 3489/3489 [00:40<00:00, 86.94it/s]


Epoch 14 | Train Loss 0.6256165124986328


100%|██████████| 3489/3489 [00:40<00:00, 86.60it/s]


Epoch 15 | Train Loss 0.6233129410265885


100%|██████████| 3489/3489 [00:40<00:00, 87.02it/s]


Epoch 16 | Train Loss 0.6132558261397473


100%|██████████| 3489/3489 [00:40<00:00, 87.10it/s]


Epoch 17 | Train Loss 0.5991500763932078


100%|██████████| 3489/3489 [00:40<00:00, 86.37it/s]


Epoch 18 | Train Loss 0.595588879101886


100%|██████████| 3489/3489 [00:39<00:00, 87.36it/s]


Epoch 19 | Train Loss 0.5806588030802756


100%|██████████| 3489/3489 [00:39<00:00, 87.29it/s]


Epoch 20 | Train Loss 0.5898111105396571


100%|██████████| 3489/3489 [00:39<00:00, 87.70it/s]


Epoch 21 | Train Loss 0.5892774240207488


100%|██████████| 3489/3489 [00:40<00:00, 86.34it/s]


Epoch 22 | Train Loss 0.5745638343795182


100%|██████████| 3489/3489 [00:40<00:00, 87.21it/s]


Epoch 23 | Train Loss 0.5671578931365152


100%|██████████| 3489/3489 [00:39<00:00, 87.31it/s]


Epoch 24 | Train Loss 0.5740221715026409


100%|██████████| 3489/3489 [00:39<00:00, 87.29it/s]

Epoch 25 | Train Loss 0.561385283310166



