In [1]:
import os
import sys

project_directory = os.path.abspath('..')
sys.path.append(project_directory)

In [2]:
import torch
import torch.nn as nn

In [3]:
from src.dataloaders.dataloader import create_dataloaders
train_loader, val_loader, test_loader = create_dataloaders(data_root='../data')

In [4]:
train_dataset, val_dataset, test_dataset = create_dataloaders(data_root='../data', return_datasets=True)

In [5]:
from src.models.modern_cnn_v1 import ModernCNNv1

In [6]:
def train_one_epoch(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device
):
    
    model.train()
    
    running_loss = 0
    correct = 0
    total = 0
    
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * xb.size(0)
        predicted_labels = preds.argmax(dim=1)
        correct += (predicted_labels == yb).sum().item()
        total += yb.size(0)
        
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    return epoch_loss, epoch_acc

In [7]:
def eval_one_epoch(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    device: torch.device
):
    
    model.eval()
    
    running_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            
            preds = model(xb)
            loss = criterion(preds, yb)
            
            running_loss += loss.item() * xb.size(0)
            predicted_labels = preds.argmax(dim=1)
            correct += (predicted_labels == yb).sum().item()
            total += yb.size(0)
            
    epoch_loss = running_loss / total if total > 0 else 0.0
    epoch_acc  = correct / total if total > 0 else 0.0
    
    return epoch_loss, epoch_acc
            

In [8]:
from torch.utils.data import Subset, DataLoader

train_subset = Subset(test_dataset, range(64))
val_subset   = Subset(val_dataset, range(64))

In [9]:
train_loader_small = DataLoader(
    train_subset,
    batch_size=16,
    shuffle=True
)

val_loader_small = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False
)

In [10]:
model = ModernCNNv1()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
subset_train_loss, subset_train_acc = train_one_epoch(
    model=model,
    loader=train_loader_small,
    optimizer=optimizer,
    criterion=criterion,
    device=DEVICE
)

In [18]:
subset_val_loss, subset_val_acc = eval_one_epoch(
    model=model,
    loader=val_loader_small,
    criterion=criterion,
    device=DEVICE
)

In [19]:
subset_train_loss, subset_train_acc, subset_val_loss, subset_val_acc

(1.6429251432418823, 0.4375, 2.548455617904663, 0.1152)

In [None]:
class EarlyStopper():
    def __init__(self,):
        pass

In [None]:
def fit(
    model: torch.nn.Module,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
    epochs :int=20, 
    early_stopper: EarlyStopper | None=None, 
    scheduler=None,
    verbose: bool=True
):
    """
    Full traning loop with optional early stopping.
    Returns:
        history dict with loss/acc cures.
    """
    train_loss_curve, val_loss_curve = [], []
    train_acc_curve,  val_acc_curve  = [], []
    
    for epoch in epochs:
        train_loss, train_acc = train_one_epoch(
            model=model,
            loader=train_loader,
            optimizer=optimizer,
            criterion=criterion,
            device=device
        )
        
        val_loss, val_acc = eval_one_epoch(
            model=model,
            loader=val_loader,
            criterion=criterion,
            device=device
        )
        
        # if scheduler is not None:
        #     scheduler.step(val_loss)
        
        train_loss_curve.append(train_loss)
        train_acc_curve.append(train_acc)
        val_loss_curve.append(val_loss)
        val_acc_curve.append(val_acc)
        
        if verbose:
            print(
                f"EPOCH {epoch+1}/{epochs} |"
                f"Train Loss: {train_loss:.6f} Acc: {train_acc:.6f}"
                f"Val Loss: {val_loss:.6f} Acc: {val_acc:.6f}"
            )
        
        if EarlyStopper:
            pass
        
        history = {
            'train_loss': train_loss_curve,
            'val_loss': val_loss_curve,
            'train_acc': train_acc_curve,
            'val_acc': val_acc_curve
        }
        
        return history