In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_004b import *
import torchvision.models as tvm

# Dogs and cats

## Basic data aug

In [None]:
PATH = Path('data/dogscats')

train_ds = FilesDataset.from_folder(PATH/'train')
valid_ds = FilesDataset.from_folder(PATH/'valid')

arch = tvm.resnet34

In [None]:
#export
def uniform_int(low, high, size=None):
    return random.randint(low,high) if size is None else torch.randint(low,high,size)

@TfmPixel
def dihedral(x, k:partial(uniform_int,0,8)):
    flips=[]
    if k&1: flips.append(1)
    if k&2: flips.append(2)
    if flips: x = torch.flip(x,flips)
    if k&4: x = x.transpose(1,2)
    return x.contiguous()

In [None]:
x=valid_ds[2][0]
_,axes = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axes.flat): dihedral(x,i).show(ax)

In [None]:
#export
def get_transforms(do_flip=False, flip_vert=False, max_rotate=0., max_zoom=1., max_lighting=0., max_warp=0.,
                   p_affine=0.75, p_lighting=0.5, xtra_tfms=None):
    res = [rand_crop()]
    if do_flip:    res.append(dihedral() if flip_vert else flip_lr(p=0.5))
    if max_warp:   res.append(symmetric_warp(magnitude=(-max_warp,max_warp), p=p_affine))
    if max_rotate: res.append(rotate(degrees=(-max_rotate,max_rotate), p=p_affine))
    if max_zoom>1: res.append(rand_zoom(scale=(1.,max_zoom), p=p_affine))
    if max_lighting:
        res.append(brightness(change=(0.5*(1-max_lighting), 0.5*(1+max_lighting)), p=p_lighting))
        res.append(contrast(scale=(1-max_lighting, 1/(1-max_lighting)), p=p_lighting))
    #       train                   , valid
    return (res + listify(xtra_tfms), [crop_pad()])  

def transform_datasets(train_ds, valid_ds, tfms, **kwargs):
    return (DatasetTfm(train_ds, tfms[0], **kwargs),
            DatasetTfm(valid_ds, tfms[1], **kwargs),
            DatasetTfm(valid_ds, tfms[0], **kwargs))

imagenet_stats = tensor([0.485, 0.456, 0.406]), tensor([0.229, 0.224, 0.225])

In [None]:
data_norm,data_denorm = normalize_funcs(*imagenet_stats)

In [None]:
#export
class DataBunch():
    def __init__(self, train_dl:DataLoader, valid_dl:DataLoader, augm_dl:DataLoader=None,
                 device:torch.device=None, tfms=None):
        self.device = default_device if device is None else device
        self.train_dl = DeviceDataLoader(train_dl, self.device, tfms=tfms)
        self.valid_dl = DeviceDataLoader(valid_dl, self.device, tfms=tfms)
        if augm_dl: self.augm_dl = DeviceDataLoader(augm_dl,  self.device, tfms=tfms)

    @classmethod
    def create(cls, train_ds, valid_ds, augm_ds=None, bs=64, train_tfm=None, valid_tfm=None, num_workers=4,
               tfms=None, device=None, **kwargs):
        if train_tfm or not isinstance(train_ds, DatasetTfm): train_ds = DatasetTfm(train_ds,train_tfm, **kwargs)
        if valid_tfm or not isinstance(valid_ds, DatasetTfm): valid_ds = DatasetTfm(valid_ds,valid_tfm, **kwargs)
        if not augm_ds: augm_ds = DatasetTfm(valid_ds, train_tfm, **kwargs)
        return cls(DataLoader(train_ds, bs,   shuffle=True,  num_workers=num_workers),
                   DataLoader(valid_ds, bs*2, shuffle=False, num_workers=num_workers),
                   DataLoader(augm_ds,  bs*2, shuffle=False, num_workers=num_workers),
                   device=device, tfms=tfms)

    @property
    def train_ds(self): return self.train_dl.dl.dataset
    @property
    def valid_ds(self): return self.valid_dl.dl.dataset
    @property
    def c(self): return self.train_ds.c

In [None]:
size=224

tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.2, max_lighting=0.3, max_warp=0.15)
tds = transform_datasets(train_ds, valid_ds, tfms, size=size)
data = DataBunch.create(*tds, bs=64, num_workers=8, tfms=data_norm)

In [None]:
(x,y) = next(iter(data.valid_dl))

_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)

In [None]:
_,axs = plt.subplots(4,4,figsize=(12,12))
for ax in axs.flat: tds[0][2][0].show(ax)

## ConvLearner

In [None]:
#export
def train_epoch(model, dl, opt, loss_func):
    "Simple training of `model` for 1 epoch of `dl` using optim `opt` and loss function `loss_func`"
    model.train()
    for xb,yb in dl:
        loss = loss_func(model(xb), yb)
        loss.backward()
        opt.step()
        opt.zero_grad()

In [None]:
#export
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or 1
        self.ap,self.mp = nn.AdaptiveAvgPool2d(sz), nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

def create_body(model, cut=None, body_fn=None):
    return (nn.Sequential(*list(model.children())[:-cut]) if cut
            else body_fn(model) if body_fn else model)

def num_features(m):
    for l in reversed(flatten_model(m)):
        if hasattr(l, 'num_features'): return l.num_features

In [None]:
model = create_body(arch(), 2)
num_features(model)

In [None]:
#export
def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None):
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if p != 0: layers.append(nn.Dropout(p))
    layers.append(nn.Linear(n_in, n_out))
    if actn is not None: layers.append(actn)
    return layers

def create_head(nf, nc, lin_ftrs=None, ps=0.2):
    lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
    ps = listify(ps)
    if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = [AdaptiveConcatPool2d(), Flatten()]
    for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns): 
        layers += bn_drop_lin(ni,no,True,p,actn)
    return nn.Sequential(*layers)

In [None]:
create_head(512, 2)

In [None]:
#export
def cond_init(m, init_fn):
    if (not isinstance(m, bn_types)) and requires_grad(m):
        if hasattr(m, 'weight'): init_fn(m.weight)
        if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)

def apply_leaf(m, f):
    c = children(m)
    if isinstance(m, nn.Module): f(m)
    for l in c: apply_leaf(l,f)

def apply_init(m, init_fn): apply_leaf(m, partial(cond_init, init_fn=init_fn))

def _init(learn, init): apply_init(learn.model, init)
Learner.init = _init

class ConvLearner(Learner):
    def __init__(self, data, arch, cut, pretrained=True, lin_ftrs=None, ps=0.2, custom_head=None, **kwargs):
        body = create_body(arch(pretrained), cut)
        nf = num_features(body) * 2
        head = custom_head or create_head(nf, data.c, lin_ftrs, ps)
        model = nn.Sequential(body, head)
        super().__init__(data, model, **kwargs)
        self.split([model[1]])
        if pretrained: self.freeze()
        apply_init(model[1], nn.init.kaiming_normal_)

In [None]:
lr = 3e-3

## Train

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-2, metrics=accuracy)

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(1, slice(lr))

In [None]:
learn.fit_one_cycle(3, slice(lr))

In [None]:
learn.save('0')

## Unfreeze

In [None]:
learn.load('0')

In [None]:
learn.unfreeze()

In [None]:
lr=6e-4

In [None]:
learn.fit_one_cycle(6, slice(lr/25,lr), pct_start=0.05)

In [None]:
learn.save('1')

In [None]:
learn.load('1')

## Save activations

In [None]:
#export
class Hook():
    def __init__(self, m, hook_func, is_forward=True):
        self.hook_func,self.stored = hook_func,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module, input, output):
        input  = (o.detach() for o in input ) if is_listy(input ) else input.detach()
        output = (o.detach() for o in output) if is_listy(output) else output.detach()
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        if not self.removed:
            self.hook.remove()
            self.removed=True

class Hooks():
    def __init__(self, ms, hook_func, is_forward=True):
        self.hooks = [Hook(m, hook_func, is_forward) for m in ms]
        
    def __getitem__(self,i): return self.hooks[i]
    def __len__(self): return len(self.hooks)
    def __iter__(self): return iter(self.hooks)
    @property
    def stored(self): return [o.stored for o in self]
    
    def remove(self):
        for h in self.hooks: h.remove()

def hook_output (module):  return Hook (module,  lambda m,i,o: o)
def hook_outputs(modules): return Hooks(modules, lambda m,i,o: o)

In [None]:
#export
class HookCallback(LearnerCallback):
    def __init__(self, learn, modules=None, do_remove=True):
        super().__init__(learn)
        self.modules,self.do_remove = modules,do_remove

    def on_train_begin(self, **kwargs):
        if not self.modules:
            self.modules = [m for m in flatten_model(self.learn.model)
                            if hasattr(m, 'weight')]
        self.hooks = Hooks(self.modules, self.hook)

    def on_train_end(self, **kwargs):
        if self.do_remove: self.remove()

    def remove(self): self.hooks.remove
    def __del__(self): self.remove()

class ActivationStats(HookCallback):
    def on_train_begin(self, **kwargs):
        super().on_train_begin(**kwargs)
        self.stats = []
        
    def hook(self, m,i,o): return o.mean().item(),o.std().item()
    def on_batch_end(self, **kwargs): self.stats.append(self.hooks.stored)
    def on_train_end(self, **kwargs): self.stats = tensor(self.stats).permute(2,1,0)

def idx_dict(a): return {v:k for k,v in enumerate(a)}

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-2, metrics=accuracy,
                    callback_fns=ActivationStats)

In [None]:
learn.fit_one_cycle(1, lr)

In [None]:
ms = learn.activation_stats.modules
d = idx_dict(ms)
ln = d[learn.model[1][8]]; ln

In [None]:
plt.plot(learn.activation_stats.stats[1][ln].numpy());

## TTA

In [None]:
_,axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flat: tds[2][1][0].show(ax)

In [None]:
model = learn.model

In [None]:
#export
def get_preds(model, dl, pbar=None):
    return [torch.cat(o).cpu() for o in validate(model, dl, pbar=pbar)]

In [None]:
preds,y = get_preds(model, data.valid_dl)

In [None]:
accuracy(preds, y)

In [None]:
pbar = master_bar(range(4))
all_preds = torch.stack([get_preds(model, data.augm_dl, pbar=pbar)[0] for _ in pbar])

In [None]:
avg_preds = all_preds.mean(0)
avg_preds.shape

In [None]:
accuracy(avg_preds, y)

In [None]:
beta=0.5
accuracy(preds*beta + avg_preds*(1-beta), y)

In [None]:
def TTA(model, valid_dl, augm_dl, n=4, beta=0.5):
    preds,y = get_preds(model, valid_dl)
    pbar = master_bar(range(n))
    all_preds = torch.stack([get_preds(model, augm_dl, pbar=pbar)[0]
                             for _ in pbar]).mean(0)
    return preds*beta + avg_preds*(1-beta)

def _learn_TTA(learn, n=4, beta=0.5):
    return TTA(learn.model, learn.data.valid_dl, learn.data.augm_dl, n=n, beta=beta)

Learner.TTA = _learn_TTA

In [None]:
learn = ConvLearner(data, arch, 2, metrics=accuracy)

In [None]:
learn.fit_one_cycle(1, lr)

In [None]:
tta_preds = learn.TTA()

In [None]:
accuracy(tta_preds, y)

## Fin