In [87]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager

from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from datasets import load_dataset,load_dataset_builder

from miniai.datasets import *
from miniai.conv import *

from fastprogress import progress_bar,master_bar



In [2]:
from fastcore.test import test_close

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

import logging
logging.disable(logging.WARNING)

## Learner

In [3]:
x , y = 'image', 'label'
name = "fashion_mnist"
dsd = load_dataset(name)

  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
@inplace
def transformi(b): b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [5]:
bs = 1024
tds = dsd.with_transform(transformi)

In [6]:
class DataLoaders:
    def __init__(self, *dls):
        self.train, self.valid = dls[:2]
    
    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True):
        '''from dataset dictionary'''
        return cls(*[DataLoader(ds, batch_size, num_workers=4, collate_fn=collate_dict(ds)) for ds in dd.values()])

In [7]:
dls = DataLoaders.from_dd(tds, bs)
dt = dls.train
xb, yb =next(iter(dt))
xb.shape, yb[:10]

(torch.Size([1024, 784]), tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5]))

In [8]:
#|export
class Learner:
    def __init__(self, model, dls, loss_func, lr, opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.xb,self.yb = to_device(self.batch)
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        with torch.no_grad(): self.calc_stats()
    
    def calc_stats(self):
        acc = (self.preds.argmax(dim=1)==self.yb).float().sum()
        self.accs.append(acc)
        n = len(self.xb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

    def one_epoch(self, train):
        self.accs,self.losses,self.ns = [],[],[]
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        for self.num,self.batch in enumerate(dl): self.one_batch()
        n = sum(self.ns)
        print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)
    
    def fit(self, n_epochs):
        #self.accs,self.losses,self.ns = [],[],[]
        self.model.to(def_device)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.one_epoch(True)
            self.one_epoch(False)

In [9]:
m,nh = 28*28,50
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [10]:
learn = Learner(model, dls, F.cross_entropy, lr=0.2)
learn.fit(1)

0 True 1.1852178385416667 0.5956333333333333
0 False 0.8332626953125 0.6878


## Basic callbacks learner

In [44]:
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [45]:
def run_cbs(cbs, method_nm):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        if method is not None:
            method()

In [46]:
class Callback():
    order = 0

In [47]:
class CompletionCB(Callback):
    def before_fit(self): self.count = 0
    def after_batch(self): self.count += 1
    def after_fit(self): print(f'completed {self.count} batches')

In [48]:
cbs = [CompletionCB()]
run_cbs(cbs, 'before_fit')
run_cbs(cbs, 'after_batch')
run_cbs(cbs, 'after_fit')

completed 1 batches


In [55]:
class Learner():
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func = optim.SGD):
        fc.store_attr()
        for cb in cbs: cb.learn = self
    
    def one_batch(self):
        self.preds = self.model(self.batch[0])
        self.loss = self.loss_func(self.preds, self.batch[1])
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
            
    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        try:
            self.callback('before_epoch')
            for self.iter,self.batch in enumerate(self.dl):
                try:
                    self.callback('before_batch')
                    self.one_batch()
                    self.callback('after_batch')
                except CancelBatchException: pass
            self.callback('after_epoch')
        except CancelEpochException: pass
        
    def fit(self, n_epochs):
        self.n_epochs =  n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        try:
            self.callback('before_fit')
            for self.epoch in self.epochs:
                self.one_epoch(True)
                self.one_epoch(False)
            self.callback('after_fit')
        except CancelFitException: pass
        
    def callback(self, method_nm):
        run_cbs(self.cbs, method_nm)

In [56]:
m, nh = 28*28, 50

def get_model():
    return nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, 10))

In [57]:
model = get_model()
learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=[CompletionCB()])
learn.fit(1)

completed 69 batches


In [66]:
class SingleBatchCallback(Callback):
    order = 1
    def after_batch(self): 
        raise CancelEpochException()

In [67]:
learn = Learner(get_model(), dls, F.cross_entropy, lr=0.2, cbs=[SingleBatchCallback(), CompletionCB()])
learn.fit(1)

completed 2 batches


## Metrics

In [68]:
class Metric:
    def __init__(self):
      self.reset()
    def reset(self):
        self.vals, self.ns = [], []
    def add(self, inp, targ=None, n=1):
        self.last = self.calc(inp, targ)
        self.vals.append(self.last)
        self.ns.append(n)
    @property
    def value(self):
        ns = tensor(self.ns)
        return (tensor(self.vals)*ns).sum()/ns.sum()
    def calc(self, inps, targs):
        return inps

In [69]:
class Accuracy(Metric):
    def calc(self, inps, targs):
        return (inps==targs).float().mean()

In [70]:
acc = Accuracy()
acc.add(tensor([0,1,2,0,1,2]), tensor([0,1,1,2,1,0]))
acc.add(tensor([1,1,2,0,1]), tensor([0,1,1,2,1]))
acc.value

tensor(0.45)

In [71]:
loss = Metric()
loss.add(0.6, n=32)
loss.add(0.9, n=2)
loss.value, round((0.6*32+0.9*2)/(32+2), 2)

(tensor(0.62), 0.62)

## Some callbacks

In [76]:
class DeviceCB(Callback):
    def __init__(self, device=def_device):
        fc.store_attr()
    def before_fit(self):
        self.learn.model.to(self.device)
    def before_batch(self):
        self.learn.batch = to_device(self.learn.batch, device=self.device)

### using torcheval metrics package (pip install torcheval)

In [78]:
from torcheval.metrics import MulticlassAccuracy, Mean

In [79]:
metric = MulticlassAccuracy()
metric.update(tensor([0 ,2, 1, 3]), tensor([0, 1, 2, 3]))
metric.compute()

tensor(0.50)

In [80]:
metric.reset()
metric.compute()

tensor(nan)

In [92]:
def to_cpu(x):
    if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}
    if isinstance(x, list): return [to_cpu(o) for o in x]
    if isinstance(x, tuple): return tuple(to_cpu(list(x)))
    return x.detach().cpu()

In [95]:
class MetricsCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms:
            #if there are arguments in ms we add them to metrics dictionary. 
            #we use the object type as key for the dict.
            metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()
        
    def _log(self, d): 
        print(d)
        
    def before_fit(self):
        self.learn.metrics = self
    
    def before_epoch(self):
        [o.reset() for o in self.all_metrics.values()]
    
    def after_epoch(self):
        log = {k:f'{v.compute()}:.3f' for k,v in self.all_metrics.items()}
        log['epoch'] = self.learn.epoch
        log['train'] = self.learn.model.training
        self._log(log)
        
    def after_batch(self):
        x, y = to_cpu(self.learn.batch)
        for m in self.metrics.values():
            m.update(to_cpu(self.learn.preds), y)
        self.loss.update(to_cpu(self.learn.loss), weight=len(x))

In [97]:
model = get_model()
metrics = MetricsCB(accuracy = MulticlassAccuracy())
learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=[DeviceCB(), metrics])
learn.fit(5)

{'accuracy': '0.5989166498184204:.3f', 'loss': '1.1985833644866943:.3f', 'epoch': 0, 'train': True}
{'accuracy': '0.6873999834060669:.3f', 'loss': '0.8357239961624146:.3f', 'epoch': 0, 'train': False}
{'accuracy': '0.7438666820526123:.3f', 'loss': '0.7178865671157837:.3f', 'epoch': 1, 'train': True}
{'accuracy': '0.7598000168800354:.3f', 'loss': '0.6535628437995911:.3f', 'epoch': 1, 'train': False}
{'accuracy': '0.7845166921615601:.3f', 'loss': '0.6146935820579529:.3f', 'epoch': 2, 'train': True}
{'accuracy': '0.7857999801635742:.3f', 'loss': '0.5967615246772766:.3f', 'epoch': 2, 'train': False}
{'accuracy': '0.8023666739463806:.3f', 'loss': '0.5637335181236267:.3f', 'epoch': 3, 'train': True}
{'accuracy': '0.8082000017166138:.3f', 'loss': '0.547622561454773:.3f', 'epoch': 3, 'train': False}
{'accuracy': '0.8141000270843506:.3f', 'loss': '0.5293701887130737:.3f', 'epoch': 4, 'train': True}
{'accuracy': '0.8209999799728394:.3f', 'loss': '0.5174703001976013:.3f', 'epoch': 4, 'train': Fal