In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nb_005 import *

# STL-10

## Basic data aug

In [None]:
PATH = Path('data/stl10')

In [None]:
data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
data_norm,data_denorm = normalize_funcs(data_mean,data_std)

In [None]:
train_ds = FilesDataset.from_folder(PATH/'train')
valid_ds = FilesDataset.from_folder(PATH/'valid')

In [None]:
x=Image(valid_ds[0][0])
x.show()
x.shape

In [None]:
def transform_datasets(train_ds, valid_ds, tfms, **kwargs):
    return (DatasetTfm(train_ds, tfms[0], **kwargs),
            DatasetTfm(valid_ds, tfms[1], **kwargs))

In [None]:
size=96
tfms = get_transforms(do_flip=True, max_rotate=5, max_lighting=0.2, max_warp=0.15)#, max_zoom=1.25)
# tfms = get_transforms(do_flip=True)#, max_rotate=5, max_lighting=0.1)
tds = transform_datasets(train_ds, valid_ds, tfms, size=size, padding_mode='zeros')
data = DataBunch(*tds, bs=32, num_workers=8, tfms=data_norm)

In [None]:
(x,y) = next(iter(data.valid_dl))

_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)

In [None]:
(x,y) = next(iter(data.train_dl))

_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)

In [None]:
_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flat): show_image(tds[0][1][0], ax)

## Train

In [None]:
from torchvision.models import resnet18, resnet34, resnet50
arch = resnet50

In [None]:
def set_bn_eval(m):
    for l in m.children():
        if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
            l.eval()
        set_bn_eval(l)

In [None]:
def set_bn_eval(m):
    for l in m.children():
        set_bn_eval(l)
        if isinstance(l, bn_types):
            l.momentum = 0.1 if next(l.parameters()).requires_grad else 0.0

In [None]:
@dataclass
class BnFreeze(Callback):
    learn:Learner
    def on_train_begin(self, **kwargs): set_bn_eval(self.learn.model)

In [None]:
class ConvLearner(Learner):
    def __init__(self, data, arch, cut, pretrained=True, lin_ftrs=None, dps=None, **kwargs):
        skeleton = create_skeleton(arch(pretrained), cut)
        nf = num_features(skeleton) * 2
        # XXX: better way to get num classes
        head = create_head(nf, len(data.train_ds.ds.classes), lin_ftrs, dps)
        model = nn.Sequential(skeleton, head)
        super().__init__(data, model, **kwargs)
        self.split(lambda m: (m[1],))

    def freeze_to(self, n):
        for g in self.layer_groups[:n]:
            for p in g.parameters(): p.requires_grad = False
        for g in self.layer_groups[n:]:
            for p in g.parameters(): p.requires_grad = True

    def freeze(self):
        assert(len(self.layer_groups)>1)
        self.freeze_to(-1)
        
    def unfreeze(self): self.freeze_to(0)

In [None]:
def cond_init(m, init_fn):
    if not isinstance(m, bn_types):
        if hasattr(m, 'weight'): init_fn(m.weight)
        if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
            
def apply_init(m, init_fn):
    m.apply(lambda x: cond_init(x, init_fn))    

In [None]:
lr = 1e-3

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-8 #, callback_fns=[BnFreeze]
                    #,dps=[0.01,0.02]
                    , opt_fn=partial(optim.SGD, momentum=0.9))
learn.metrics = [accuracy]

In [None]:
learn.split(lambda m: (m[0][6], m[1]))
learn.freeze()

In [None]:
apply_init(learn.model[1], nn.init.kaiming_normal_)

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
learn.fit(2, lr)

In [None]:
learn.save('0')

In [None]:
def requires_grad(l):
    p = list(l.parameters())
    if not p: return None
    return p[0].requires_grad

## Gradual unfreezing

In [None]:
learn.load('0')

In [None]:
learn = ConvLearner(data, arch, 2, callback_fns=[BnFreeze]
                    , opt_fn=partial(optim.SGD, momentum=0.9))
learn.metrics = [accuracy]
apply_init(learn.model[1], nn.init.kaiming_normal_)
learn.split(lambda m: (m[1]))
learn.freeze()

In [None]:
learn.fit(1, lr)

In [None]:
learn.unfreeze()

In [None]:
learn.fit(2, lr)

In [None]:
fit_one_cycle(learn, lrs/2, 12, div_factor=20, pct_end=0.35)

In [None]:
learn.save('1')

## Fin

In [None]:
import pandas as pd
csv = pd.read_csv(PATH/'default.csv')
is_valid = csv['2']=='valid'
valid_df,train_df = csv[is_valid],csv[~is_valid]
len(valid_df),len(train_df)

In [None]:
len(valid_ds)

In [None]:
train_fns,train_lbls,valid_fns,valid_lbls = map(np.array,
    (train_df['0'],train_df['1'],valid_df['0'],valid_df['1']))

train_fns = [PATH/o for o in train_fns]
valid_fns = [PATH/o for o in valid_fns]

train_ds = FilesDataset(train_fns,train_lbls)
valid_ds = FilesDataset(valid_fns,valid_lbls, classes=train_ds.classes)