In [None]:
#| default_exp learner

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export 
import math, torch, matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy

from torch import optim
import torch.nn.functional as F

from miniai.conv import *

from fastprogress import progress_bar, master_bar

In [None]:
import matplotlib as mpl
import torchvision.transforms.functional as TF
from contextlib import contextmanager
from torch import nn, tensor
from datasets import load_dataset, load_dataset_builder
from miniai.datasets import *
import logging
from fastcore.test import test_close

In [None]:
torch.set_printoptions(precision=2, linewidth=160, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

In [None]:
logging.disable(logging.WARNING)

## Learner

In [None]:
x, y = 'image', 'label'
name = 'fashion_mnist'
dsd = load_dataset(name)

  0%|          | 0/2 [00:00<?, ?it/s]

Grab a single example from Dataset and check its shape. We will be working with flattened images, so we convert PIL to tensor and flatten it out to get 784 (28x28) long tensor

In [None]:
ex = TF.to_tensor(dsd['train'][0][x])
ex.shape

torch.Size([1, 28, 28])

In [None]:
torch.flatten(TF.to_tensor(dsd['train'][0][x])).shape

torch.Size([784])

In [None]:
@inplace
def transformi(b):
    # b is a dictionary of image and label
    # import ipdb; ipdb.set_trace()
    b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [None]:
bs = 1024
tds = dsd.with_transform(transformi)

In [None]:
# remove number of workers during debugging
dls = DataLoaders.from_dd(tds, bs, num_workers=4)

dt = dls.train
xb, yb = next(iter(dt))
xb.shape, yb[:10]

(torch.Size([1024, 784]), tensor([8, 0, 4, 0, 5, 8, 3, 8, 3, 4]))

Let's remind ourselves how `DataLoaders.from_dd` actually works. 

```python
def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f, **kwargs))
```
Let's first deal with the first line of code and `f` function. `f` is an internal collation function that itself returns a function returning collated `x` and `y` values from a given batch. This is how it is done:
1. `collate_dict` takes feature names from `dd['train']`: `image` and `label` and puts them into an itemgetter
2.  this itemgetter is applied to the result of a `default_collate` on a given Dataset returning a tuple of `x` and `y` values

```python
def collate_dict(ds):
    get = itemgetter(*ds.features) # get x and y values
    def _f(b): return get(default_collate(b))
    return _f
```

The second line of code (returning) can be broken down into several parts: call `get_dls` and wrap its return value into `cls`. Let's tackle it one by one.

```python
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, bs, shuffle=True, **kwargs), DataLoader(valid_ds, bs*2, **kwargs))
```

`get_dls` takes a train and valid datasets and turns them into the respective DataLoaders (PyTorch). Our created `collate_dict` function (assigned to `f`) is passed to the DataLoader `__init__` method via `**kwargs`

Wrapping the resulting tuple of DataLoaders inside `cls` simply allows us to get them by calling .train and .vaild on our Dataloaders class.

```python
def __init__(self, *dls): self.train, self.valid = dls[:2]
```

Now we should fully understand the initial code above. We create dataloaders (train and valid), then we select train dataloader and get one batche of size 1024 from it.

```python
dls = DataLoaders.from_dd(tds, bs)
dt = dls.train
xb, yb = next(iter(dt))
```
During the call to iterator our collation function (`f`) comes in play and merges the 1024 element list of dictionaries into a tuple of `x` and `y` tensors.

Now let's move on to creating our first Learner framework.

##  Learner

First we create a simplified learner that still has some nice structure to it.

In [None]:
class Learner:
    def __init__(self, model, dls, loss_func, lr, opt_func=optim.SGD): fc.store_attr()
    
    def one_batch(self):
        self.xb, self.yb = to_device(self.batch)
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        with torch.no_grad(): self.calc_stats()
        
    def calc_stats(self):
        acc = (self.preds.argmax(dim=-1)==self.yb).float().sum()
        self.accs.append(acc)
        n = len(self.xb)
        self.losses.append(self.loss * n)
        self.ns.append(n)
        
        
    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        for _, self.batch in enumerate(dl): self.one_batch()
        n = sum(self.ns)
        print(self.epoch, self.model.training, 
              sum(self.losses).item()/n, sum(self.accs).item()/n)
        
        
    def fit(self, n_epochs):
        self.accs, self.losses, self.ns = [], [], []
        self.model.to(def_device)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.one_epoch(True)
            with torch.no_grad(): self.one_epoch(False)

In [None]:
m, nh = 28*28, 50
model = nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, 10))

In [None]:
learn = Learner(model, dls, F.cross_entropy, lr=0.2)
learn.fit(1)

0 True 1.154941015625 0.6162166666666666
0 False 1.10222109375 0.6292


## Basic Callbacks Learner 

Now let's improve greatly upon our previus iteration and add Callbacks to our Learner

In [None]:
#|export
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [None]:
#|export
class Callback: order = 0

In [None]:
#|export
def run_cbs(cbs, method_nm, learn=None):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        #?
        if method is not None: method(learn)

In [None]:
class CompletionCB(Callback):
    def before_fit(self, learn): self.count = 0
    def after_batch(self, learn): self.count += 1
    def after_fit(self, learn): print(f'Completed {self.count} batches')

In [None]:
cbs = [CompletionCB()]
run_cbs(cbs, 'before_fit')
run_cbs(cbs, 'after_batch')
run_cbs(cbs, 'after_fit')

Completed 1 batches
