In [1]:
import os
import sys

project_directory = os.path.abspath('..')
sys.path.append(project_directory)

In [None]:
import torch
import torch.nn as nn
from src.utils.common import set_seed

set_seed(42)

ModuleNotFoundError: No module named 'utils'

In [3]:
from src.dataloaders.dataloader import create_dataloaders
train_loader, val_loader, test_loader = create_dataloaders(data_root='../data')

In [4]:
train_dataset, val_dataset, test_dataset = create_dataloaders(data_root='../data', return_datasets=True)

In [5]:
from src.models.modern_cnn_v1 import ModernCNNv1

In [6]:
def train_one_epoch(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device
):
    
    model.train()
    
    running_loss = 0
    correct = 0
    total = 0
    
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * xb.size(0)
        predicted_labels = preds.argmax(dim=1)
        correct += (predicted_labels == yb).sum().item()
        total += yb.size(0)
        
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    return epoch_loss, epoch_acc

In [7]:
def eval_one_epoch(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    device: torch.device
):
    
    model.eval()
    
    running_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            
            preds = model(xb)
            loss = criterion(preds, yb)
            
            running_loss += loss.item() * xb.size(0)
            predicted_labels = preds.argmax(dim=1)
            correct += (predicted_labels == yb).sum().item()
            total += yb.size(0)
            
    epoch_loss = running_loss / total if total > 0 else 0.0
    epoch_acc  = correct / total if total > 0 else 0.0
    
    return epoch_loss, epoch_acc
            

In [8]:
from torch.utils.data import Subset, DataLoader

train_subset = Subset(test_dataset, range(64))
val_subset   = Subset(val_dataset, range(64))

In [9]:
train_loader_small = DataLoader(
    train_subset,
    batch_size=16,
    shuffle=True
)

val_loader_small = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False
)

In [10]:
model = ModernCNNv1()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
import copy
class EarlyStopper():
    def __init__(self, patience: int=50, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.best = float('inf')
        self.best_epoch = -1
        self.best_state = None
        self.count = 0
    
    def step(self, model, val_loss, epoch):
        improved = val_loss < (self.best - self.min_delta)
        if improved:
            self.best = val_loss
            self.count = 0
            self.best_epoch = epoch
            self.best_state = copy.deepcopy(model.state_dict())
            return False
        else:
            self.count += 1
            return self.count >= self.patience

In [None]:
def fit(
    model: torch.nn.Module,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
    epochs :int=20, 
    early_stopper: EarlyStopper | None=None, 
    scheduler=None,
    verbose: bool=True
):
    """
    Full training loop with optional early stopping.
    Returns:
        history dict with loss/acc curves.
    """
    train_loss_curve, val_loss_curve = [], []
    train_acc_curve,  val_acc_curve  = [], []
    
    for epoch in range(epochs):
        train_loss, train_acc = train_one_epoch(
            model=model,
            loader=train_loader,
            optimizer=optimizer,
            criterion=criterion,
            device=device
        )
        
        val_loss, val_acc = eval_one_epoch(
            model=model,
            loader=val_loader,
            criterion=criterion,
            device=device
        )
        
        if scheduler is not None:
            if hasattr(scheduler, 'step'):
                try:  
                    scheduler.step(val_loss)
                except TypeError:
                    scheduler.step()
        
        train_loss_curve.append(train_loss)
        train_acc_curve.append(train_acc)
        val_loss_curve.append(val_loss)
        val_acc_curve.append(val_acc)
        
        if verbose:
            print(
                f"EPOCH {epoch+1}/{epochs} | "
                f"Train Loss: {train_loss:.6f} Acc: {train_acc:.6f} | "
                f"Val Loss: {val_loss:.6f} Acc: {val_acc:.6f} "
            )
        
        if early_stopper is not None:
            if early_stopper.step(model, val_loss, epoch):
                if verbose:
                    print(
                        f"Early stop at epoch {epoch+1} | "
                        f"Best epoch {early_stopper.best_epoch+1}"
                    )
                break
        
    if early_stopper is not None and early_stopper.best_state is not None:
        model.load_state_dict(early_stopper.best_state)

    history = {
        'train_loss': train_loss_curve,
        'val_loss': val_loss_curve,
        'train_acc': train_acc_curve,
        'val_acc': val_acc_curve
    }
        
    return history

In [16]:
early_stopper = EarlyStopper(patience=2)

test_result = fit(
    model=model,
    train_loader=train_loader_small,
    val_loader=val_loader_small,
    optimizer=optimizer,
    criterion=criterion,
    device=DEVICE,
    epochs=5,
    early_stopper=early_stopper
)

EPOCH 1/5 | Train Loss: 1.720966 Acc: 0.406250 | Val Loss: 2.307184 Acc: 0.151200 
EPOCH 2/5 | Train Loss: 1.365673 Acc: 0.578125 | Val Loss: 2.466031 Acc: 0.126600 
EPOCH 3/5 | Train Loss: 0.980003 Acc: 0.843750 | Val Loss: 2.663698 Acc: 0.117200 
Early stop at epoch 3 | best epoch 1
