In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nb_005 import *

# STL-10

## Basic data aug

In [None]:
PATH = Path('data/stl10')

In [None]:
data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
data_norm,data_denorm = normalize_funcs(data_mean,data_std)

In [None]:
train_ds = FilesDataset.from_folder(PATH/'train')
valid_ds = FilesDataset.from_folder(PATH/'valid')

In [None]:
x=Image(valid_ds[2][0])
x.show()
x.shape

In [None]:
tfms = get_transforms(do_flip=True, max_rotate=5, max_zoom=1.25, max_lighting=0.4, max_warp=0.15)

In [None]:
size=96
tds = transform_datasets(train_ds, valid_ds, tfms, size=size)

In [None]:
data = DataBunch(*tds, bs=64, num_workers=8, tfms=data_norm)

In [None]:
(x,y) = next(iter(data.valid_dl))

_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)

In [None]:
_,axs = plt.subplots(4,4,figsize=(12,12))
for ax in axs.flat: show_image(tds[0][1][0], ax)

## Train

In [None]:
from torchvision.models import resnet18, resnet34, resnet50
arch = resnet18

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-3)
learn.metrics = [accuracy]
learn.freeze()

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr = 1e-3

In [None]:
fit_one_cycle(learn, lr, 2)

In [None]:
learn.save('0')

## Gradual unfreezing

In [None]:
class ConvLearner(Learner):
    def __init__(self, data, arch, cut, pretrained=True, lin_ftrs=None, dps=None, **kwargs):
        self.skeleton = create_skeleton(arch(pretrained), cut)
        nf = num_features(self.skeleton) * 2
        # XXX: better way to get num classes
        self.head = create_head(nf, len(data.train_ds.ds.classes), lin_ftrs, dps)
        model = nn.Sequential(self.skeleton, self.head)
        super().__init__(data, model, **kwargs)
    
    def freeze_to(self, n):
        for g in self.layer_groups[:n]: 
            for p in g.parameters(): p.requires_grad = False
        for g in self.layer_groups[n:]:
            for p in g.parameters(): p.requires_grad = True
            
    def freeze(self): self.freeze_to(len(self.layer_groups))
    def unfreeze(self): self.freeze_to(0)

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-3)
learn.metrics = [accuracy]

In [None]:
learn.load('0')

In [None]:
learn.split(lambda m: (m[0][6], m[1]))

In [None]:
learn.freeze_to(1)

In [None]:
lrs = np.array([lr/9, lr/3, lr])

In [None]:
fit_one_cycle(learn, lrs/3, 1)

In [None]:
learn.save('1')

In [None]:
learn.load('1')

In [None]:
learn.unfreeze()

In [None]:
%time fit_one_cycle(learn, lrs/9, 1)

In [None]:
learn.save('2')

In [None]:
# TODO remove layer groups; use start_layer / start_lr

## TTA

In [None]:
def transform_datasets(train_ds, valid_ds, tfms, size=None):
    return (DatasetTfm(train_ds, tfms[0], size=size),
            DatasetTfm(valid_ds, tfms[1], size=size),
            DatasetTfm(valid_ds, tfms[0], size=size))

In [None]:
class DataBunch():
    def __init__(self, train_ds, valid_ds, augm_ds, bs=64, device=None, num_workers=4, **kwargs):
        self.device = default_device if device is None else device
        self.train_dl = DeviceDataLoader.create(train_ds, bs,   shuffle=True,  num_workers=num_workers, **kwargs)
        self.valid_dl = DeviceDataLoader.create(valid_ds, bs*2, shuffle=False, num_workers=num_workers, **kwargs)
        self.augm_dl  = DeviceDataLoader.create(augm_ds,  bs*2, shuffle=False, num_workers=num_workers, **kwargs)

    @classmethod
    def create(cls, train_ds, valid_ds, train_tfm=None, valid_tfm=None, dl_tfms=None, **kwargs):
        return cls(DatasetTfm(train_ds, train_tfm), DatasetTfm(valid_ds, valid_tfm), DatasetTfm(valid_ds, train_tfm), 
                   tfms=dl_tfms, **kwargs)

    @property
    def train_ds(self): return self.train_dl.dl.dataset
    @property
    def valid_ds(self): return self.valid_dl.dl.dataset

In [None]:
tds = transform_datasets(train_ds, valid_ds, tfms, size=size)
data = DataBunch(*tds, num_workers=8, tfms=data_norm)

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-3)
learn.metrics = [accuracy]

In [None]:
learn.load('2')

In [None]:
_,axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flat: show_image(tds[2][1][0], ax)

In [None]:
model = learn.model
model.eval();

In [None]:
with torch.no_grad():
    preds,y = zip(*[(model(xb.detach()), yb.detach()) for xb,yb in data.valid_dl])

preds = torch.cat(preds)
y = torch.cat(y)

In [None]:
accuracy(preds, y)

In [None]:
def get_preds(model, dl):
    with torch.no_grad():
        return torch.cat([model(xb.detach()) for xb,yb in dl])

In [None]:
all_preds = torch.stack([get_preds(model, data.augm_dl) for _ in range(4)])

In [None]:
avg_preds = all_preds.mean(0)
avg_preds.shape

In [None]:
accuracy(avg_preds, y)

In [None]:
beta=0.5
accuracy(preds*beta + avg_preds*(1-beta), y)

## Fin