In [None]:
#default_exp utils

# Utility Functions

> Utility functions to help with downstream tasks

In [None]:
#hide
from nbdev.showdoc import *
from self_supervised.byol import *
from self_supervised.simclr import *
from self_supervised.swav import *

In [None]:
#export
from fastai.vision.all import *

## Loading Weights for Downstream Tasks

In [None]:
#export
def transfer_weights(learn:Learner, weights_path:Path, device:torch.device=None):
    "Load and freeze pretrained weights inplace from `weights_path` using `device`"
    if device is None: device = learn.dls.device
    new_state_dict = torch.load(weights_path, map_location=device)
    if 'model' in new_state_dict.keys(): new_state_dict = new_state_dict['model'] 
    #allow for simply exporting the raw PyTorch model
    learn_state_dict = learn.model.state_dict()
    matched_layers = 0
    for name, param in learn_state_dict.items():
        name = 'encoder.'+name[2:]
        if name in new_state_dict:
            matched_layers += 1
            input_param = new_state_dict[name]
            if input_param.shape == param.shape:
                param.copy_(input_param)
            else:
                raise ValueError(f'Shape mismatch at {name}, please ensure you have the same backbone')
        else: pass # these are weights that weren't in the original model, such as a new head
    if matched_layers == 0: raise Exception("No shared weight names were found between the models")
    learn.model.load_state_dict(learn_state_dict)
    learn.freeze()
    print("Weights successfully transferred!")

When training models with this library, the `state_dict` will change, so loading it back into `fastai` as an encoder won't be a perfect match. This helper function aims to make that simple. 

Example usage:

First prepare the downstream-task dataset (`ImageWoof` is shown here):

In [None]:
path = untar_data(URLs.IMAGEWOOF)
tfms = [[PILImage.create], [parent_label, Categorize()]]
item_tfms = [ToTensor(), Resize(224)]
batch_tfms = [FlipItem(), RandomResizedCrop(224, min_scale=0.35),
          IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)]
items = get_image_files(path)
splits = GrandparentSplitter(valid_name='val')(items)
dsets = Datasets(items, tfms, splits=splits)
dls = dsets.dataloaders(after_item=item_tfms, after_batch=batch_tfms,
                      bs=32)

For the sake of example we will make and save a SWaV model trained for one epoch (in reality you'd want to train for many more):

In [None]:
net = create_swav_model(arch=xresnet50, pretrained=False)
learn = Learner(dls, net, SWAVLoss(), cbs=[SWAV()])
learn.save('../../../swav_test');

Followed by a `Learner` designed for classification with a simple custom head for our `xresnet`:

In [None]:
body = create_body(xresnet34, pretrained=False)
nf = num_features_model(body)*2
head = create_head(nf, dls.c)
arch = nn.Sequential(body, head)

learn = Learner(dls, arch, splitter=default_split)

Before loading in all the weights:

In [None]:
transfer_weights(learn, '../../swav_test.pth')

Weights successfully transferred!


Now we can do downstream tasks with our pretrained models!