In [2]:
import os
%pwd

'/mnt/cb03386d-9344-47b1-82f9-868fbb64b4ae/python_projects/HIV_inhibitors_classification_and_generation/research'

In [3]:
os.chdir("../")
%pwd

'/mnt/cb03386d-9344-47b1-82f9-868fbb64b4ae/python_projects/HIV_inhibitors_classification_and_generation'

In [4]:
from dataclasses import dataclass
from pathlib import Path
from typing import List

@dataclass(frozen=True)
class ModelTrainerConfig:
    root_dir: Path
    models: Path
    stats: Path
    source_root: Path
    processed_root: Path
    source_filename: str
    processed_filename: List[str]
    tuning: bool
    params: dict

In [5]:
from hivclass.constants import *
from hivclass.utils.main_utils import create_directories, read_yaml

class ConfigurationManager:
    def __init__(
        self,
        config_file_path = CONFIG_FILE_PATH,
        params_file_path = PARAMS_FILE_PATH,
        schema_file_path = SCHEMA_FILE_PATH
    ):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)
        self.schema = read_yaml(schema_file_path)
        
        create_directories([self.config.artifacts_root])
    
    def get_model_trainer_config(self) -> ModelTrainerConfig:
        config = self.config.model_trainer
        params = self.params
        
        create_directories([config.root_dir, config.models, config.stats])
        
        model_trainer_config = ModelTrainerConfig(
            root_dir=config.root_dir,
            models=config.models,
            stats=config.stats,
            source_root=config.source_root,
            processed_root=config.processed_root,
            source_filename=config.source_filename,
            processed_filename=config.processed_filename,
            tuning=config.tuning,
            params=params
        )
        
        return model_trainer_config

In [9]:
from hivclass.utils.molecule_dataset import MoleculeDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score
import yaml
import numpy as np
import pandas as pd
from hivclass.utils.molecule_dataset import MoleculeDataset
from hivclass.utils.mol_gnn import MolGNN
from hivclass.utils.main_utils import save_json, plot_metric, plot_confusion_matrix, plot_roc_curve
import torch 
from torch_geometric.data import DataLoader
from box import ConfigBox
import sys
from tqdm import tqdm
from mango import Tuner

class ModelTrainer:
    def __init__(self, config: ModelTrainerConfig):
        self.config = config
    
    def train_val_separation(self, train_dataset):
        data_df = pd.read_csv(os.path.join(self.config.processed_root, self.config.processed_filename[3]))
        
        # data_name_list = os.listdir(os.path.join(self.config.processed_root, self.config.processed_filename[1]))
        # data_idxs = [int(name.split('.')[0].split('_')[1]) for name in data_name_list]
        # data_labels = [data_df.HIV_active[i] for i in data_idxs]
        
        train_df, val_df = train_test_split(
            data_df,
            test_size=self.config.params.data_transformation.val_size,
            stratify=data_df.HIV_active,
            random_state=42
        )
        
        # train_idxs, val_idxs, _, _ = train_test_split(
        #     data_idxs,
        #     data_labels,
        #     test_size=self.config.model_params.val_size,
        #     stratify=data_labels,
        #     random_state=42
        # )
        
        train_idxs = train_df.index.tolist()
        val_idxs = val_df.index.tolist()
        
        train = train_dataset.index_select(train_idxs)
        val = train_dataset.index_select(val_idxs)
        
        return train, val
    
    def train(self, params, epoch, model, train_loader, optimizer, criterion, device):
        model.train()
        total_loss = 0.0
        train_preds = []
        train_labels = []
        
        for i, batch in tqdm(enumerate(train_loader)):
            batch.to(device)
            optimizer.zero_grad()
            
            preds = model(batch.x.float(), batch.edge_attr.float(), batch.edge_index, batch.batch)
            train_preds.extend(np.rint(torch.sigmoid(preds).cpu().detach().numpy()))
            train_labels.extend(batch.y.cpu().detach().numpy())
            
            loss = criterion(torch.squeeze(preds), batch.y.float())
            loss.backward()
            optimizer.step()
            
            accuracy = accuracy_score(train_labels, train_preds)
            
            total_loss += loss.item()
            
            print()
            sys.stdout.write(
                "Epoch:%2d/%2d - Batch:%2d/%2d - train_loss:%.4f - train_accuracy:%.4f" %(
                    epoch,
                    params.num_epochs,
                    i,
                    len(train_loader),
                    loss.item(),
                    accuracy
                )
            )
            sys.stdout.flush()
        
        return total_loss / len(train_loader), accuracy
    
    def validation(self, epoch, model, val_loader, criterion, best_val_loss, stats_path,  device):
        model.eval()
        total_loss = 0.0
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for batch in tqdm(val_loader):
                batch.to(device)
                
                preds = model(batch.x.float(), batch.edge_attr.float(), batch.edge_index, batch.batch)
                val_preds.extend(torch.round(torch.squeeze(preds)).cpu().detach().numpy())
                val_labels.extend(batch.y.cpu().detach().numpy())
                
                loss = criterion(torch.squeeze(preds), batch.y.float())
                total_loss += loss.item()
                accuracy = accuracy_score(val_labels, val_preds)
            
            epoch_loss = total_loss / len(val_loader)
            
            if epoch_loss < best_val_loss:
                report = classification_report(
                    val_labels,
                    val_preds,
                    zero_division=0,
                    output_dict=True
                )

                save_json(
                    os.path.join(stats_path, f'report_{epoch}.json'),
                    report
                )

                conf_matrix = confusion_matrix(val_labels, val_preds)

                plot_confusion_matrix(
                    conf_matrix,
                    stats_path,
                    epoch,
                    f'Confusion Matrix for epoch: {epoch}'
                )

                auc_score = roc_auc_score(val_labels, val_preds)
                auc_score_dict = {'auc_score': auc_score}

                save_json(
                    os.path.join(stats_path, f'auc_score_{epoch}.json'), 
                    auc_score_dict
                )
                
                plot_roc_curve(
                    val_labels,
                    val_preds,
                    stats_path,
                    epoch
                )
                
        return epoch_loss, accuracy
    
    def train_tuning(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("device:", device)
        
        dataset = MoleculeDataset(
            self.config.source_root,
            self.config.processed_root,
            self.config.source_filename,
            self.config.processed_filename
        )
        
        train_dataset, val_dataset = self.train_val_separation(dataset)
        
        def train_compose(params):
            params = params[0]
            
            if self.config.tuning:
                folder_name = str(len(os.listdir(self.config.stats)) + 1)
            else:
                folder_name = "best_params"
            
            models_path = os.path.join(self.config.models, folder_name)
            stats_path = os.path.join(self.config.stats, folder_name)
            
            create_directories([models_path, stats_path])
            
            with open(os.path.join(stats_path, "params.yaml"), 'w') as file:
                file.write(yaml.dump(params, sort_keys=False))
            
            params = ConfigBox(params)
            
            train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=params['batch_size'], shuffle=False)
            params["model_edge_dim"] = train_dataset[0].edge_attr.shape[1]
            
            print("Loading model...")
            model_params = ConfigBox({k: v for k, v in params.items() if k.startswith("model_")})
            model = MolGNN(feature_size=train_dataset[0].x.shape[1], model_params=model_params)
            model = model.to(device)
            
            weight = torch.tensor([params["pos_weight"]], dtype=torch.float32).to(device)
            criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weight)
            optimizer = torch.optim.SGD(
                model.parameters(),
                lr=params['learning_rate'],
                momentum=params['sgd_momentum'],
                weight_decay=params['weight_decay']
            )
            
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params.scheduler_gamma)
            
            train_losses = []
            val_losses = []
            train_accuracies = []
            val_accuracies = []
            best_val_loss = float('inf')
            early_stopping_counter = 0
            epochs_range = range(1, params.num_epochs + 1)
            
            for epoch in tqdm(range(params.num_epochs)):
                if early_stopping_counter <= 10:
                    train_epoch_loss, train_epoch_acc = self.train(
                        params,
                        epoch,
                        model,
                        train_loader,
                        optimizer,
                        criterion,
                        device
                    )
                    
                    train_losses.append(train_epoch_loss)
                    train_accuracies.append(train_epoch_acc)
                    
                    val_epoch_loss, val_epoch_acc = self.validation(
                        epoch,
                        model,
                        val_loader,
                        criterion,
                        best_val_loss,
                        stats_path,
                        device
                    )
                    
                    val_losses.append(val_epoch_loss)
                    val_accuracies.append(val_epoch_acc)
                    
                    print(f'Epoch [{epoch+1}/{params.num_epochs}], '
                        f'Loss: {train_epoch_loss:.4f}, '
                        f'Validation Loss: {val_epoch_loss:.4f}, '
                        f'Train Accuracy: {train_epoch_acc:.2f}%, '
                        f'Validation Accuracy: {val_epoch_acc:.2f}%')
                    
                    if float(val_epoch_loss) < best_val_loss:
                        torch.save(model.state_dict(), os.path.join(models_path, f'model_{epoch}.pth'))
                        best_val_loss = float(val_epoch_loss)
                        early_stopping_counter = 0
                    else:
                        early_stopping_counter += 1
                    
                    scheduler.step()
                else:
                    print("Early stopping due to no improvement.")
                    
                    plot_metric(
                        stats_path,
                        epochs_range,
                        train_losses,
                        val_losses,
                        'Train Loss',
                        'Validation Loss',
                    )
                    
                    
                    plot_metric(
                        stats_path,
                        epochs_range,
                        train_accuracies,
                        val_accuracies,
                        'Train Accuracies',
                        'Validation Accuracies',
                    )
                    
                    return [best_val_loss]
            
            print(f"Finishing training with best test loss: {best_val_loss}")

            plot_metric(
                stats_path,
                epochs_range,
                train_losses,
                val_losses,
                'Train Loss',
                'Validation Loss',
            )
            
            
            plot_metric(
                stats_path,
                epochs_range,
                train_accuracies,
                val_accuracies,
                'Train Accuracies',
                'Validation Accuracies',
            )
            
            return [best_val_loss]
        
        if self.config.tuning:
            print("Running hyperparameter search...")
            params = self.config.params.HYPERPARAMETERS
            config = dict()
            config["optimizer"] = "Bayesian"
            config["num_iteration"] = params.tuning_iterations
            
            tuner = Tuner(params, objective=train_compose, conf_dict=config)
            
            results = tuner.minimize()
            
            self.config.params['BEST_PARAMETERS'] = results['best_parameters']
            best_params = yaml.save_dump(self.config.params, sort_keys=False)
            
            with open(PARAMS_FILE_PATH, 'w') as file:
                file.write(best_params)
        else:
            params = self.config.params.BEST_PARAMETERS
            best_val_loss = train_compose(params)

In [10]:
try:
    config = ConfigurationManager()
    model_trainer_config = config.get_model_trainer_config()
    model_trainer = ModelTrainer(config=model_trainer_config)
    results = model_trainer.train_tuning()

except Exception as e:
    raise e

[2025-04-08 23:03:21,425: INFO: main_utils: created directory at: artifacts]
[2025-04-08 23:03:21,427: INFO: main_utils: created directory at: artifacts/model_trainer]
[2025-04-08 23:03:21,428: INFO: main_utils: created directory at: artifacts/model_trainer/models]
[2025-04-08 23:03:21,429: INFO: main_utils: created directory at: artifacts/model_trainer/stats]
device: cuda
Running hyperparameter search...
[2025-04-08 23:03:21,812: INFO: main_utils: created directory at: artifacts/model_trainer/models/1]
[2025-04-08 23:03:21,814: INFO: main_utils: created directory at: artifacts/model_trainer/stats/1]
Loading model...


  0%|          | 0/100 [00:00<?, ?it/s]


Epoch: 0/100 - Batch: 0/971 - train_loss:0.8182 - train_accuracy:0.6094




Epoch: 0/100 - Batch: 1/971 - train_loss:1.1937 - train_accuracy:0.5547




Epoch: 0/100 - Batch: 2/971 - train_loss:1.1874 - train_accuracy:0.5469




Epoch: 0/100 - Batch: 3/971 - train_loss:1.2503 - train_accuracy:0.5078




Epoch: 0/100 - Batch: 4/971 - train_loss:0.9714 - train_accuracy:0.5188




Epoch: 0/100 - Batch: 5/971 - train_loss:0.6226 - train_accuracy:0.5495




Epoch: 0/100 - Batch: 6/971 - train_loss:1.0320 - train_accuracy:0.5402




Epoch: 0/100 - Batch: 7/971 - train_loss:1.0393 - train_accuracy:0.5469




Epoch: 0/100 - Batch: 8/971 - train_loss:0.9381 - train_accuracy:0.5451




Epoch: 0/100 - Batch: 9/971 - train_loss:0.8966 - train_accuracy:0.5469




Epoch: 0/100 - Batch:10/971 - train_loss:0.9014 - train_accuracy:0.5526




Epoch: 0/100 - Batch:11/971 - train_loss:1.0795 - train_accuracy:0.5508




Epoch: 0/100 - Batch:12/971 - train_loss:0.8659 - train_accuracy:0.5517




Epoch: 0/100 - Batch:13/971 - train_loss:1.0462 - train_accuracy:0.5458




Epoch: 0/100 - Batch:14/971 - train_loss:1.1165 - train_accuracy:0.5365




Epoch: 0/100 - Batch:15/971 - train_loss:0.9318 - train_accuracy:0.5273




Epoch: 0/100 - Batch:16/971 - train_loss:1.0089 - train_accuracy:0.5221




Epoch: 0/100 - Batch:17/971 - train_loss:0.7601 - train_accuracy:0.5234




Epoch: 0/100 - Batch:18/971 - train_loss:0.8505 - train_accuracy:0.5197




Epoch: 0/100 - Batch:19/971 - train_loss:0.9841 - train_accuracy:0.5141




Epoch: 0/100 - Batch:20/971 - train_loss:0.9207 - train_accuracy:0.5156




Epoch: 0/100 - Batch:21/971 - train_loss:0.9607 - train_accuracy:0.5156




Epoch: 0/100 - Batch:22/971 - train_loss:0.8830 - train_accuracy:0.5136




Epoch: 0/100 - Batch:23/971 - train_loss:0.7277 - train_accuracy:0.5189




Epoch: 0/100 - Batch:24/971 - train_loss:0.7456 - train_accuracy:0.5206




Epoch: 0/100 - Batch:25/971 - train_loss:0.8299 - train_accuracy:0.5162




Epoch: 0/100 - Batch:26/971 - train_loss:0.7190 - train_accuracy:0.5191




Epoch: 0/100 - Batch:27/971 - train_loss:0.7295 - train_accuracy:0.5229




Epoch: 0/100 - Batch:28/971 - train_loss:0.7443 - train_accuracy:0.5221




Epoch: 0/100 - Batch:29/971 - train_loss:0.8027 - train_accuracy:0.5224




Epoch: 0/100 - Batch:30/971 - train_loss:0.9581 - train_accuracy:0.5232




Epoch: 0/100 - Batch:31/971 - train_loss:0.8869 - train_accuracy:0.5225




Epoch: 0/100 - Batch:32/971 - train_loss:0.9309 - train_accuracy:0.5241




Epoch: 0/100 - Batch:33/971 - train_loss:0.9192 - train_accuracy:0.5239




Epoch: 0/100 - Batch:34/971 - train_loss:0.7918 - train_accuracy:0.5237




Epoch: 0/100 - Batch:35/971 - train_loss:0.8362 - train_accuracy:0.5221




Epoch: 0/100 - Batch:36/971 - train_loss:0.8320 - train_accuracy:0.5207




Epoch: 0/100 - Batch:37/971 - train_loss:0.8613 - train_accuracy:0.5206




Epoch: 0/100 - Batch:38/971 - train_loss:0.8119 - train_accuracy:0.5184




Epoch: 0/100 - Batch:39/971 - train_loss:0.7210 - train_accuracy:0.5184




Epoch: 0/100 - Batch:40/971 - train_loss:0.7323 - train_accuracy:0.5194




Epoch: 0/100 - Batch:41/971 - train_loss:0.7462 - train_accuracy:0.5201




Epoch: 0/100 - Batch:42/971 - train_loss:0.7805 - train_accuracy:0.5174




Epoch: 0/100 - Batch:43/971 - train_loss:0.7595 - train_accuracy:0.5178




Epoch: 0/100 - Batch:44/971 - train_loss:0.6954 - train_accuracy:0.5198




Epoch: 0/100 - Batch:45/971 - train_loss:0.7691 - train_accuracy:0.5187




Epoch: 0/100 - Batch:46/971 - train_loss:0.8586 - train_accuracy:0.5166




Epoch: 0/100 - Batch:47/971 - train_loss:0.7153 - train_accuracy:0.5182




Epoch: 0/100 - Batch:48/971 - train_loss:0.7678 - train_accuracy:0.5188




Epoch: 0/100 - Batch:49/971 - train_loss:0.7387 - train_accuracy:0.5203




Epoch: 0/100 - Batch:50/971 - train_loss:0.6947 - train_accuracy:0.5208




Epoch: 0/100 - Batch:51/971 - train_loss:0.7757 - train_accuracy:0.5195




Epoch: 0/100 - Batch:52/971 - train_loss:0.6991 - train_accuracy:0.5200




Epoch: 0/100 - Batch:53/971 - train_loss:0.7721 - train_accuracy:0.5188




Epoch: 0/100 - Batch:54/971 - train_loss:0.6571 - train_accuracy:0.5205




Epoch: 0/100 - Batch:55/971 - train_loss:0.7738 - train_accuracy:0.5190




Epoch: 0/100 - Batch:56/971 - train_loss:0.7242 - train_accuracy:0.5178




Epoch: 0/100 - Batch:57/971 - train_loss:0.8026 - train_accuracy:0.5167




Epoch: 0/100 - Batch:58/971 - train_loss:0.7331 - train_accuracy:0.5164




Epoch: 0/100 - Batch:59/971 - train_loss:0.7790 - train_accuracy:0.5169




Epoch: 0/100 - Batch:60/971 - train_loss:0.7493 - train_accuracy:0.5169




Epoch: 0/100 - Batch:61/971 - train_loss:0.7837 - train_accuracy:0.5159




Epoch: 0/100 - Batch:62/971 - train_loss:0.7238 - train_accuracy:0.5166




Epoch: 0/100 - Batch:63/971 - train_loss:0.7869 - train_accuracy:0.5161




Epoch: 0/100 - Batch:64/971 - train_loss:0.6964 - train_accuracy:0.5168




Epoch: 0/100 - Batch:65/971 - train_loss:0.7750 - train_accuracy:0.5161




Epoch: 0/100 - Batch:66/971 - train_loss:0.7498 - train_accuracy:0.5145




Epoch: 0/100 - Batch:67/971 - train_loss:0.7166 - train_accuracy:0.5145




Epoch: 0/100 - Batch:68/971 - train_loss:0.6664 - train_accuracy:0.5149




Epoch: 0/100 - Batch:69/971 - train_loss:0.7087 - train_accuracy:0.5161




Epoch: 0/100 - Batch:70/971 - train_loss:0.7007 - train_accuracy:0.5163




Epoch: 0/100 - Batch:71/971 - train_loss:0.7464 - train_accuracy:0.5152




Epoch: 0/100 - Batch:72/971 - train_loss:0.6817 - train_accuracy:0.5158




Epoch: 0/100 - Batch:73/971 - train_loss:0.7606 - train_accuracy:0.5148




Epoch: 0/100 - Batch:74/971 - train_loss:0.7257 - train_accuracy:0.5148




Epoch: 0/100 - Batch:75/971 - train_loss:0.7531 - train_accuracy:0.5138




Epoch: 0/100 - Batch:76/971 - train_loss:0.7491 - train_accuracy:0.5132




Epoch: 0/100 - Batch:77/971 - train_loss:0.7223 - train_accuracy:0.5138




Epoch: 0/100 - Batch:78/971 - train_loss:0.7917 - train_accuracy:0.5127




Epoch: 0/100 - Batch:79/971 - train_loss:0.8111 - train_accuracy:0.5123




Epoch: 0/100 - Batch:80/971 - train_loss:0.7758 - train_accuracy:0.5116




Epoch: 0/100 - Batch:81/971 - train_loss:0.7538 - train_accuracy:0.5114




Epoch: 0/100 - Batch:82/971 - train_loss:0.7967 - train_accuracy:0.5107




Epoch: 0/100 - Batch:83/971 - train_loss:0.7123 - train_accuracy:0.5113




Epoch: 0/100 - Batch:84/971 - train_loss:0.7040 - train_accuracy:0.5121




Epoch: 0/100 - Batch:85/971 - train_loss:0.6867 - train_accuracy:0.5125




Epoch: 0/100 - Batch:86/971 - train_loss:0.7995 - train_accuracy:0.5119




Epoch: 0/100 - Batch:87/971 - train_loss:0.6858 - train_accuracy:0.5131




Epoch: 0/100 - Batch:88/971 - train_loss:0.7661 - train_accuracy:0.5123




Epoch: 0/100 - Batch:89/971 - train_loss:0.7318 - train_accuracy:0.5123




Epoch: 0/100 - Batch:90/971 - train_loss:0.7207 - train_accuracy:0.5120




Epoch: 0/100 - Batch:91/971 - train_loss:0.7973 - train_accuracy:0.5110




Epoch: 0/100 - Batch:92/971 - train_loss:0.6573 - train_accuracy:0.5118




Epoch: 0/100 - Batch:93/971 - train_loss:0.6451 - train_accuracy:0.5146




Epoch: 0/100 - Batch:94/971 - train_loss:0.7614 - train_accuracy:0.5146




Epoch: 0/100 - Batch:95/971 - train_loss:0.6586 - train_accuracy:0.5158




Epoch: 0/100 - Batch:96/971 - train_loss:0.6966 - train_accuracy:0.5159




Epoch: 0/100 - Batch:97/971 - train_loss:0.7329 - train_accuracy:0.5158




Epoch: 0/100 - Batch:98/971 - train_loss:0.7441 - train_accuracy:0.5148




Epoch: 0/100 - Batch:99/971 - train_loss:0.7350 - train_accuracy:0.5142




Epoch: 0/100 - Batch:100/971 - train_loss:0.7145 - train_accuracy:0.5147




Epoch: 0/100 - Batch:101/971 - train_loss:0.6946 - train_accuracy:0.5150




Epoch: 0/100 - Batch:102/971 - train_loss:0.7105 - train_accuracy:0.5147




Epoch: 0/100 - Batch:103/971 - train_loss:0.7363 - train_accuracy:0.5140




Epoch: 0/100 - Batch:104/971 - train_loss:0.7177 - train_accuracy:0.5137




Epoch: 0/100 - Batch:105/971 - train_loss:0.6978 - train_accuracy:0.5137




Epoch: 0/100 - Batch:106/971 - train_loss:0.6869 - train_accuracy:0.5139




Epoch: 0/100 - Batch:107/971 - train_loss:0.6728 - train_accuracy:0.5145




Epoch: 0/100 - Batch:108/971 - train_loss:0.7186 - train_accuracy:0.5140




Epoch: 0/100 - Batch:109/971 - train_loss:0.7504 - train_accuracy:0.5132




Epoch: 0/100 - Batch:110/971 - train_loss:0.7249 - train_accuracy:0.5132




Epoch: 0/100 - Batch:111/971 - train_loss:0.7024 - train_accuracy:0.5131




Epoch: 0/100 - Batch:112/971 - train_loss:0.7256 - train_accuracy:0.5130




Epoch: 0/100 - Batch:113/971 - train_loss:0.7328 - train_accuracy:0.5129




Epoch: 0/100 - Batch:114/971 - train_loss:0.7290 - train_accuracy:0.5125




Epoch: 0/100 - Batch:115/971 - train_loss:0.7146 - train_accuracy:0.5128




Epoch: 0/100 - Batch:116/971 - train_loss:0.7295 - train_accuracy:0.5130




Epoch: 0/100 - Batch:117/971 - train_loss:0.6700 - train_accuracy:0.5134




Epoch: 0/100 - Batch:118/971 - train_loss:0.7257 - train_accuracy:0.5131




Epoch: 0/100 - Batch:119/971 - train_loss:0.7122 - train_accuracy:0.5129




Epoch: 0/100 - Batch:120/971 - train_loss:0.7143 - train_accuracy:0.5130




Epoch: 0/100 - Batch:121/971 - train_loss:0.8095 - train_accuracy:0.5126




Epoch: 0/100 - Batch:122/971 - train_loss:0.6991 - train_accuracy:0.5121




Epoch: 0/100 - Batch:123/971 - train_loss:0.7192 - train_accuracy:0.5120




Epoch: 0/100 - Batch:124/971 - train_loss:0.6971 - train_accuracy:0.5119




Epoch: 0/100 - Batch:125/971 - train_loss:0.6735 - train_accuracy:0.5130




Epoch: 0/100 - Batch:126/971 - train_loss:0.6957 - train_accuracy:0.5133




Epoch: 0/100 - Batch:127/971 - train_loss:0.6692 - train_accuracy:0.5140




Epoch: 0/100 - Batch:128/971 - train_loss:0.7312 - train_accuracy:0.5143




Epoch: 0/100 - Batch:129/971 - train_loss:0.7451 - train_accuracy:0.5141




Epoch: 0/100 - Batch:130/971 - train_loss:0.7397 - train_accuracy:0.5134




Epoch: 0/100 - Batch:131/971 - train_loss:0.7258 - train_accuracy:0.5129




Epoch: 0/100 - Batch:132/971 - train_loss:0.6965 - train_accuracy:0.5127




Epoch: 0/100 - Batch:133/971 - train_loss:0.7282 - train_accuracy:0.5126




Epoch: 0/100 - Batch:134/971 - train_loss:0.6652 - train_accuracy:0.5134




Epoch: 0/100 - Batch:135/971 - train_loss:0.7344 - train_accuracy:0.5134




Epoch: 0/100 - Batch:136/971 - train_loss:0.7502 - train_accuracy:0.5122




Epoch: 0/100 - Batch:137/971 - train_loss:0.7245 - train_accuracy:0.5122




Epoch: 0/100 - Batch:138/971 - train_loss:0.7380 - train_accuracy:0.5118




Epoch: 0/100 - Batch:139/971 - train_loss:0.7076 - train_accuracy:0.5119




Epoch: 0/100 - Batch:140/971 - train_loss:0.6682 - train_accuracy:0.5125




Epoch: 0/100 - Batch:141/971 - train_loss:0.6660 - train_accuracy:0.5130




Epoch: 0/100 - Batch:142/971 - train_loss:0.7224 - train_accuracy:0.5130




Epoch: 0/100 - Batch:143/971 - train_loss:0.6989 - train_accuracy:0.5130




Epoch: 0/100 - Batch:144/971 - train_loss:0.7082 - train_accuracy:0.5135




Epoch: 0/100 - Batch:145/971 - train_loss:0.7439 - train_accuracy:0.5126




Epoch: 0/100 - Batch:146/971 - train_loss:0.7324 - train_accuracy:0.5122




Epoch: 0/100 - Batch:147/971 - train_loss:0.6630 - train_accuracy:0.5126




Epoch: 0/100 - Batch:148/971 - train_loss:0.6832 - train_accuracy:0.5122




Epoch: 0/100 - Batch:149/971 - train_loss:0.7440 - train_accuracy:0.5117




Epoch: 0/100 - Batch:150/971 - train_loss:0.6995 - train_accuracy:0.5115




Epoch: 0/100 - Batch:151/971 - train_loss:0.6949 - train_accuracy:0.5113




Epoch: 0/100 - Batch:152/971 - train_loss:0.7457 - train_accuracy:0.5111




Epoch: 0/100 - Batch:153/971 - train_loss:0.7058 - train_accuracy:0.5113




Epoch: 0/100 - Batch:154/971 - train_loss:0.7299 - train_accuracy:0.5112




Epoch: 0/100 - Batch:155/971 - train_loss:0.6841 - train_accuracy:0.5117




Epoch: 0/100 - Batch:156/971 - train_loss:0.6698 - train_accuracy:0.5124




Epoch: 0/100 - Batch:157/971 - train_loss:0.7171 - train_accuracy:0.5120




Epoch: 0/100 - Batch:158/971 - train_loss:0.6985 - train_accuracy:0.5123




Epoch: 0/100 - Batch:159/971 - train_loss:0.7384 - train_accuracy:0.5121




Epoch: 0/100 - Batch:160/971 - train_loss:0.7199 - train_accuracy:0.5115




Epoch: 0/100 - Batch:161/971 - train_loss:0.7203 - train_accuracy:0.5106




Epoch: 0/100 - Batch:162/971 - train_loss:0.6937 - train_accuracy:0.5113




Epoch: 0/100 - Batch:163/971 - train_loss:0.6559 - train_accuracy:0.5116




Epoch: 0/100 - Batch:164/971 - train_loss:0.7074 - train_accuracy:0.5118




Epoch: 0/100 - Batch:165/971 - train_loss:0.7362 - train_accuracy:0.5117




Epoch: 0/100 - Batch:166/971 - train_loss:0.7156 - train_accuracy:0.5114




Epoch: 0/100 - Batch:167/971 - train_loss:0.7102 - train_accuracy:0.5113




Epoch: 0/100 - Batch:168/971 - train_loss:0.6795 - train_accuracy:0.5117




Epoch: 0/100 - Batch:169/971 - train_loss:0.7217 - train_accuracy:0.5117




Epoch: 0/100 - Batch:170/971 - train_loss:0.7305 - train_accuracy:0.5116




Epoch: 0/100 - Batch:171/971 - train_loss:0.7293 - train_accuracy:0.5111




Epoch: 0/100 - Batch:172/971 - train_loss:0.7152 - train_accuracy:0.5105




Epoch: 0/100 - Batch:173/971 - train_loss:0.6557 - train_accuracy:0.5111




Epoch: 0/100 - Batch:174/971 - train_loss:0.7294 - train_accuracy:0.5105




Epoch: 0/100 - Batch:175/971 - train_loss:0.7044 - train_accuracy:0.5104




Epoch: 0/100 - Batch:176/971 - train_loss:0.7175 - train_accuracy:0.5105




Epoch: 0/100 - Batch:177/971 - train_loss:0.6884 - train_accuracy:0.5105




Epoch: 0/100 - Batch:178/971 - train_loss:0.7039 - train_accuracy:0.5104




Epoch: 0/100 - Batch:179/971 - train_loss:0.7064 - train_accuracy:0.5102




Epoch: 0/100 - Batch:180/971 - train_loss:0.6751 - train_accuracy:0.5104




Epoch: 0/100 - Batch:181/971 - train_loss:0.7028 - train_accuracy:0.5107




Epoch: 0/100 - Batch:182/971 - train_loss:0.7352 - train_accuracy:0.5103




Epoch: 0/100 - Batch:183/971 - train_loss:0.7357 - train_accuracy:0.5106




Epoch: 0/100 - Batch:184/971 - train_loss:0.6831 - train_accuracy:0.5108




Epoch: 0/100 - Batch:185/971 - train_loss:0.7073 - train_accuracy:0.5103




Epoch: 0/100 - Batch:186/971 - train_loss:0.6968 - train_accuracy:0.5104




Epoch: 0/100 - Batch:187/971 - train_loss:0.6962 - train_accuracy:0.5106




Epoch: 0/100 - Batch:188/971 - train_loss:0.7023 - train_accuracy:0.5109




Epoch: 0/100 - Batch:189/971 - train_loss:0.6880 - train_accuracy:0.5108




Epoch: 0/100 - Batch:190/971 - train_loss:0.7067 - train_accuracy:0.5104




Epoch: 0/100 - Batch:191/971 - train_loss:0.7120 - train_accuracy:0.5103




Epoch: 0/100 - Batch:192/971 - train_loss:0.7075 - train_accuracy:0.5103




Epoch: 0/100 - Batch:193/971 - train_loss:0.7059 - train_accuracy:0.5101




Epoch: 0/100 - Batch:194/971 - train_loss:0.6890 - train_accuracy:0.5104




Epoch: 0/100 - Batch:195/971 - train_loss:0.6882 - train_accuracy:0.5106




Epoch: 0/100 - Batch:196/971 - train_loss:0.7405 - train_accuracy:0.5108




Epoch: 0/100 - Batch:197/971 - train_loss:0.7214 - train_accuracy:0.5106




Epoch: 0/100 - Batch:198/971 - train_loss:0.6859 - train_accuracy:0.5109




Epoch: 0/100 - Batch:199/971 - train_loss:0.6706 - train_accuracy:0.5115




Epoch: 0/100 - Batch:200/971 - train_loss:0.7056 - train_accuracy:0.5110




Epoch: 0/100 - Batch:201/971 - train_loss:0.6899 - train_accuracy:0.5111




Epoch: 0/100 - Batch:202/971 - train_loss:0.7571 - train_accuracy:0.5105




Epoch: 0/100 - Batch:203/971 - train_loss:0.6800 - train_accuracy:0.5107




Epoch: 0/100 - Batch:204/971 - train_loss:0.7182 - train_accuracy:0.5104




Epoch: 0/100 - Batch:205/971 - train_loss:0.6819 - train_accuracy:0.5110




Epoch: 0/100 - Batch:206/971 - train_loss:0.6935 - train_accuracy:0.5112




Epoch: 0/100 - Batch:207/971 - train_loss:0.6923 - train_accuracy:0.5112




Epoch: 0/100 - Batch:208/971 - train_loss:0.7137 - train_accuracy:0.5110




Epoch: 0/100 - Batch:209/971 - train_loss:0.7226 - train_accuracy:0.5109




Epoch: 0/100 - Batch:210/971 - train_loss:0.7257 - train_accuracy:0.5105




Epoch: 0/100 - Batch:211/971 - train_loss:0.7200 - train_accuracy:0.5104




Epoch: 0/100 - Batch:212/971 - train_loss:0.6505 - train_accuracy:0.5109




Epoch: 0/100 - Batch:213/971 - train_loss:0.7049 - train_accuracy:0.5107




Epoch: 0/100 - Batch:214/971 - train_loss:0.6930 - train_accuracy:0.5110




Epoch: 0/100 - Batch:215/971 - train_loss:0.6986 - train_accuracy:0.5109




Epoch: 0/100 - Batch:216/971 - train_loss:0.6911 - train_accuracy:0.5110




Epoch: 0/100 - Batch:217/971 - train_loss:0.7348 - train_accuracy:0.5104




Epoch: 0/100 - Batch:218/971 - train_loss:0.6939 - train_accuracy:0.5103




Epoch: 0/100 - Batch:219/971 - train_loss:0.7001 - train_accuracy:0.5104




Epoch: 0/100 - Batch:220/971 - train_loss:0.7266 - train_accuracy:0.5101




Epoch: 0/100 - Batch:221/971 - train_loss:0.6966 - train_accuracy:0.5100




Epoch: 0/100 - Batch:222/971 - train_loss:0.6718 - train_accuracy:0.5105




Epoch: 0/100 - Batch:223/971 - train_loss:0.7259 - train_accuracy:0.5100




Epoch: 0/100 - Batch:224/971 - train_loss:0.7031 - train_accuracy:0.5099




Epoch: 0/100 - Batch:225/971 - train_loss:0.6688 - train_accuracy:0.5100




Epoch: 0/100 - Batch:226/971 - train_loss:0.6972 - train_accuracy:0.5103




Epoch: 0/100 - Batch:227/971 - train_loss:0.7096 - train_accuracy:0.5101




Epoch: 0/100 - Batch:228/971 - train_loss:0.7225 - train_accuracy:0.5100




Epoch: 0/100 - Batch:229/971 - train_loss:0.7039 - train_accuracy:0.5098




Epoch: 0/100 - Batch:230/971 - train_loss:0.7306 - train_accuracy:0.5095




Epoch: 0/100 - Batch:231/971 - train_loss:0.7184 - train_accuracy:0.5093




Epoch: 0/100 - Batch:232/971 - train_loss:0.6775 - train_accuracy:0.5096




Epoch: 0/100 - Batch:233/971 - train_loss:0.6859 - train_accuracy:0.5097




Epoch: 0/100 - Batch:234/971 - train_loss:0.6864 - train_accuracy:0.5102

235it [00:34,  6.83it/s]
  0%|          | 0/100 [00:34<?, ?it/s]


KeyboardInterrupt: 