In [1]:
#|default_exp learner

In [2]:
#|export
import torch, torch.nn as nn, torch.nn.functional as F
from torch import tensor, optim
import numpy as np, pandas as pd, matplotlib.pyplot as plt, matplotlib as mpl, math
import fastcore.all as fc
from operator import attrgetter
from collections.abc import Mapping
from functools import partial
from miniai.callbacks import *
from miniai.hooks import Hooks

In [3]:
from miniai.data import *
from torch.utils.data import DataLoader
from miniai.plotting import *
from miniai.callbacks import *
from miniai.custom_modules import conv
from torcheval.metrics import MulticlassAccuracy
from tqdm import tqdm

mpl.rcParams['image.cmap'] = 'Greys'
torch.set_printoptions(precision=2, linewidth=100)

In [4]:
data = pd.read_csv('data/fashion_mnist/train.csv')

In [5]:
X_train = tensor(data.iloc[:50000, 1:].values)
y_train = tensor(data.iloc[:50000, 0].values)
X_valid = tensor(data.iloc[50000:, 1:].values)
y_valid = tensor(data.iloc[50000:, 0].values)

In [6]:
dset_train = Dataset((X_train[:5000]/255. ).view(-1, 1, 28, 28), y_train[:5000])
dset_valid = Dataset((X_valid[:500]/255. ).view(-1, 1, 28, 28), y_valid[:500])

In [7]:
dl_train = DataLoader(dset_train, batch_size=64, shuffle=True)
dl_valid = DataLoader(dset_valid, batch_size=len(dset_valid), shuffle=False)

In [8]:
dls = DataLoaders(dl_train, dl_valid)

<br><br> **Learner with callbacks**

In [9]:
#|export
class Learner:
    def __init__(self, dls, model, loss_func, opt_func=optim.Adam, lr=None, cbs=[]): 
        fc.store_attr()
        
    @with_cbs('batch')
    def _one_batch(self):
        # Get the gradients by calculating the loss
        self.predict()
        self.get_loss()
        # Update the weights
        if self.training:
            self.backward()
            self.step()
            self.zero_grad()
            
    @with_cbs('epoch')
    def _one_epoch(self):
        for self.batch_iter, self.batch in enumerate(self.dl): self._one_batch()
    
    @with_cbs('fit')
    def _fit(self):
        for self.epoch in self.epochs:
            self.one_epoch(True)
            with torch.no_grad(): self.one_epoch(False)
    
    def one_epoch(self, train):
        self.model.train(train)
        self.training = train
        self.dl = self.dls.train if train else self.dls.valid
        self._one_epoch()

    def fit(self, epochs, lr=None, cbs=None):
        if (lr != None): self.lr = lr
        if (cbs != None): orig_cbs = self.cbs; self.cbs = cbs
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.epochs = range(epochs)
        self._fit()
        if (cbs != None): self.cbs = orig_cbs
            
    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
        
    def __getattr__(self, nm):
        if nm in ['predict', 'get_loss', 'backward', 'step', 'zero_grad']: return partial(self.callback, nm)
        raise AttributeError(nm)
    
    def lr_find(self, start_lr=1e-7, coef=1.3):
        lr_finder = LRFinderCB(coef)
        self.fit(1, start_lr, cbs=[TrainCB(), lr_finder, MetricCB(), ProgressCB()])
        lr_finder.plot();

In [10]:
#|export
class TrainLearner(Learner):
    def predict(self): self.preds = self.model(self.batch[0])
    def get_loss(self): self.loss = self.loss_func(self.preds, self.batch[1])
    def backward(self): self.loss.backward()
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()

In [11]:
#|export
class MomentumLearner(TrainLearner):
    def __init__(self, dls, model, loss_func, opt=optim.Adam, lr=None, cbs=[], mom=0.85):
        super(TrainLearner, self).__init__(dls, model, loss_func, opt_func=opt, lr=lr, cbs=cbs)
        self.mom = mom

    def zero_grad(self): 
        with torch.no_grad():
            for p in self.model.parameters(): p.grad = p.grad*self.mom

In [12]:
lr_find_model = nn.Sequential(
    conv(1, 32, act=nn.ReLU()), #14×14
    conv(32, 64, act=nn.ReLU()), #7×7
    conv(64, 128, act=nn.ReLU()), #4×4
    conv(128, 256, act=nn.ReLU()), #2×2
    conv(256, 10, act=None), #1×1
    nn.Flatten()
)

In [13]:
model = nn.Sequential(
    conv(1, 32, act=nn.ReLU()), #14×14
    conv(32, 64, act=nn.ReLU()), #7×7
    conv(64, 128, act=nn.ReLU()), #4×4
    conv(128, 256, act=nn.ReLU()), #2×2
    conv(256, 10, act=None), #1×1
    nn.Flatten()
)

In [15]:
metrics = MetricCB(accuracy=MulticlassAccuracy())
learn = MomentumLearner(dls, model, F.cross_entropy, optim.Adam, cbs=[metrics, ProgressCB()])

In [16]:
learn.fit(2, 5e-3)

epoch,train_loss,valid_loss,accuracy
0,1.517,1.335,0.512
1,1.062,0.943,0.64


In [18]:
#|export
def _flops(x, h, w):
    if x.dim()<3: return x.numel()
    if x.dim()==4: return x.numel()*h*w
    
@fc.patch
def summary(self:Learner):
        res = f'|Module|Input shape|Output shape|Param count|MFLOPS|\n|--|--|--|--|--|\n'
        total_params = 0; total_MFLOPS = 0
        def _get_summary(hook, mod, inp, out):
            nonlocal res, total_params, total_MFLOPS
            param_cnt = 0; flops = 0
            for p in mod.parameters():
                param_cnt += torch.numel(p)
                flops += _flops(p, out.shape[-2], out.shape[-1])
            flops /= 1e6
            res += f'|{type(mod).__name__:14}|{str(list(inp[0].shape)):20}|{str(list(out.shape)):16}|{param_cnt}|{flops:.1f}|\n'
            total_params += param_cnt; total_MFLOPS += flops
        with Hooks(self.model, _get_summary) as hooks: self.model(next(iter(self.dls.train))[0])
        if fc.IN_NOTEBOOK:
            from IPython.display import Markdown
            return Markdown(res + f'|--|--|--|{total_params}|{total_MFLOPS:.1f}|\n')
        else: print(res + f'|--|--|--|{total_params}|{total_MFLOPS:1.f}|\n')

In [19]:
#|export
@fc.patch
def show_batch_images(self:Learner, n=None, **kwargs):
    batch = next(iter(self.dls.train))
    if n is None: n = batch[0].shape[0]
    show_images(batch[0][:n], titles=batch[1].tolist(), **kwargs);

In [21]:
import nbdev; nbdev.nbdev_export()