## Inpainting with Variable Dataset Size

Recall that the Image网 dataset consists of:

1. A `/val` folder with 10 classes.
2. A `/train` folder with 20 classes. 
  - There are ~125 images in each class that exists in `/val`. There are 
  - There are ~1,300 images in each class that does not exist in `/val`
3. An `/unsup` folder with 7,750 unlabelled images.

The question we would like to answer with this notebook is:

> What is the effect of dataset size during pretext training on downstream task performance?

To answer this question we will consider four different datasets, each built from ImageWang.

They are:

1. All data in `/train`, `/unsup` and `/val`
2. All data in `/train`, `/unsup`
3. All data in `/train`
4. Only Data in `/train` that has a corresponding class in `/val`

In [1]:
import json
import torch

import numpy as np

from functools import partial


from fastai2.layers import MishJit, MaxPool, LabelSmoothingCrossEntropy
from fastai2.basics import DataBlock, RandomSplitter, GrandparentSplitter, CategoryBlock, parent_label
from fastai2.learner import Learner
from fastai2.metrics import accuracy, top_k_accuracy
from fastai2.optimizer import ranger, Adam, SGD, RMSProp

from fastai2.vision.all import ImageBlock, PILMask, get_image_files, PILImage, imagenet_stats
from fastai2.vision.core import get_annotations, Image, TensorBBox, TensorPoint, TensorImage
from fastai2.vision.augment import aug_transforms, RandomResizedCrop, RandTransform, FlipItem
from fastai2.vision.learner import unet_learner, unet_config
from fastai2.vision.models.xresnet import xresnet50, xresnet34

from fastai2.data.transforms import get_files
from fastai2.data.external import download_url, URLs, untar_data

from fastcore.foundation import L
from fastcore.utils import num_cpus

from torch.nn import MSELoss
from torchvision.models import resnet34

We will train this network with the best hyper-parameters/optimizer/settings we know. These settings come from [training Imagenette](https://github.com/fastai/imagenette/tree/58a63175a2c6457650289d32741940d6a7d58fbf). 

One thing to keep in mind is that the above is a classification task, so it's not 100% guaranteed that these settings will map perfectly to our task. That said, they're probably a very good starting point.

As of January 2020 the [best parameters](https://github.com/fastai/imagenette/blob/58a63175a2c6457650289d32741940d6a7d58fbf/2020-01-train.md) are:

```
--lr 8e-3 
--sqrmom 0.99 
--mom 0.95 
--eps 1e-6 
--bs 64 
--opt ranger 
--sa 1
--fp16 1 
--arch xse_resnext50 
--pool MaxPool
```

One change we're making is that we're going to use **`xresnet34`** here. 

In [2]:
# Default parameters
gpu=None
lr=1e-2
size=128
sqrmom=0.99
mom=0.9
eps=1e-6
epochs=15
bs=64
mixup=0.
opt='ranger',
arch='xresnet50'
sh=0.
sa=0
sym=0
beta=0.
act_fn='MishJit'
fp16=0
pool='AvgPool',
dump=0
runs=3
meta=''

# Chosen parameters
lr=8e-3
sqrmom=0.99
mom=0.95
eps=1e-6
bs=64 
opt='ranger'
sa=1                 #NOTE: NOT USED HERE. Do we need this?
fp16=0               #NOTE: My GPU cannot run fp16 :'(
arch='xresnet50' 
pool='MaxPool'

gpu=0

# NOTE: Normally loaded from their corresponding string
m = xresnet34
act_fn = MishJit
pool = MaxPool

In [3]:
# We create this dummy class in order to create a transform that ONLY operates on images of this type
# We will use it to create all input images
class PILImageInput(PILImage): pass

class RandomCutout(RandTransform):
    "Picks a random scaled crop of an image and resize it to `size`"
    split_idx = None
    def __init__(self, min_n_holes=5, max_n_holes=10, min_length=5, max_length=50, **kwargs):
        super().__init__(**kwargs)
        self.min_n_holes=min_n_holes
        self.max_n_holes=max_n_holes
        self.min_length=min_length
        self.max_length=max_length
        

    def encodes(self, x:PILImageInput):
        """
        Note that we're accepting our dummy PILImageInput class
        fastai2 will only pass images of this type to our encoder. 
        This means that our transform will only be applied to input images and won't
        be run against output images.
        """
        
        n_holes = np.random.randint(self.min_n_holes, self.max_n_holes)
        pixels = np.array(x) # Convert to mutable numpy array. FeelsBadMan
        h,w = pixels.shape[:2]

        for n in range(n_holes):
            h_length = np.random.randint(self.min_length, self.max_length)
            w_length = np.random.randint(self.min_length, self.max_length)
            h_y = np.random.randint(0, h)
            h_x = np.random.randint(0, w)
            y1 = int(np.clip(h_y - h_length / 2, 0, h))
            y2 = int(np.clip(h_y + h_length / 2, 0, h))
            x1 = int(np.clip(h_x - w_length / 2, 0, w))
            x2 = int(np.clip(h_x + w_length / 2, 0, w))
           
            pixels[y1:y2, x1:x2, :] = 0
            
        return Image.fromarray(pixels, mode='RGB')

In [4]:
source = untar_data(URLs.IMAGEWANG_160)

In [5]:
opt_func = partial(ranger, mom=mom, sqr_mom=sqrmom, eps=eps, beta=beta)

In [6]:
# Number of workers for creating the data bunch
workers = min(8, num_cpus())

In [7]:
# transforms are the same for each experiment
item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5), RandomCutout]
batch_tfms=RandomErasing(p=0.9, max_count=3, sh=sh) if sh else None

In [8]:
size = 160
#CHANGE: I can only fit ~32 images in a batch
bs = 32

## Get Items From Folder

So before we do anything, let's create some helper methods that will give us only the training sets that we would like.

In [9]:
def get_all_items(path):
    return get_files(path, extensions='.JPEG', recurse=True)

def get_train_items(path):
    return get_files(path/'train', extensions='.JPEG', recurse=True)

def get_unsup_items(path):
    return get_files(path/'unsup', extensions='.JPEG', recurse=True)

def get_valid_items(path):
    return get_files(path/'val', extensions='.JPEG', recurse=True)

def get_train_and_unsup(path):
    return get_train_items(path) + get_unsup_items(path)

def get_train_items_that_are_present_in_val(path):
    """
    We first get a list of all classes in /val
    Then we use that list to get all the examples of each class from /train
    """
    val = source/'val'
    validation_classes = [path.name for path in val.iterdir()]
    
    train_files = L()
    for class_name in validation_classes:
        items = get_files(path/'train'/class_name, extensions='.JPEG', recurse=True)
        train_files = train_files + items
        
    return train_files

all_items = get_all_items(untar_data(URLs.IMAGEWANG_160))
train_items = get_train_items(untar_data(URLs.IMAGEWANG_160))
unsup_items = get_unsup_items(untar_data(URLs.IMAGEWANG_160))
valid_items = get_valid_items(untar_data(URLs.IMAGEWANG_160))

print("All Files: {}".format(len(all_items)))
print("Train Files: {}".format(len(train_items)))
print("Unsup Files: {}".format(len(unsup_items)))
print("Valid Files: {}".format(len(valid_items)))
print()

train_and_unsup_items = get_train_and_unsup(untar_data(URLs.IMAGEWANG_160))
print("Train+Unsup Files: {}".format(len(train_and_unsup_items)))
train_in_valid_items = get_train_items_that_are_present_in_val(untar_data(URLs.IMAGEWANG_160))
print("Train+Unsup Files: {}".format(len(train_in_valid_items)))

All Files: 26348
Train Files: 14669
Unsup Files: 7750
Valid Files: 3929

Train+Unsup Files: 22419
Train+Unsup Files: 1275


## Train with all data in `/train`, `/unsup` and `/val`

In [10]:
dblock = DataBlock(blocks=(ImageBlock(cls=PILImageInput), ImageBlock),
                   splitter=RandomSplitter(valid_pct=0),
                   get_items=get_all_items, 
                   get_y=lambda o: o)

dbunch =  dblock.databunch(source, path=source, bs=bs, num_workers=workers, 
                        item_tfms=item_tfms, batch_tfms=batch_tfms)

#CHANGE: We're predicting pixel values, so we're just going to predict an output for each RGB channel
dbunch.vocab = ['R', 'G', 'B']

print("Training Size:", len(dbunch.train_ds))
print("Validation Size:", len(dbunch.valid_ds))

Training Size: 26348
Validation Size: 0


In [11]:
learn = unet_learner(dbunch, m, opt_func=opt_func, metrics=[], loss_func=MSELoss())
if dump: print(learn.model); exit()
if fp16: learn = learn.to_fp16()
cbs = MixUp(mixup) if mixup else []
learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)

# I'm not using fastai2's .export() because I only want to save 
# the model's parameters. 
torch.save(learn.model[0].state_dict(), 'all_train_unsup_val_pretext.pth')

epoch,train_loss,valid_loss,time
0,0.00583,,02:45
1,0.005231,,02:41
2,0.004981,,02:39
3,0.005106,,02:39
4,0.005025,,02:39
5,0.004764,,02:39
6,0.004623,,02:39
7,0.004679,,02:39
8,0.004866,,02:39
9,0.0046,,02:39


  warn("Your generator is empty.")


## Train with all data in `/train` and `/unsup`

In [12]:
dblock = DataBlock(blocks=(ImageBlock(cls=PILImageInput), ImageBlock),
                   splitter=RandomSplitter(valid_pct=0),
                   get_items=get_train_and_unsup, 
                   get_y=lambda o: o)
dbunch =  dblock.databunch(source, path=source, bs=bs, num_workers=workers, 
                        item_tfms=item_tfms, batch_tfms=batch_tfms)

#CHANGE: We're predicting pixel values, so we're just going to predict an output for each RGB channel
dbunch.vocab = ['R', 'G', 'B']

print("Training Size:", len(dbunch.train_ds))
print("Validation Size:", len(dbunch.valid_ds))

Training Size: 22419
Validation Size: 0


In [13]:
learn = unet_learner(dbunch, m, opt_func=opt_func, metrics=[], loss_func=MSELoss())
if dump: print(learn.model); exit()
if fp16: learn = learn.to_fp16()
cbs = MixUp(mixup) if mixup else []
learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)

# I'm not using fastai2's .export() because I only want to save 
# the model's parameters. 
torch.save(learn.model[0].state_dict(), 'all_train_unsup_pretext.pth')

epoch,train_loss,valid_loss,time
0,0.006544,,02:15
1,0.005819,,02:15
2,0.005358,,02:15
3,0.00497,,02:15
4,0.005081,,02:15
5,0.005071,,02:16
6,0.004868,,02:15
7,0.004608,,02:16
8,0.004858,,02:15
9,0.005015,,02:15


## Train with all data in `/train`

In [14]:
dblock = DataBlock(blocks=(ImageBlock(cls=PILImageInput), ImageBlock),
                   splitter=RandomSplitter(valid_pct=0),
                   get_items=get_train_items, 
                   get_y=lambda o: o)

dbunch =  dblock.databunch(source, path=source, bs=bs, num_workers=workers, 
                        item_tfms=item_tfms, batch_tfms=batch_tfms)

#CHANGE: We're predicting pixel values, so we're just going to predict an output for each RGB channel
dbunch.vocab = ['R', 'G', 'B']

print("Training Size:", len(dbunch.train_ds))
print("Validation Size:", len(dbunch.valid_ds))

Training Size: 14669
Validation Size: 0


In [15]:
learn = unet_learner(dbunch, m, opt_func=opt_func, metrics=[], loss_func=MSELoss())
if dump: print(learn.model); exit()
if fp16: learn = learn.to_fp16()
cbs = MixUp(mixup) if mixup else []
learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)

# I'm not using fastai2's .export() because I only want to save 
# the model's parameters. 
torch.save(learn.model[0].state_dict(), 'all_train_pretext.pth')

epoch,train_loss,valid_loss,time
0,0.010095,,01:29
1,0.006838,,01:29
2,0.006041,,01:29
3,0.005875,,01:29
4,0.005637,,01:29
5,0.005434,,01:29
6,0.00535,,01:29
7,0.005186,,01:29
8,0.005243,,01:30
9,0.00521,,01:29


## Train with partial data from `/train`

In [16]:
dblock = DataBlock(blocks=(ImageBlock(cls=PILImageInput), ImageBlock),
                   splitter=RandomSplitter(valid_pct=0),
                   get_items=get_train_items_that_are_present_in_val, 
                   get_y=lambda o: o)

dbunch =  dblock.databunch(source, path=source, bs=bs, num_workers=workers, 
                        item_tfms=item_tfms, batch_tfms=batch_tfms)

#CHANGE: We're predicting pixel values, so we're just going to predict an output for each RGB channel
dbunch.vocab = ['R', 'G', 'B']

print("Training Size:", len(dbunch.train_ds))
print("Validation Size:", len(dbunch.valid_ds))

Training Size: 1275
Validation Size: 0


In [17]:
learn = unet_learner(dbunch, m, opt_func=opt_func, metrics=[], loss_func=MSELoss())
if dump: print(learn.model); exit()
if fp16: learn = learn.to_fp16()
cbs = MixUp(mixup) if mixup else []
learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)

# I'm not using fastai2's .export() because I only want to save 
# the model's parameters. 
torch.save(learn.model[0].state_dict(), 'partial_train_pretext.pth')

epoch,train_loss,valid_loss,time
0,0.838443,,00:08
1,0.678933,,00:08
2,0.534371,,00:08
3,0.397004,,00:08
4,0.276421,,00:08
5,0.180806,,00:08
6,0.111064,,00:08
7,0.064325,,00:08
8,0.035989,,00:08
9,0.020368,,00:08


# Downstream Task: Image网

Now that we've trained models on our pretext tasks, let's compare the performance of each model against one another.

In [18]:
def get_dbunch(size, bs, sh=0., workers=None):
    if size<=224: 
        path = URLs.IMAGEWANG_160
    else: 
        path = URLs.IMAGEWANG
    source = untar_data(path)
    if workers is None: workers = min(8, num_cpus())
    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                       splitter=GrandparentSplitter(valid_name='val'),
                       get_items=get_image_files, get_y=parent_label)
    item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    batch_tfms=RandomErasing(p=0.9, max_count=3, sh=sh) if sh else None
    return dblock.databunch(source, path=source, bs=bs, num_workers=workers,
                            item_tfms=item_tfms, batch_tfms=batch_tfms)

In [19]:
dbunch = get_dbunch(160, 32)

## Random Baseline

In [20]:
for run in range(runs):
        print(f'Run: {run}')
        #CHANGE: No self-attention
        sa = 0
        learn = Learner(dbunch, m(c_out=20, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
        if dump: print(learn.model); exit()
        if fp16: learn = learn.to_fp16()
        cbs = MixUp(mixup) if mixup else []
        learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)

Run: 0


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.724964,3.348793,0.036905,0.394248,00:33
1,1.513149,3.536191,0.042504,0.442861,00:33
2,1.370328,3.05226,0.109443,0.53805,00:33
3,1.322803,2.70726,0.192924,0.668872,00:32
4,1.246846,3.402213,0.086791,0.607025,00:32
5,1.200642,2.542671,0.251463,0.755663,00:33
6,1.150564,2.537648,0.27946,0.718758,00:32
7,1.095762,2.193474,0.37465,0.827691,00:32
8,1.115557,2.220904,0.369814,0.800713,00:32
9,1.060507,2.137547,0.39043,0.845253,00:32


Run: 1


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.693955,3.488328,0.018834,0.359379,00:32
1,1.455517,3.557442,0.017562,0.494782,00:32
2,1.403235,2.993247,0.109188,0.575464,00:33
3,1.306308,2.801047,0.202087,0.620005,00:32
4,1.232584,2.753357,0.22143,0.656147,00:33
5,1.166696,2.436998,0.29575,0.774752,00:32
6,1.170094,2.429303,0.308984,0.781369,00:32
7,1.129331,2.747638,0.241537,0.669381,00:32
8,1.086292,2.10087,0.417918,0.847035,00:32
9,1.059242,2.251438,0.399847,0.811148,00:32


Run: 2


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.704429,3.499991,0.02876,0.360906,00:32
1,1.483799,3.526596,0.01069,0.359888,00:32
2,1.383331,3.063984,0.125732,0.558666,00:32
3,1.304391,2.90804,0.12955,0.618733,00:32
4,1.234187,2.628224,0.240774,0.701196,00:32
5,1.210015,2.450081,0.302876,0.739628,00:32
6,1.161548,2.539947,0.293968,0.714431,00:32
7,1.099507,2.363112,0.316365,0.786969,00:33
8,1.084282,2.270023,0.366251,0.8338,00:32
9,1.061747,2.26742,0.372614,0.79435,00:32


Results:
- Run 1: 0.547213
- Run 2: 0.544668
- Run 3: 0.557903

Average: **55.0%**


## All data in `/train`, `/unsup` and `/val`

In [21]:
for run in range(runs):
        print(f'Run: {run}')
        #CHANGE: No self-attention
        sa = 0
        learn = Learner(dbunch, m(c_out=20, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
        if dump: print(learn.model); exit()
        if fp16: learn = learn.to_fp16()
        cbs = MixUp(mixup) if mixup else []
            
        # Load weights generated from training on our pretext task
        state_dict = torch.load('all_train_unsup_val_pretext.pth')
        # HACK: If we don't have all of the parameters for our learner, we get an error
        linear_layer = learn.model[-1]
        state_dict['11.weight'] = linear_layer.weight
        state_dict['11.bias'] = linear_layer.bias
        
        learn.model.load_state_dict(state_dict)
        
        learn.freeze()
        learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)

Run: 0


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.582127,3.750611,0.007381,0.334945,00:32
1,1.419643,2.922854,0.146857,0.661491,00:32
2,1.306587,2.740122,0.202596,0.677017,00:32
3,1.236147,2.322164,0.324256,0.823874,00:32
4,1.15281,2.667139,0.246373,0.689997,00:32
5,1.125192,2.126453,0.393484,0.871469,00:32
6,1.124609,2.090095,0.40112,0.846017,00:32
7,1.047484,2.320644,0.36167,0.805803,00:33
8,1.022953,2.271408,0.380249,0.814457,00:32
9,1.019367,2.093393,0.435989,0.855434,00:32


Run: 1


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.54791,3.181367,0.06872,0.435989,00:32
1,1.403507,3.091935,0.127768,0.565029,00:32
2,1.316903,2.781852,0.221176,0.630695,00:32
3,1.219741,2.856547,0.18198,0.655638,00:32
4,1.165347,2.693729,0.193433,0.713159,00:32
5,1.117393,2.510223,0.285569,0.745737,00:32
6,1.086503,2.292799,0.339527,0.799186,00:32
7,1.05075,2.299235,0.375414,0.781369,00:32
8,1.037902,2.139894,0.414864,0.84958,00:32
9,1.01504,2.004684,0.47213,0.862815,00:32


Run: 2


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.551739,3.794199,0.008145,0.209723,00:32
1,1.405913,3.885333,0.036142,0.302622,00:32
2,1.287837,2.71175,0.177399,0.676253,00:32
3,1.216502,2.651858,0.245355,0.682616,00:32
4,1.176918,2.853776,0.194961,0.640112,00:33
5,1.108003,2.454046,0.314075,0.76508,00:32
6,1.12013,2.347095,0.336727,0.772207,00:33
7,1.056684,3.060613,0.212268,0.780351,00:32
8,1.017083,2.01868,0.458386,0.869941,00:32
9,1.015788,1.988194,0.472639,0.861542,00:32


Results:
- Run 1: 0.608297
- Run 2: 0.595062
- Run 3: 0.607788

Average: **60.4%**


## All data in `/train` and `/unsup`

In [22]:
for run in range(runs):
        print(f'Run: {run}')
        #CHANGE: No self-attention
        sa = 0
        learn = Learner(dbunch, m(c_out=20, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
        if dump: print(learn.model); exit()
        if fp16: learn = learn.to_fp16()
        cbs = MixUp(mixup) if mixup else []
            
        # Load weights generated from training on our pretext task
        state_dict = torch.load('all_train_unsup_pretext.pth')
        # HACK: If we don't have all of the parameters for our learner, we get an error
        linear_layer = learn.model[-1]
        state_dict['11.weight'] = linear_layer.weight
        state_dict['11.bias'] = linear_layer.bias
        
        learn.model.load_state_dict(state_dict)
        
        learn.freeze()
        learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)

Run: 0


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.585423,3.597252,0.024434,0.373886,00:32
1,1.41034,3.021539,0.099262,0.590736,00:32
2,1.312899,2.971085,0.148384,0.586154,00:32
3,1.259073,2.683551,0.215831,0.665309,00:32
4,1.184609,2.485646,0.290914,0.744464,00:32
5,1.131635,2.294315,0.347417,0.804276,00:32
6,1.08762,2.176512,0.37465,0.829982,00:32
7,1.053774,2.630029,0.302112,0.736574,00:32
8,1.010145,2.411612,0.346144,0.803258,00:32
9,1.002379,2.138042,0.441079,0.836854,00:33


Run: 1


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.544064,3.742224,0.001018,0.229066,00:35
1,1.453853,3.200466,0.090863,0.438025,00:35
2,1.295131,2.710228,0.187834,0.678035,00:35
3,1.234436,2.303281,0.306185,0.786969,00:35
4,1.169761,2.532847,0.286332,0.750064,00:34
5,1.130456,2.166461,0.378722,0.815729,00:34
6,1.087212,2.602861,0.253754,0.739119,00:34
7,1.077098,2.373262,0.347926,0.804785,00:34
8,1.030411,2.155382,0.411555,0.838127,00:34
9,0.994264,2.111377,0.43548,0.832527,00:34


Run: 2


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.601778,3.392368,0.045813,0.465513,00:35
1,1.41485,3.797504,0.016289,0.217104,00:34
2,1.31141,2.874049,0.15551,0.632731,00:34
3,1.250876,2.909004,0.146093,0.694833,00:34
4,1.180566,2.979107,0.204378,0.599644,00:34
5,1.121922,2.301704,0.335709,0.801731,00:34
6,1.101594,2.299592,0.344108,0.809366,00:34
7,1.067181,2.299928,0.360143,0.842963,00:34
8,1.049793,2.060683,0.426317,0.859761,00:34
9,1.013166,2.086261,0.419191,0.86536,00:34


Results:
- Run 1: 0.599644
- Run 2: 0.597098
- Run 3: 0.581573

Average: **59.3%**


## All data in `/train`

In [23]:
for run in range(runs):
        print(f'Run: {run}')
        #CHANGE: No self-attention
        sa = 0
        learn = Learner(dbunch, m(c_out=20, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
        if dump: print(learn.model); exit()
        if fp16: learn = learn.to_fp16()
        cbs = MixUp(mixup) if mixup else []
            
        # Load weights generated from training on our pretext task
        state_dict = torch.load('all_train_pretext.pth')
        # HACK: If we don't have all of the parameters for our learner, we get an error
        linear_layer = learn.model[-1]
        state_dict['11.weight'] = linear_layer.weight
        state_dict['11.bias'] = linear_layer.bias
        
        learn.model.load_state_dict(state_dict)
        
        learn.freeze()
        learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)


Run: 0


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.569853,3.524181,0.041741,0.362942,00:32
1,1.421617,3.673763,0.016798,0.323237,00:33
2,1.324057,2.890579,0.165182,0.58819,00:33
3,1.259301,2.402827,0.285314,0.762026,00:32
4,1.157544,3.302516,0.109952,0.547468,00:32
5,1.114174,2.525295,0.27717,0.792568,00:32
6,1.101109,2.296636,0.355307,0.812166,00:32
7,1.056738,2.196374,0.38763,0.813439,00:32
8,0.997091,2.193357,0.396284,0.835327,00:32
9,1.01173,2.344701,0.363197,0.786714,00:32


Run: 1


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.573777,3.525678,0.066684,0.595317,00:32
1,1.407014,3.314837,0.05294,0.626623,00:32
2,1.300834,2.617093,0.224485,0.713413,00:32
3,1.23296,2.498491,0.23212,0.764317,00:32
4,1.13788,3.015863,0.175108,0.633495,00:32
5,1.160997,2.236256,0.356579,0.817256,00:32
6,1.08365,2.16752,0.395266,0.822347,00:32
7,1.038221,2.59683,0.31662,0.720285,00:32
8,1.055219,2.119983,0.406465,0.84169,00:32
9,1.010749,2.261474,0.395266,0.795113,00:32


Run: 2


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.571346,3.861682,0.006872,0.327055,00:32
1,1.406196,3.761446,0.033342,0.349198,00:32
2,1.290671,2.95942,0.15831,0.544413,00:33
3,1.222605,2.328174,0.301858,0.817765,00:32
4,1.174556,2.72966,0.224739,0.64622,00:33
5,1.129882,2.41244,0.313566,0.787987,00:32
6,1.088257,2.41469,0.302622,0.799186,00:32
7,1.083549,2.314783,0.328328,0.801222,00:33
8,0.997049,2.091954,0.434716,0.856452,00:32
9,1.027624,1.928081,0.486892,0.874523,00:32


Results:
- Run 1: 0.612115
- Run 2: 0.600407
- Run 3: 0.603970

Average: **60.5%**


## Partial data from `/train`

In [25]:
for run in range(runs):
        print(f'Run: {run}')
        #CHANGE: No self-attention
        sa = 0
        learn = Learner(dbunch, m(c_out=20, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
        if dump: print(learn.model); exit()
        if fp16: learn = learn.to_fp16()
        cbs = MixUp(mixup) if mixup else []
            
        # Load weights generated from training on our pretext task
        state_dict = torch.load('partial_train_pretext.pth')
        # HACK: If we don't have all of the parameters for our learner, we get an error
        linear_layer = learn.model[-1]
        state_dict['11.weight'] = linear_layer.weight
        state_dict['11.bias'] = linear_layer.bias
        
        learn.model.load_state_dict(state_dict)
        
        learn.freeze()
        learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)


Run: 0


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.579269,3.575554,0.038178,0.523543,00:32
1,1.42733,5.354814,0.0,0.107661,00:32
2,1.315489,2.618873,0.231102,0.710359,00:32
3,1.230068,2.538547,0.218885,0.748282,00:32
4,1.191853,2.902922,0.201323,0.643421,00:32
5,1.135853,2.259957,0.341054,0.824128,00:33
6,1.119195,2.286839,0.35887,0.808094,00:32
7,1.070681,2.728824,0.26139,0.703232,00:32
8,1.037029,2.173724,0.390685,0.855943,00:33
9,1.005619,2.023341,0.469585,0.861542,00:32


Run: 1


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.579783,4.154885,0.019343,0.222194,00:32
1,1.416207,2.860526,0.113515,0.668618,00:32
2,1.315893,2.941454,0.132349,0.621787,00:32
3,1.208766,3.048025,0.193688,0.560957,00:32
4,1.1603,2.585316,0.276661,0.721303,00:32
5,1.130774,2.221207,0.369051,0.829982,00:32
6,1.108004,2.198512,0.378213,0.828201,00:32
7,1.07841,2.4887,0.30084,0.829728,00:32
8,1.037706,2.100374,0.415882,0.839145,00:32
9,1.016165,2.109907,0.429371,0.837109,00:32


Run: 2


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.579392,3.727465,0.011199,0.345635,00:32
1,1.396257,3.523748,0.043013,0.285569,00:32
2,1.300451,3.058381,0.138712,0.55663,00:32
3,1.220316,2.313607,0.30084,0.807076,00:32
4,1.197507,3.465457,0.081955,0.490965,00:33
5,1.134546,2.392729,0.324001,0.798931,00:32
6,1.098854,2.353947,0.343853,0.771698,00:32
7,1.069024,2.247727,0.386612,0.810639,00:32
8,1.047567,2.031735,0.439807,0.862051,00:32
9,0.987922,2.192773,0.408246,0.842454,00:32


Results:
- Run 1: 0.590481
- Run 2: 0.589463
- Run 3: 0.594553

Average: **59.1%**
