In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nb_004c import *

# Dogs and cats

## Basic data aug

In [None]:
PATH = Path('../../data/dogscats')

In [None]:
data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
data_norm,data_denorm = normalize_funcs(data_mean,data_std)

In [None]:
train_ds = FilesDataset.from_folder(PATH/'train')
valid_ds = FilesDataset.from_folder(PATH/'valid')

In [None]:
#export
def get_transforms(do_flip=False, max_rotate=0., max_zoom=1., p_affine=0.75):
    res = [rand_crop()]
    # TODO: dihedral, lighting, warp
    if do_flip:    res.append(flip_lr(p=0.5))
    if max_rotate: res.append(rotate(degrees=(-max_rotate,max_rotate), p=p_affine))
    if max_zoom>1: res.append(rand_zoom(scale=(1.,max_zoom), p=p_affine))
    return (res, [crop_pad()])  #train,valid

def transform_datasets(train_ds, valid_ds, tfms, size=None):
    return (DatasetTfm(train_ds, tfms[0], size=size),
            DatasetTfm(valid_ds, tfms[1], size=size))

In [None]:
tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.1)

In [None]:
size=224
tds = transform_datasets(train_ds, valid_ds, tfms, size=size)

In [None]:
data = DataBunch(*tds, bs=64, num_workers=8, tfms=data_norm)

In [None]:
(x,y) = next(iter(data.valid_dl))

_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)

In [None]:
(x,y) = next(iter(data.train_dl))

_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)

## Train

In [None]:
from torchvision.models import resnet34
arch = resnet34

In [None]:
model = arch()
opt_fn = partial(optim.SGD, momentum=0.9)
learn = Learner(data, model, opt_fn=opt_fn, true_wd=True)
learn.metrics = [accuracy]

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
fit_one_cycle(learn, 1e-2, 1, wd=1e-2)

## Model with a new head

In [None]:
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or 1
        self.ap,self.mp = nn.AdaptiveAvgPool2d(sz), nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

In [None]:
def create_skeleton(model, cut):
    layers = list(model.children())[:-cut] if cut else [model]
    layers += [AdaptiveConcatPool2d(), Flatten()]
    return nn.Sequential(*layers)

In [None]:
def num_features(m):
    c=list(m.children())
    if len(c)==0: return None
    for l in reversed(c):
        if hasattr(l, 'num_features'): return l.num_features
        res = num_features(l)
        if res is not None: return res

In [None]:
model = create_skeleton(arch(), 2)
num_features(model)

In [None]:
def bn_dp_lin(n_in, n_out, bn=True, dp=0., actn=None):
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if dp != 0: layers.append(nn.Dropout(dp))
    layers.append(nn.Linear(n_in, n_out))
    if actn is not None: layers.append(actn)
    return layers

def create_head(nf, nc, lin_ftrs=None, dps=None):
    lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
    if dps is None: dps = [0.25] * (len(lin_ftrs)-2) + [0.5]
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    for ni,no,dp,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],dps,actns): 
        layers += bn_dp_lin(ni,no,True,dp,actn)
    return nn.Sequential(*layers)

In [None]:
create_head(512, 2)

In [None]:
class ConvLearner(Learner):
    def __init__(self, data, arch, cut, pretrained=True, lin_ftrs=None, dps=None, **kwargs):
        self.skeleton = create_skeleton(arch(pretrained), cut)
        nf = num_features(self.skeleton) * 2
        # XXX: better way to get num classes
        self.head = create_head(nf, len(data.train_ds.ds.classes), lin_ftrs, dps)
        model = nn.Sequential(self.skeleton, self.head)
        super().__init__(data, model, **kwargs)
    
    def freeze(self):
        for p in self.skeleton.parameters(): p.require_grad = False
    
    def unfreeze(self):
        for p in self.skeleton.parameters(): p.require_grad = True

In [None]:
learn = ConvLearner(data, arch, 2)
learn.metrics = [accuracy]
learn.freeze()

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
learn.fit(1, 1e-2)

In [None]:
opt_fn=partial(optim.Adam, betas=(0.9,0.99))