In [14]:
import torch
from torch.utils.data import DataLoader
from torch import nn

from functools import partial

from toss.utils import flatten
from toss.layers import Lambda
from toss.train import Trainer
from toss.data import get_mnist, get_gens, DataSet, DataBunch, normalize
from toss.utils import listify
from toss.optimize import Optimizer

In [4]:
x_train, y_train, x_valid, y_valid = get_mnist()

In [7]:
x_train, x_valid = normalize(x_train, x_valid)

In [9]:
train_ds = DataSet(x_train, y_train)
valid_ds = DataSet(x_valid, y_valid)

In [10]:
n, m = x_train.shape
c = y_train.max() + 1
bs = 32
n, m, c

(50000, 784, tensor(10))

In [12]:
data = DataBunch(*get_gens(train_ds, valid_ds, bs), c)

In [10]:
def mnist_resize(x):
    return x.view(-1, 1, 28, 28)

def get_cnn_model(data):
    return nn.Sequential(
      Lambda(mnist_resize),
      nn.Conv2d(1, 8, 5, padding=2, stride=2), nn.ReLU(),
      nn.Conv2d(8, 16, 3, padding=1, stride=2), nn.ReLU(),
      nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.ReLU(),
      nn.Conv2d(32, 32, 3, padding=1, stride=2), nn.ReLU(),
      nn.AdaptiveAvgPool2d(1),
      Lambda(flatten),
      nn.Linear(32, data.c)
  )

In [11]:
model = get_cnn_model(data)

In [16]:
class Optimizer():
    def __init__(self, params, lr=0.5):
        self.params = list(params)
        self.lr = lr
    
    def step(self):
        with torch.no_grad():
            for p in self.params:
                p -= p.grad * self.lr
            
    def zero_grad(self):
        for p in self.params:
            p.grad.data.zero_()

In [17]:
class Callback():
    def set_learner(self, learner):
        self.learner = learner
    
    def __getattr__(self, k):
        return getattr(self.learner, k)
    
    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f():
            return True
        return False

In [18]:
class TrainEvalCallback(Callback):
    _order = 1
    def begin_epoch(self):
        self.model.train()
        self.learner.in_train= True
    
    def begin_validate(self):
        self.model.eval()
        self.learner.in_train= False

In [32]:
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics = listify(metrics)
        self.in_train = in_train
    
    def reset(self):
        self.tot_loss = 0
        self.count = 0
        self.tot_mets = [0.] * len(self.metrics)
    
    @property
    def all_stats(self):
        return [self.tot_loss.item()] + self.tot_mets
    
    @property
    def avg_stats(self):
        return [o / self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count:
            return ""
        return f"{'train' if self.in_train else 'valid'} : {self.avg_stats}"
    
    def accumulate(self, learner):
        bn = learner.xb.shape[0]
        self.tot_loss += learner.loss * bn
        self.count += bn
        for i, metric in enumerate(self.metrics):
            self.tot_mets[i] += metric(learner.pred, learner.yb) * bn
    
class AvgStatsCallback(Callback):
    _order = 10
    def __init__(self, metrics):
        self.train_stats = AvgStats(metrics, True)
        self.valid_stats = AvgStats(metrics, False)
    
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
    
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad():
            stats.accumulate(self.learner)
    
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)
    

In [20]:
def accuracy(out, yb):
    return (torch.argmax(out, dim=1) == yb).float().mean()

In [33]:
import torch.nn.functional as F
loss_func = F.cross_entropy
opt = Optimizer(model.parameters())
metrics = [accuracy]
cbfs = [AvgStatsCallback(metrics)]
learner = Learner(model, data, loss_func, opt, cbs=cbfs)

In [34]:
learner.fit(1)

train : [0.094985537109375, tensor(0.9728)]
valid : [0.07639669189453124, tensor(0.9786)]
