# Customize Training with Callbacks
---

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Customize-Training-with-Callbacks" data-toc-modified-id="Customize-Training-with-Callbacks-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Customize Training with Callbacks</a></span><ul class="toc-item"><li><span><a href="#Import-Libraries" data-toc-modified-id="Import-Libraries-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Import Libraries</a></span></li><li><span><a href="#Load-Data" data-toc-modified-id="Load-Data-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Load Data</a></span></li><li><span><a href="#Pre-process-Data" data-toc-modified-id="Pre-process-Data-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>Pre-process Data</a></span></li></ul></li><li><span><a href="#Basic-Callback-system" data-toc-modified-id="Basic-Callback-system-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Basic Callback system</a></span></li></ul></div>

## Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data.dataloader import DataLoader

## Load Data

In [2]:
dataset = MNIST(root="../data/")

In [3]:
x, y = dataset.data.float(), dataset.targets
x_train, x_test = x[:50000], x[50000:]
y_train, y_test = y[:50000], y[50000:]

x_train.shape, x_test.shape

(torch.Size([50000, 28, 28]), torch.Size([10000, 28, 28]))

**Flatten**

In [4]:
x_train = x_train.reshape(x_train.shape[0],x_train.shape[1]*x_train.shape[2])
x_test = x_test.reshape(x_test.shape[0],x_test.shape[1]*x_test.shape[2])

## Pre-process Data

In [5]:
def normalize(x, mean, std): return (x-mean)/std

In [6]:
train_mean = x_train.mean()
train_std = x_train.std()

In [7]:
x_train = normalize(x_train, train_mean, train_std)
x_test = normalize(x_test, train_mean, train_std)

In [8]:
class Dataset():
    
    def __init__(self, x, y):
        self.x, self.y = x, y
    
    def __len__(self):
        return self.x.shape[0]
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]

# Basic Callback system
---

In [9]:
config = {
    'epochs': 10,
    'lr': 0.5,
    'bs': 128
}

In [10]:
model = nn.Sequential(nn.Linear(784,300), nn.ReLU(), nn.Linear(300,10))

In [11]:
train_ds, test_ds = Dataset(x_train, y_train), Dataset(x_test, y_test)

In [24]:
class Agent():
    
    def __init__(self, config, model,
                 train_ds, test_ds,
                 loss_fn, metric_fn, optimizer):
        
        self.config, self.model = config, model
        
        self.train_dl = DataLoader(train_ds, self.config['bs'], shuffle=True)
        self.test_dl = DataLoader(test_ds, self.config['bs']*2, shuffle=False)
        
        self.loss_fn, self.metric_fn = loss_fn, metric_fn
        self.opt = optimizer(self.model.parameters(), config['lr'])
        
    def train_one_epoch(self):
        self.model.train()
        losses = 0.
        for xb, yb in self.train_dl:
            yb_pred = self.model(xb)
            loss = self.loss_fn(yb_pred, yb)
            loss.backward()
            self.opt.step()
            self.opt.zero_grad()
            losses += loss
        return losses           
    
    def validate(self):
        self.model.eval()
        losses = 0.
        metrics = 0.
        with torch.no_grad():
            for xb, yb in self.test_dl:
                yb_pred = self.model(xb)
                losses += self.loss_fn(yb_pred, yb)
                metrics += self.metric_fn(yb_pred, yb)
        return losses, metrics
    
    def train(self):
        for epoch in range(self.config['epochs']):
            train_loss = self.train_one_epoch()
            valid_loss, valid_metric = self.validate()
            report = (f"EPOCH#:{epoch} \t"
                      f"Train-Loss: {train_loss:.4f} \t"
                      f"Valid-Loss: {valid_loss:.4f} \t"
                      f"Valid-Metric: {valid_metric:.4f}")
            print(report)

In [25]:
agent = Agent(config, model,
              train_ds, test_ds,
              F.cross_entropy, F.cross_entropy,
              torch.optim.SGD)

In [26]:
agent.train()

EPOCH#:0 	Train-Loss: 1.4344 	Valid-Loss: 2.9827 	Valid-Metric: 2.9827
EPOCH#:1 	Train-Loss: 0.8576 	Valid-Loss: 3.0184 	Valid-Metric: 3.0184
EPOCH#:2 	Train-Loss: 0.6495 	Valid-Loss: 3.0779 	Valid-Metric: 3.0779
EPOCH#:3 	Train-Loss: 0.4975 	Valid-Loss: 3.1153 	Valid-Metric: 3.1153
EPOCH#:4 	Train-Loss: 0.4218 	Valid-Loss: 3.1431 	Valid-Metric: 3.1431
EPOCH#:5 	Train-Loss: 0.3615 	Valid-Loss: 3.1253 	Valid-Metric: 3.1253
EPOCH#:6 	Train-Loss: 0.3272 	Valid-Loss: 3.1588 	Valid-Metric: 3.1588
EPOCH#:7 	Train-Loss: 0.2870 	Valid-Loss: 3.2010 	Valid-Metric: 3.2010
EPOCH#:8 	Train-Loss: 0.2642 	Valid-Loss: 3.2119 	Valid-Metric: 3.2119
EPOCH#:9 	Train-Loss: 0.2444 	Valid-Loss: 3.2481 	Valid-Metric: 3.2481
