In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from callback import *

In [3]:
#export
class AvgStats():
    def __init__(self, metrics, training): 
        self.metrics = metrics
        self.training = training
    
    def reset(self):
        self.count = 0
        self.total_loss = torch.Tensor([0])
        self.totals = [torch.Tensor([0])] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.total_loss] + self.totals
    
    @property
    def avg_stats(self): return [s.item()/self.count for s in self.all_stats]
    
    def __repr__(self):
        if not self.count: 
            return ''
        return f"{'train' if self.training else 'valid'} metrics - {self.avg_stats}"

    def accumulate(self, runner):
        batch_size = runner.x_batch.shape[0]
        self.count += batch_size
        self.total_loss = runner.loss * batch_size
        for i, metric in enumerate(self.metrics):
            self.totals[i] += metric(runner.pred, runner.y_batch) * batch_size

class StatsLogging(Callback):
    def __init__(self, metrics):
        self.train_stats = AvgStats(metrics, True)
        self.valid_stats = AvgStats(metrics, False)
        
    def before_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.model.training else self.valid_stats
        stats.accumulate(self.runner)
    
    def after_epoch(self):
        print(f'Epoch - {self.epoch}\n{self.train_stats}\n{self.valid_stats}\n')

In [4]:
num_hidden = 50
batch_size = 64
num_epochs = 20
learning_rate = 0.1

data_bunch = get_data_bunch(*get_mnist_data(), batch_size)
model, optimizer = get_model(data_bunch, learning_rate, num_hidden)
loss_fn = CrossEntropy()
learner = Learner(model, optimizer, loss_fn, data_bunch)

In [5]:
runner = Runner(learner, [StatsLogging([compute_accuracy])])
print(runner)

(DataBunch) 
	(DataLoader) 
		(Dataset) x: (50000, 784), y: (50000,)
		(Sampler) total: 50000, batch_size: 64, shuffle: True
	(DataLoader) 
		(Dataset) x: (10000, 784), y: (10000,)
		(Sampler) total: 10000, batch_size: 128, shuffle: False
(Sequential)
	(Layer1) Linear(784, 50)
	(Layer2) ReLU()
	(Layer3) Linear(50, 10)
(CrossEntropy)
(Optimizer) num_params: 4, learning_rate: 0.1
(Callbacks) TrainEval StatsLogging


In [6]:
runner.fit(5)

Epoch - 1
train metrics - [0.00021775859832763673, 0.91726]
valid metrics - [0.00012849409580230713, 0.9443]

Epoch - 2
train metrics - [4.575530052185059e-05, 0.9597]
valid metrics - [2.5315427780151367e-05, 0.9651]

Epoch - 3
train metrics - [1.922694206237793e-05, 0.97048]
valid metrics - [1.1928081512451172e-05, 0.9666]

Epoch - 4
train metrics - [3.533562660217285e-05, 0.97454]
valid metrics - [9.833717346191407e-06, 0.9662]

Epoch - 5
train metrics - [7.881155014038086e-06, 0.97778]
valid metrics - [1.0901546478271485e-05, 0.969]

