# Learner Module

> Training interface for object detection models in fastai style

In [None]:
#| default_exp learner

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from pathlib import Path
import time
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Tuple, Union, Optional, Callable
from PIL import Image
import torchmetrics

from objdetect.core import plot_boxes
from objdetect.data import ObjectDetectionDataset
from objdetect.models import create_model

In [None]:
#| hide
from nbdev.showdoc import *

## Callback System

In [None]:
#| export
class Callback:
    """Base class for callbacks."""
    order = 0
    
    def before_fit(self, learner): pass
    def after_fit(self, learner): pass
    def before_epoch(self, learner): pass
    def after_epoch(self, learner): pass
    def before_batch(self, learner): pass
    def after_batch(self, learner): pass
    def before_backward(self, learner): pass
    def after_backward(self, learner): pass

In [None]:
#| export
class ProgressCallback(Callback):
    """Display training progress."""
    order = 0
    
    def before_fit(self, learner):
        self.train_losses = []
        self.val_losses = []
        
    def after_batch(self, learner):
        if learner.training:
            # Only log every 10 batches to avoid flooding output
            if learner.batch_idx % 10 == 0:
                loss_str = ", ".join([f"{k}: {v:.4f}" for k, v in learner.loss_dict.items()])
                print(f"Epoch {learner.epoch+1}/{learner.n_epochs}, Batch {learner.batch_idx+1}/{len(learner.train_dl)}, {loss_str}")
                
            # Save loss for plotting
            self.train_losses.append(learner.loss.item())
    
    def after_epoch(self, learner):
        # Save validation loss for plotting
        if not learner.training and hasattr(learner, 'val_loss'):
            self.val_losses.append(learner.val_loss)
        
        # Print epoch summary
        if not learner.training:
            train_loss = sum(self.train_losses[-len(learner.train_dl):])/len(learner.train_dl)
            print(f"Epoch {learner.epoch+1}/{learner.n_epochs} - Train Loss: {train_loss:.4f}")
            
            if hasattr(learner, 'val_loss'):
                print(f"Epoch {learner.epoch+1}/{learner.n_epochs} - Val Loss: {learner.val_loss:.4f}")
                
    def after_fit(self, learner):
        # Plot loss curves
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label='Train Loss')
        
        if self.val_losses:
            # Plot validation loss points (one per epoch)
            x_vals = np.linspace(0, len(self.train_losses)-1, len(self.val_losses))
            plt.plot(x_vals, self.val_losses, 'ro-', label='Val Loss')
            
        plt.xlabel('Batch')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

## Learning Rate Schedule

In [None]:
#| export
class OneCycleScheduler(Callback):
    """One cycle learning rate scheduler."""
    order = 1
    
    def __init__(self, max_lr, pct_start=0.3, div_factor=25., final_div_factor=1e4):
        self.max_lr = max_lr
        self.pct_start = pct_start
        self.div_factor = div_factor
        self.final_div_factor = final_div_factor
        
    def before_fit(self, learner):
        self.n_steps = learner.n_epochs * len(learner.train_dl)
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            learner.optimizer,
            max_lr=self.max_lr,
            total_steps=self.n_steps,
            pct_start=self.pct_start,
            div_factor=self.div_factor,
            final_div_factor=self.final_div_factor
        )
        
    def after_batch(self, learner):
        if learner.training:
            self.scheduler.step()

## Checkpoint Callback

In [None]:
#| export
class SaveModelCallback(Callback):
    """Save model checkpoint."""
    order = 2
    
    def __init__(self, save_path='checkpoints', save_name='model', monitor='val_loss', mode='min'):
        self.save_path = Path(save_path)
        self.save_name = save_name
        self.monitor = monitor
        self.mode = mode
        self.best_value = float('inf') if mode == 'min' else float('-inf')
        
    def before_fit(self, learner):
        self.save_path.mkdir(exist_ok=True, parents=True)
        
    def after_epoch(self, learner):
        if not learner.training:
            current_value = getattr(learner, self.monitor, None)
            
            if current_value is not None:
                improved = (self.mode == 'min' and current_value < self.best_value) or \
                           (self.mode == 'max' and current_value > self.best_value)
                           
                if improved:
                    self.best_value = current_value
                    path = self.save_path / f"{self.save_name}_best.pth"
                    torch.save(learner.model.state_dict(), path)
                    print(f"Saved model to {path} with {self.monitor}={current_value:.4f}")
                    
            # Save last model
            path = self.save_path / f"{self.save_name}_last.pth"
            torch.save(learner.model.state_dict(), path)

## Object Detection Learner

In [None]:
#| export
class ObjectDetectionLearner:
    """Learner for training object detection models."""
    
    def __init__(self, dataset, model=None, batch_size=4, num_workers=2, 
                 callbacks=None, device=None):
        """
        Args:
            dataset: ObjectDetectionDataset or tuple of (train_ds, val_ds)
            model: Object detection model or model name
            batch_size: Batch size for training
            num_workers: Number of workers for data loading
            callbacks: List of callbacks
            device: Device to use (auto-detected if None)
        """
        # Set device
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Handle dataset
        if isinstance(dataset, tuple):
            self.train_ds, self.val_ds = dataset
        else:
            # If only one dataset is provided, use it for both train and val
            self.train_ds = dataset
            self.val_ds = dataset
            
        # Create data loaders
        self.train_dl = DataLoader(
            self.train_ds, batch_size=batch_size, shuffle=True,
            num_workers=num_workers, collate_fn=self.train_ds.collate_fn
        )
        
        self.val_dl = DataLoader(
            self.val_ds, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, collate_fn=self.val_ds.collate_fn
        ) if self.val_ds != self.train_ds else None
        
        # Handle model
        if model is None:
            # Create default model
            self.model = create_model(num_classes=self.train_ds.num_classes)
        elif isinstance(model, str):
            # Create model from name
            self.model = create_model(model, num_classes=self.train_ds.num_classes)
        else:
            # Use provided model
            self.model = model
            
        self.model = self.model.to(self.device)
        
        # Set up callbacks
        self.callbacks = [ProgressCallback()] if callbacks is None else callbacks
        # Sort callbacks by order
        self.callbacks = sorted(self.callbacks, key=lambda x: x.order)
        
    def _prepare_batch(self, batch):
        """Prepare batch for training."""
        images, targets = batch
        
        # Move images to device
        images = [img.to(self.device) for img in images]
        
        # Move targets to device
        targets = [{k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                   for k, v in t.items()} for t in targets]
        
        return images, targets
    
    def fit(self, n_epochs, lr=1e-4, optimizer=None):
        """Train the model.
        
        Args:
            n_epochs: Number of epochs
            lr: Learning rate
            optimizer: Optimizer (default: Adam)
        """
        self.n_epochs = n_epochs
        
        # Set up optimizer
        self.optimizer = optimizer or torch.optim.Adam(self.model.parameters(), lr=lr)
        
        # Call before_fit for all callbacks
        for cb in self.callbacks:
            cb.before_fit(self)
            
        # Training loop
        for epoch in range(n_epochs):
            self.epoch = epoch
            
            # Training phase
            self.model.train()
            self.training = True
            
            # Call before_epoch for all callbacks
            for cb in self.callbacks:
                cb.before_epoch(self)
                
            for batch_idx, batch in enumerate(self.train_dl):
                self.batch_idx = batch_idx
                self.batch = batch
                
                # Call before_batch for all callbacks
                for cb in self.callbacks:
                    cb.before_batch(self)
                    
                # Forward pass
                images, targets = self._prepare_batch(batch)
                loss_dict = self.model(images, targets)
                
                # Get total loss
                self.loss_dict = {k: v.item() for k, v in loss_dict.items()}
                self.loss = sum(loss for loss in loss_dict.values())
                
                # Backward pass
                self.optimizer.zero_grad()
                
                # Call before_backward for all callbacks
                for cb in self.callbacks:
                    cb.before_backward(self)
                    
                self.loss.backward()
                
                # Call after_backward for all callbacks
                for cb in self.callbacks:
                    cb.after_backward(self)
                    
                self.optimizer.step()
                
                # Call after_batch for all callbacks
                for cb in self.callbacks:
                    cb.after_batch(self)
            
            # Validation phase
            if self.val_dl is not None:
                self.model.eval()
                self.training = False
                val_losses = []
                
                with torch.no_grad():
                    for batch_idx, batch in enumerate(self.val_dl):
                        self.batch_idx = batch_idx
                        self.batch = batch
                        
                        # Call before_batch for all callbacks
                        for cb in self.callbacks:
                            cb.before_batch(self)
                            
                        # Forward pass
                        images, targets = self._prepare_batch(batch)
                        loss_dict = self.model(images, targets)
                        
                        # Get total loss
                        loss = sum(loss for loss in loss_dict.values())
                        val_losses.append(loss.item())
                        
                        # Call after_batch for all callbacks
                        for cb in self.callbacks:
                            cb.after_batch(self)
                
                self.val_loss = sum(val_losses) / len(val_losses)
            
            # Call after_epoch for all callbacks
            for cb in self.callbacks:
                cb.after_epoch(self)
        
        # Call after_fit for all callbacks
        for cb in self.callbacks:
            cb.after_fit(self)
    
    def fit_one_cycle(self, n_epochs, max_lr=1e-3):
        """Train with the 1cycle policy.
        
        Args:
            n_epochs: Number of epochs
            max_lr: Maximum learning rate
        """
        optimizer = torch.optim.SGD(self.model.parameters(), lr=max_lr/25)
        one_cycle = OneCycleScheduler(max_lr=max_lr)
        
        # Add OneCycleScheduler to callbacks if not already present
        for cb in self.callbacks:
            if isinstance(cb, OneCycleScheduler):
                break
        else:
            self.callbacks.append(one_cycle)
            self.callbacks = sorted(self.callbacks, key=lambda x: x.order)
        
        self.fit(n_epochs=n_epochs, optimizer=optimizer)
        
    def predict(self, img, threshold=0.5):
        """Make prediction on a single image.
        
        Args:
            img: PIL.Image or file path or tensor
            threshold: Confidence threshold
            
        Returns:
            Dictionary with prediction results
        """
        self.model.eval()
        
        # Load image if path is provided
        if isinstance(img, (str, Path)):
            img = Image.open(img).convert('RGB')
            
        with torch.no_grad():
            if hasattr(self.model, 'predict'):
                pred = self.model.predict(img, threshold=threshold)[0]
            else:
                # For models without a predict method
                if not isinstance(img, torch.Tensor):
                    img = torchvision.transforms.ToTensor()(img)
                img = img.to(self.device).unsqueeze(0)
                pred = self.model(img)[0]
                
                # Filter by threshold
                keep = pred['scores'] >= threshold
                pred = {k: v[keep] for k, v in pred.items()}
        
        return pred
    
    def show_results(self, img, pred=None, figsize=(10, 10)):
        """Show prediction results.
        
        Args:
            img: PIL.Image or file path or tensor
            pred: Prediction dictionary (if None, predict will be called)
            figsize: Figure size
            
        Returns:
            Matplotlib figure
        """
        # Load image if path is provided
        if isinstance(img, (str, Path)):
            img_path = img
            img = Image.open(img_path).convert('RGB')
            
        # Make prediction if not provided
        if pred is None:
            pred = self.predict(img)
            
        boxes = pred['boxes'].cpu() if 'boxes' in pred else None
        labels = pred['labels'].cpu() if 'labels' in pred else None
        scores = pred['scores'].cpu() if 'scores' in pred else None
        
        class_names = self.train_ds.class_names if hasattr(self.train_ds, 'class_names') else None
        
        return plot_boxes(img, boxes, labels, scores, class_names, figsize=figsize)