#### Normal Trainer

In [None]:
class Trainer():
     def __init__(self, model, device, criterion, optimizer, early_stopping_patience, save_path, wscheduler=None):
         self.model = model.to(device)
         self.device = device
         self.criterion = criterion
         self.optimizer = optimizer
         self.scheduler =scheduler
         self.early_stopping_patience = early_stopping_patience
         self.save_path = save_path
         self.best_val_loss = 0.0
         self.epochs_no_improve = 0

     def train_one_epoch(self, train_loader):
            self.model.train()

            running_loss = 0.0
            correct, total = 0,0

            for images, labels in tqdm(train_loader, desc="Training", leave=False):
                images, labels = images.to(self.device), labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (preds == labels).sum().item()

            train_acc = 100 * correct / total
            train_loss = running_loss / total

            return train_loss, train_acc

     def validate_per_epoch(self, val_loader):
            val_loss, val_correct, val_total = 0.00, 0, 0
            with torch.no_grad():
                for images, labels in tqdm(val_loader, desc="Validation"):
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                    val_loss += loss.item() * images.size(0)
                    _, preds = torch.max(outputs, 1)
                    val_total += labels.size(0)
                    val_correct += (preds == labels).sum().item()

            epoch_loss = val_loss/val_total
            epoch_acc = 100 * val_correct / val_total

            return epoch_loss, epoch_acc


     def fit(self, train_loader, val_loader, num_epochs):
            history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

            for epoch in range(num_epochs):
                print(f"epochs {epoch}/{num_epochs}")
                train_loss, train_acc = self.train_one_epoch(train_loader)
                val_loss, val_acc = self.validate_per_epoch(val_loader)

                history["train_loss"].append(train_loss)
                history["train_acc"].append(train_acc)
                history["val_loss"].append(val_loss)
                history["val_acc"].append(val_acc)

                print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%"
                     f"\nValidation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")

                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    self.save = f"model/model-epoch{epoch}-best.pth"
                    torch.save(self.model.state_dict(), self.save_path)
                    self.epochs_no_improve = 0
                else:
                    self.epochs_no_improve += 1
                    if self.early_stopping_patience and self.epochs_no_improve > self.early_stopping_patience:
                        self.save_path = "model/finetuned-model.pth"
                        torch.save(self.model.state_dict(), self.save_path)
                        print(f"Early stopped latest model saved {self.save_path}")




#### Kfold-Trainer

In [None]:
import torch
import torch.nn as nn
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import numpy as np
import os


class Trainer:
    def __init__(self, model, device, criterion, optimizer, scheduler=None, 
                 early_stopping_patience=5, save_path="model/best_model.pth"):
        self.model = model.to(device)
        self.device = device
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.early_stopping_patience = early_stopping_patience
        self.save_path = save_path

        self.best_val_loss = float("inf")
        self.epochs_no_improve = 0

    def train_one_epoch(self, train_loader):
        self.model.train()
        running_loss = 0.0
        correct, total = 0, 0

        for images, labels in tqdm(train_loader, desc="Training", leave=False):
            images, labels = images.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()

            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

        train_acc = 100 * correct / total
        train_loss = running_loss / total

        return train_loss, train_acc

    def validate_per_epoch(self, val_loader):
        self.model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Validation", leave=False):
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (preds == labels).sum().item()

        epoch_loss = val_loss / val_total
        epoch_acc = 100 * val_correct / val_total
        return epoch_loss, epoch_acc

    def fit(self, train_loader, val_loader, num_epochs):
        history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            train_loss, train_acc = self.train_one_epoch(train_loader)
            val_loss, val_acc = self.validate_per_epoch(val_loader)

            if self.scheduler:
                self.scheduler.step(val_loss)

            history["train_loss"].append(train_loss)
            history["train_acc"].append(train_acc)
            history["val_loss"].append(val_loss)
            history["val_acc"].append(val_acc)

            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

            # --- Early Stopping ---
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                torch.save(self.model.state_dict(), self.save_path)
                self.epochs_no_improve = 0
                print(f"✅ Best model saved at {self.save_path}")
            else:
                self.epochs_no_improve += 1
                if self.epochs_no_improve > self.early_stopping_patience:
                    print(f"⛔ Early Stopping triggered. Best model at {self.save_path}")
                    break

        return history


In [None]:
TrainerKFold(model_class=MyModel, dataset=my_dataset).run()



class TrainerKFold:
    def __init__(self, model_class, dataset, device, criterion, optimizer_class,
                 scheduler_class=None, n_splits=5, epochs=10, batch_size=32,
                 patience=5, save_dir="model_kfold/"):
        self.model_class = model_class
        self.dataset = dataset
        self.device = device
        self.criterion = criterion
        self.optimizer_class = optimizer_class
        self.scheduler_class = scheduler_class
        self.n_splits = n_splits
        self.epochs = epochs
        self.batch_size = batch_size
        self.patience = patience
        self.save_dir = save_dir

        os.makedirs(save_dir, exist_ok=True)

    def _get_targets(self):
        if hasattr(self.dataset, "targets"):
            return np.array(self.dataset.targets)
        elif hasattr(self.dataset, "labels"):
            return np.array(self.dataset.labels)
        else:
            raise ValueError("Dataset must have `.targets` or `.labels` attribute")

    def run(self):
        y = self._get_targets()
        skf = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=42)
        fold_results = []

        for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(y)), y)):
            print(f"\n===== Fold {fold+1}/{self.n_splits} =====")

            train_subset = torch.utils.data.Subset(self.dataset, train_idx)
            val_subset = torch.utils.data.Subset(self.dataset, val_idx)

            train_loader = torch.utils.data.DataLoader(train_subset, batch_size=self.batch_size, shuffle=True)
            val_loader = torch.utils.data.DataLoader(val_subset, batch_size=self.batch_size, shuffle=False)

            model = self.model_class()
            optimizer = self.optimizer_class(model.parameters(), lr=1e-4)
            scheduler = None
            if self.scheduler_class:
                scheduler = self.scheduler_class(optimizer)

            save_path = os.path.join(self.save_dir, f"best_fold{fold+1}.pth")
            trainer = Trainer(model, self.device, self.criterion, optimizer, scheduler,
                              early_stopping_patience=self.patience, save_path=save_path)

            history = trainer.fit(train_loader, val_loader, self.epochs)
            fold_results.append(history["val_acc"][-1])

        print(f"\nAverage Validation Accuracy across folds: {np.mean(fold_results):.2f}%")
        return fold_results
