# Trainer base class

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp trainers.trainer

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
import torch

In [None]:
#| export
import torch
from torch import nn
from torch.utils.data import DataLoader

class Trainer:
    def __init__(self, cfg, model, train_loader, val_loader=None, 
                 criterion=None, optimizer=None, device=None):
        
        self.cfg = cfg
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion or nn.CrossEntropyLoss()
        self.optimizer = optimizer or torch.optim.Adam(model.parameters(), lr=1e-3)
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

    def train_epoch(self, epoch=None):
        self.model.train()
        total_loss = 0.0
        for inputs, targets in self.train_loader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item() * inputs.size(0)
        
        return total_loss / len(self.train_loader.dataset)

    def eval_epoch(self):
        if self.val_loader is None:
            return None
        
        self.model.eval()
        total_loss = 0.0
        correct = 0
        with torch.no_grad():
            for inputs, targets in self.val_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                total_loss += loss.item() * inputs.size(0)
                preds = outputs.argmax(dim=1)
                correct += (preds == targets).sum().item()
        
        avg_loss = total_loss / len(self.val_loader.dataset)
        accuracy = correct / len(self.val_loader.dataset)
        return avg_loss, accuracy

    def fit(self):
        for epoch in range(1, self.cfg.epochs + 1):
            train_loss = self.train_epoch()
            val_result = self.eval_epoch()

            msg = f"Epoch {epoch}/{self.cfg.epochs} - Train loss: {train_loss:.4f}"
            if val_result is not None:
                val_loss, val_acc = val_result
                msg += f" | Val loss: {val_loss:.4f} | Val acc: {val_acc:.4f}"
            print(msg)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()