In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq self-supervised

# Barlow Twins ImageWang Tutorial

https://arxiv.org/pdf/2103.03230.pdf

**Note:** This notebook demonstrates how to use `SimCLR` callback with a single GPU. For distributed version, `DistributedSimCLR` checkout documentation.

First import **fastai** for training and other helpers, you can choose not to use **wandb** by setting `WANDB=False`.

In [1]:
from fastai.vision.all import *
torch.backends.cudnn.benchmark = True
WANDB = False
if WANDB:
    try:
        from fastai.callback.wandb import WandbCallback
        import wandb
    except:
        raise ImportError("Please run '!pip install wandb' on another cell to install wandb")

Then import **self_supervised** `augmentations` module for creating augmentations pipeline, `layers` module for creating encoder and model, and finally `simclr` for self-supervised training.

In [2]:
from self_supervised.augmentations import *
from self_supervised.layers import *
# from self_supervised.vision.simclr import *

In this notebook we will take a look at [ImageWang](https://github.com/fastai/imagenette#image%E7%BD%91) benchmark, how to train a self-supervised model using MoCo algorithm and then how to use this pretrained model for finetuning on the given downstream task. 

## Pretraining

In [3]:
def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    
    files = get_image_files(source)
    tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(valid_pct=0.1)(files))
    
    batch_tfms = [IntToFloatTensor]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls

ImageWang has several benchmarks for different image sizes, in this tutorial we will go for `size=224` and also demonstrate how effectively you can utilize GPU memory.

Define batch size, resize resolution before batching and size for random cropping during self-supervised training. It's always good to use a batch size as high as it can fit the GPU memory.

In [4]:
# bs, resize, size = 256, 256, 224

Select architecture to train on, remember all **timm** and **fastai** models are available! We need to set `pretrained=False` here because using imagenet weights for ImageWang data would be cheating.

In [5]:
# arch = "xresnet34"
# encoder = create_encoder(arch, pretrained=False, n_in=3)

In [6]:
# if WANDB:
#     xtra_config = {"Arch":arch, "Resize":resize, "Size":size, "Algorithm":"Barlow-Twins"}
#     wandb.init(project="self-supervised-imagewang", config=xtra_config);

Initialize the Dataloaders using the function above.

In [7]:
# dls = get_dls(resize, bs)

Create SimCLR model. You can change values of `hidden_size`, `projection_size`, and `n_layers`. For this problem, defaults work just fine so we don't do any changes.

In [8]:
#export
class BarlowTwinsModel(Module):
    "An encoder followed by a projector" 
    def __init__(self,encoder,projector): self.encoder,self.projector = encoder,projector
    def forward(self,x): return self.projector(self.encoder(x))

In [9]:
#export
def create_barlow_twins_model(encoder, hidden_size=256, projection_size=128):
    "Create SimCLR model"
    n_in  = in_channels(encoder)
    with torch.no_grad(): representation = encoder(torch.randn((2,n_in,128,128)))
    projector = create_mlp_module(representation.size(1), hidden_size, projection_size, bn=True, nlayers=3) 
    apply_init(projector)
    return BarlowTwinsModel(encoder, projector)

In [10]:
# model = create_barlow_twins_model(encoder, hidden_size=768, projection_size=768)

In [11]:
# model.projector

Next step is perhaps the most critical step for achieving good results on a custom problem - data augmentation. For this, we will use utility function from `self_supervised.vision.simclr.get_simclr_aug_pipelines` but you can also use your own list of Pipeline augmentations. `self_supervised.vision.simclr.get_moco_aug_pipelines`should be enough for most of the cases since under the hood it uses `self_supervised.augmentations.get_multi_aug_pipelines` and `self_supervised.augmentations.get_batch_augs`. You can do shift+tab and see all the arguments that can be passed to `get_simclr_aug_pipelines`. You can simply pass anything that you could pass to `get_batch_augs` including custom `xtra_tfms`.

`get_simclr_aug_pipelines` excepts size for random resized cropping of the 2 views of a given image and the rest of the arguments are coming from `get_batch_augs()`

In [12]:
# aug_pipelines = get_multi_aug_pipelines(n=2, size=size, rotate=True, 
#                                         rotate_deg=10, jitter=True, bw=True, blur=False)

In [13]:
# aug_pipelines

Here, we will feed the augmentation pipelines and leave temperature parameter as default.

In [14]:
# pred = torch.randn(32,16)
# bs,nf = pred.size(0)//2,pred.size(1)
# I = torch.eye(nf)
# z1, z2 = pred[:bs],pred[bs:]

# z1norm = (z1 - z1.mean(0)) / z1.std(0, unbiased=False)
# z2norm = (z2 - z2.mean(0)) / z2.std(0, unbiased=False)

# C = (z1norm.T @ z2norm) / bs 
# # cdiff = (C - I)**2
# # loss = (cdiff*I + cdiff*(1-I)*lmb).sum() 

In [15]:
# C.max(), C.min()

In [16]:
#export
class BarlowTwins(Callback):
    order,run_valid = 9,True
    def __init__(self, aug_pipelines, lmb=5e-3, print_augs=False):
        assert_aug_pipelines(aug_pipelines)
        self.aug1, self.aug2 = aug_pipelines
        if print_augs: print(self.aug1), print(self.aug2)
        store_attr('lmb')
        
        
    def before_fit(self): 
        self.learn.loss_func = self.lf
        nf = self.learn.model.projector[-1].out_features
        self.I = torch.eye(nf).to(self.dls.device)
                    
            
    def before_batch(self):
        xi,xj = self.aug1(self.x), self.aug2(self.x)
        self.learn.xb = (torch.cat([xi, xj]),)
        
    
    def lf(self, pred, *yb):
        bs,nf = pred.size(0)//2,pred.size(1)
        z1, z2 = pred[:bs],pred[bs:]
        
        z1norm = (z1 - z1.mean(0)) / z1.std(0, unbiased=False)
        z2norm = (z2 - z2.mean(0)) / z2.std(0, unbiased=False)
        
        C = (z1norm.T @ z2norm) / bs 
        cdiff = (C - self.I)**2
        loss = (cdiff*self.I + cdiff*(1-self.I)*self.lmb).sum() 
        return loss

    
    @torch.no_grad()
    def show(self, n=1):
        bs = self.learn.x.size(0)//2
        x1,x2  = self.learn.x[:bs], self.learn.x[bs:] 
        idxs = np.random.choice(range(bs),n,False)
        x1 = self.aug1.decode(x1[idxs].to('cpu').clone()).clamp(0,1)
        x2 = self.aug2.decode(x2[idxs].to('cpu').clone()).clamp(0,1)
        images = []
        for i in range(n): images += [x1[i],x2[i]] 
        return show_batch(x1[0], None, images, max_n=len(images), nrows=n)

In [17]:
# cbs=[BarlowTwins(aug_pipelines, lmb=0.1)]
# if WANDB: cbs += [WandbCallback(log_preds=False,log_model=False)]

In [18]:
# learn = Learner(dls, model, cbs=cbs)

Before starting training let's check whether our augmentations makes sense or not. Since this step consumes GPU memory, once you are done with inspection, restart the notebook and skip this step. We can see that 2 views of the same image side by side and indeed augmentations look pretty good. Now, it's time restart the notebook and skip this step.

In [19]:
# b = dls.one_batch()
# learn._split(b)
# learn('before_batch')
# learn.sim_clr.show(n=5);

Use mixed precision with `to_fp16()` for more GPU memory, larger batch size and faster training . We could also use gradient checkpointing wrapper models from `self_supervised.layers` to save even more memory, e.g. `CheckpointSequential()`.

In [20]:
# learn.to_fp16();

In [21]:
# learn.lr_find()

Learning good representations via contrastive learning usually takes a lot of epochs. So here number epochs are set to 100. This might change depending on your data distribution and dataset size.

In [22]:
# lr,wd,epochs=1e-2,1e-2,100

In [23]:
# learn.unfreeze()
# learn.fit_flat_cos(epochs, lr, wd=wd, pct_start=0.5)

In [24]:
# if WANDB: wandb.finish()

### Search Best Lambda

In [25]:
import gc

In [26]:
bs, resize, size = 128, 256, 224
lr,wd,epochs=1e-2, 1e-2, 100
WANDB = True
arch = "xresnet34"

lmb = 5e-3
nhidden = 1024

for bs in (64,128):
    if WANDB:
        xtra_config = {"Arch":arch, "Resize":resize, "Size":size, "Algorithm":"Barlow-Twins"}
        wandb.init(project="self-supervised-imagewang", config=xtra_config)
    
    dls = get_dls(resize, bs)
    encoder = create_encoder(arch, pretrained=False, n_in=3)
    model = create_barlow_twins_model(encoder, hidden_size=nhidden, projection_size=nhidden)
    aug_pipelines = get_multi_aug_pipelines(n=2, size=size, rotate=True, 
                                            rotate_deg=10, jitter=True, bw=True, 
                                            blur=True, blur_s=(4,8), blur_p=0.2)
    cbs=[BarlowTwins(aug_pipelines, lmb=lmb)]
    if WANDB: cbs += [WandbCallback(log_preds=False,log_model=False)]
    
    learn = Learner(dls, model, cbs=cbs)
    learn.to_fp16()
    learn.fit_flat_cos(epochs, lr, wd=wd, pct_start=0.25)

    save_name = f'btwins_iwang_sz{size}_epc{epochs}_lmb{lmb}_bs{bs}'
    learn.save(save_name)
    torch.save(learn.model.encoder.state_dict(), learn.path/learn.model_dir/f'{save_name}_encoder.pth')
    
    if WANDB: wandb.finish()
    
    
    del dls, learn
    gc.collect()

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33mkeremturgutlu[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.22 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


epoch,train_loss,valid_loss,time
0,509.020874,571.987305,01:30
1,424.807922,483.454285,01:30
2,387.678741,470.431671,01:31
3,362.851898,416.969574,01:33
4,327.759827,385.156342,01:33
5,308.808716,337.863037,01:31
6,287.980652,324.501007,01:31
7,282.777222,317.835175,01:32
8,264.439758,301.704346,01:33
9,248.605438,308.721252,01:32


VBox(children=(Label(value=' 0.02MB of 0.02MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,100.0
train_loss,124.97174
raw_loss,136.00615
wd_0,0.01
sqr_mom_0,0.99
lr_0,0.0
mom_0,0.9
eps_0,1e-05
_runtime,9851.0
_timestamp,1616038430.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss,█▆▄▃▃▃▃▅▂▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,███████████████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.22 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


epoch,train_loss,valid_loss,time
0,460.065491,479.487885,01:28
1,390.107361,430.421143,01:26
2,328.203613,394.97113,01:27
3,304.454712,385.682831,01:28
4,277.867676,318.509552,01:27
5,255.575256,316.390228,01:25
6,234.127808,299.507111,01:27
7,236.583771,280.800354,01:26
8,214.562408,252.471283,01:25
9,197.241241,240.573242,01:24


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,100.0
train_loss,81.6367
raw_loss,78.64686
wd_0,0.01
sqr_mom_0,0.99
lr_0,0.0
mom_0,0.9
eps_0,1e-05
_runtime,8631.0
_timestamp,1616047066.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss,█▅▃▅▂▂▂▃▂▂▂▃▂▂▂▁▁▁▁▁▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,███████████████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


## Downstream Task

In [27]:
optdict = dict(sqr_mom=0.99,mom=0.95,beta=0.,eps=1e-4)
opt_func = partial(ranger, **optdict)

In [28]:
bs, resize, size = 128, 256, 224

In [29]:
bs, size

(128, 224)

In [30]:
def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    files = get_image_files(source, folders=['train', 'val'])
    splits = GrandparentSplitter(valid_name='val')(files)
    
    item_aug = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    tfms = [[PILImage.create, ToTensor, *item_aug], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=splits)
    
    batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls

In [31]:
def split_func(m): return L(m[0], m[1]).map(params)

def create_learner(size=size, arch='xresnet34', encoder_path=f'models/btwins_iwang_sz224_epc100_lmb0.005_encoder'):
    
    dls = get_dls(size, bs=bs//2)
    pretrained_encoder = torch.load(encoder_path)
    encoder = create_encoder(arch, pretrained=False, n_in=3)
    encoder.load_state_dict(pretrained_encoder)
    nf = encoder(torch.randn(2,3,224,224)).size(-1)
    classifier = create_cls_module(nf, dls.c)
    model = nn.Sequential(encoder, classifier)
    learn = Learner(dls, model, opt_func=opt_func, splitter=split_func,
                metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
    return learn

In [32]:
def finetune(size, epochs, arch, encoder_path, lr=1e-2, wd=1e-2):
    learn = create_learner(size, arch, encoder_path)
    learn.unfreeze()
    learn.fit_flat_cos(epochs, lr, wd=wd)
    final_acc = learn.recorder.values[-1][-2]
    return final_acc

### 5 epochs

In [33]:
runs = 5
lmb = 0.005
for bs in (64,128):
    print(lmb)
    finetune(size, epochs=5, arch='xresnet34', encoder_path=f'models/btwins_iwang_sz224_epc100_lmb{lmb}_bs{bs}_encoder.pth')

0.005


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.196903,2.078518,0.48791,0.84449,01:00
1,1.135539,1.906038,0.511326,0.892084,01:00
2,1.004799,1.69984,0.583609,0.914991,01:01
3,0.979391,1.63855,0.62255,0.912191,01:01
4,0.895962,1.476393,0.668109,0.935862,01:01


0.005


epoch,train_loss,valid_loss,accuracy,top_k_accuracy,time
0,1.154825,1.900912,0.484602,0.88903,00:57
1,1.07924,1.72367,0.589463,0.89972,00:58
2,1.012616,1.765324,0.549758,0.900993,00:59
3,1.003595,1.799174,0.54594,0.900484,00:59
4,0.917192,1.488612,0.653092,0.934589,00:59


In [None]:
acc = []
runs = 5
for i in range(runs): acc += [finetune(size, epochs=5, arch='xresnet34', encoder_path=f'models/btwins_iwang_sz224_epc100_lmb0.005_encoder.pth')]

In [None]:
np.mean(acc)

### 20 epochs

In [None]:
acc = []
runs = 3
for i in range(runs): acc += [finetune(size, epochs=20, arch='xresnet34', encoder_path=f'models/simclr_iwang_sz{size}_epc100_encoder.pth')]

In [None]:
np.mean(acc)

### 80 epochs

In [None]:
acc = []
runs = 1
for i in range(runs): acc += [finetune(size, epochs=80, arch='xresnet34',encoder_path=f'models/simclr_iwang_sz{size}_epc100_encoder.pth')]

In [None]:
np.mean(acc)

### 200 epochs

In [None]:
acc = []
runs = 1
for i in range(runs): acc += [finetune(size, epochs=200, arch='xresnet34', encoder_path=f'models/simclr_iwang_sz{size}_epc100_encoder.pth')]

In [None]:
np.mean(acc)