In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_004b import *

# Dogs and cats

## Basic data aug

In [None]:
PATH = Path('data/dogscats')

data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
data_norm,data_denorm = normalize_funcs(data_mean,data_std)

train_ds = FilesDataset.from_folder(PATH/'train')
valid_ds = FilesDataset.from_folder(PATH/'valid')

In [None]:
from torchvision.models import resnet18, resnet34
arch = resnet34

In [None]:
#export
def uniform_int(low, high, size=None):
    return random.randint(low,high) if size is None else torch.randint(low,high,size)

@TfmPixel
def dihedral(x, k:partial(uniform_int,0,8)):
    flips=[]
    if k&1: flips.append(1)
    if k&2: flips.append(2)
    if flips: x = torch.flip(x,flips)
    if k&4: x = x.transpose(1,2)
    return x.contiguous()

In [None]:
x=valid_ds[2][0]
_,axes = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axes.flat): dihedral(x,i).show(ax)

In [None]:
#export
def get_transforms(do_flip=False, flip_vert=False, max_rotate=0., max_zoom=1., max_lighting=0., max_warp=0.,
                   p_affine=0.75, p_lighting=0.5, xtra_tfms=None):
    res = [rand_crop()]
    if do_flip:    res.append(dihedral() if flip_vert else flip_lr(p=0.5))
    if max_warp:   res.append(symmetric_warp(magnitude=(-max_warp,max_warp), p=p_affine))
    if max_rotate: res.append(rotate(degrees=(-max_rotate,max_rotate), p=p_affine))
    if max_zoom>1: res.append(rand_zoom(scale=(1.,max_zoom), p=p_affine))
    if max_lighting:
        res.append(brightness(change=(0.5*(1-max_lighting), 0.5*(1+max_lighting)), p=p_lighting))
        res.append(contrast(scale=(1-max_lighting, 1/(1-max_lighting)), p=p_lighting))
    #       train                   , valid
    return (res + listify(xtra_tfms), [crop_pad()])  

def transform_datasets(train_ds, valid_ds, tfms, **kwargs):
    return (DatasetTfm(train_ds, tfms[0], **kwargs),
            DatasetTfm(valid_ds, tfms[1], **kwargs),
            DatasetTfm(valid_ds, tfms[0], **kwargs))

In [None]:
class DataBunch():
    def __init__(self, train_dl:DataLoader, valid_dl:DataLoader, augm_dl:DataLoader=None,
                 device:torch.device=None, tfms=None):
        self.device = default_device if device is None else device
        self.train_dl = DeviceDataLoader(train_dl, self.device, tfms=tfms)
        self.valid_dl = DeviceDataLoader(valid_dl, self.device, tfms=tfms)
        if augm_dl: self.augm_dl = DeviceDataLoader(augm_dl,  self.device, tfms=tfms)

    @classmethod
    def create(cls, train_ds, valid_ds, augm_ds=None, bs=64, train_tfm=None, valid_tfm=None, num_workers=4,
               tfms=None, device=None, **kwargs):
        if train_tfm or not isinstance(train_ds, DatasetTfm): train_ds = DatasetTfm(train_ds,train_tfm, **kwargs)
        if valid_tfm or not isinstance(valid_ds, DatasetTfm): valid_ds = DatasetTfm(valid_ds,valid_tfm, **kwargs)
        if not augm_ds: augm_ds = DatasetTfm(valid_ds, train_tfm, **kwargs)
        return cls(DataLoader(train_ds, bs,   shuffle=True,  num_workers=num_workers),
                   DataLoader(valid_ds, bs*2, shuffle=False, num_workers=num_workers),
                   DataLoader(augm_ds,  bs*2, shuffle=False, num_workers=num_workers),
                   device=device, tfms=tfms)

    @property
    def train_ds(self): return self.train_dl.dl.dataset
    @property
    def valid_ds(self): return self.valid_dl.dl.dataset
    @property
    def c(self): return self.train_ds.c

In [None]:
size=224

tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.2, max_lighting=0.3, max_warp=0.15)
tds = transform_datasets(train_ds, valid_ds, tfms, size=size)
data = DataBunch.create(*tds, bs=64, num_workers=8, tfms=data_norm)

In [None]:
(x,y) = next(iter(data.valid_dl))

_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)

In [None]:
_,axs = plt.subplots(4,4,figsize=(12,12))
for ax in axs.flat: show_image(tds[0][2][0], ax)

## Train

In [None]:
#export
def train_epoch(model, dl, opt):
    "Simple training of `model` for 1 epoch of `dl` using `opt`; mainly for quick tests"
    model.train()
    for xb,yb in dl:
        loss = F.cross_entropy(model(xb), yb)
        loss.backward()
        opt.step()
        opt.zero_grad()

In [None]:
#export
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or 1
        self.ap,self.mp = nn.AdaptiveAvgPool2d(sz), nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

def create_body(model, cut=None, body_fn=None):
    layers = (list(model.children())[:-cut] if cut
              else [body_fn(model)] if body_fn else [model])
    layers += [AdaptiveConcatPool2d(), Flatten()]
    return nn.Sequential(*layers)

def num_features(m):
    for l in reversed(flatten_model(m)):
        if hasattr(l, 'num_features'): return l.num_features

In [None]:
model = create_body(arch(), 2)
num_features(model)

In [None]:
#export
def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None):
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if p != 0: layers.append(nn.Dropout(p))
    layers.append(nn.Linear(n_in, n_out))
    if actn is not None: layers.append(actn)
    return layers

def create_head(nf, nc, lin_ftrs=None, ps=None):
    lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
    if ps is None: ps = [0.25] * (len(lin_ftrs)-2) + [0.5]
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns): 
        layers += bn_drop_lin(ni,no,True,p,actn)
    return nn.Sequential(*layers)

In [None]:
create_head(512, 2)

In [None]:
#export
def cond_init(m, init_fn):
    if not isinstance(m, bn_types):
        if hasattr(m, 'weight'): init_fn(m.weight)
        if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
            
def apply_init(m, init_fn):
    m.apply(lambda x: cond_init(x, init_fn))    

def _set_mom(m, mom):
    if isinstance(m, bn_types): m.momentum=mom

def set_mom(m, mom): m.apply(lambda x: _set_mom(x, mom))

class ConvLearner(Learner):
    def __init__(self, data, arch, cut, pretrained=True, lin_ftrs=None, ps=None, **kwargs):
        body = create_body(arch(pretrained), cut)
        nf = num_features(body) * 2
        head = create_head(nf, data.c, lin_ftrs, ps)
        model = nn.Sequential(body, head)
        super().__init__(data, model, **kwargs)
        self.split([model[1]])
        apply_init(model[1], nn.init.kaiming_normal_)

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-2, metrics = accuracy)
learn.split(lambda m: (m[0][6], m[1]))
learn.freeze()

In [None]:
lr = 1e-3

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
learn.fit(3, slice(lr))

In [None]:
learn.fit_one_cycle(3, slice(lr*4), pct_start=0.05, moms=0.9, pct_end=0.25)

In [None]:
learn.save('0')

## Gradual unfreezing

In [None]:
learn.load('0')

In [None]:
learn.unfreeze()

In [None]:
lrs = learn.lr_range(slice(lr/100,lr))
lrs

In [None]:
fit_one_cycle(learn, 10, lrs, pct_start=0.01, pct_end=0.35)

In [None]:
fit_one_cycle(learn, 3, lrs/12)

In [None]:
learn.save('2')

## TTA

In [None]:
learn.load('2')

In [None]:
_,axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flat: show_image(tds[2][1][0], ax)

In [None]:
model = learn.model
model.eval();

In [None]:
with torch.no_grad():
    preds,y = zip(*[(model(xb.detach()), yb.detach()) for xb,yb in data.valid_dl])

preds = torch.cat(preds)
y = torch.cat(y)

In [None]:
accuracy(preds, y)

In [None]:
def get_preds(model, dl):
    with torch.no_grad():
        return torch.cat([model(xb.detach()) for xb,yb in dl])

In [None]:
all_preds = torch.stack([get_preds(model, data.augm_dl) for _ in range(4)])

In [None]:
avg_preds = all_preds.mean(0)
avg_preds.shape

In [None]:
accuracy(avg_preds, y)

In [None]:
beta=0.75
accuracy(preds*beta + avg_preds*(1-beta), y)

## Fin