# BYOL Linear Evaluation on Imagenet
This notebook implements the linear evaluation protocol stated in Section C.1 of the Appendix of "Bootstrap Your Own Latent - A New Approach to Self-Supervised Learning" paper. 

The authors state the following: 

"At training time, we apply spatial augmentations, i.e., random crops with resize to 224 × 224 pixels, and random flips. At test time, images are resized to 256 pixels along the shorter side using bicubic resampling, after which a 224 × 224 center crop is applied. In both cases, we normalize the color channels by subtracting the average color and dividing by the standard deviatio (computed on ImageNet), after applying the augmentations. We optimize the cross-entropy loss using SGD with Nesterov momentum over
80 epochs, using a batch size of 1024 and a momentum of 0.9. [...] We finally sweep over 5 learning rates {0.4, 0.3, 0.2, 0.1, 0.05} on a local validation set (10009 images from ImageNet train set), and report the accuracy of the best validation hyperparameter on the test set.

This is the basic setup, which should result in 74.3% accuracy when using ResNet50 (authors use a modified protocol for slightly better results)

In [10]:
import numpy as np
import random
import torch
import os
from torchvision import transforms, datasets, models
from torch.utils.data.dataloader import DataLoader
import torch.nn as nn
import tqdm.notebook

In [11]:
# Get device (Can use multiple GPU, DataParallel added below!)
device = torch.device(TODO)

# Set model path
CHECKPOINT_PATH = TODO
IMAGENET_PATH = TODO

# Set seed
SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)


In [12]:
# Define ResNet50 model (https://github.com/yaox12/BYOL-PyTorch/blob/master/model/basic_modules.py)
class ResNet(torch.nn.Module):
    def __init__(self, net_name, pretrained=False, use_fc=False):
        super().__init__()
        base_model = models.__dict__[net_name](pretrained=pretrained)
        self.encoder = torch.nn.Sequential(*list(base_model.children())[:-1])

        self.use_fc = use_fc
        if self.use_fc:
            self.fc = torch.nn.Linear(2048, 1000)

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        if self.use_fc:
            x = self.fc(x)
        return x


In [13]:
# Create model
model = ResNet("resnet50", pretrained=False, use_fc=True)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)
print("Created ResNet model!")

Created ResNet model!


In [14]:
# Load Weights
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)["model"]
state_dict = {}
length = len(model.encoder.state_dict())
for name, param in zip(model.encoder.state_dict(), list(checkpoint.values())[:length]):
    state_dict[name] = param
print("Loaded Weights")

Loaded Weights


## Beginning LR Sweep Phase

Sweeping over learning rates to find best lr. Then, we will perform training and testing with the best lr.

In [34]:
# Training HP (from BYOL paper)
BATCH_SIZE = 1024 #Paper used 1024, I can't
NUM_EPOCHS = 80
MOMENTUM = 0.9
lr_list = [0.4, 0.3, 0.2, 0.1, 0.5] #n. We finally sweep over 5 learning rates {0.4, 0.3, 0.2, 0.1, 0.05} 

In [35]:
# Data Augmentations (from BYOL paper)
#At training time, 
#   we apply spatial augmentations, i.e., random crops with resize to 224 × 224 pixels, and random flips. 

# At test time, images are resized to 256 pixels along the shorter side using bicubic resampling, 
#   after which a 224 × 224 center crop is applied. 
# 
# In both cases, we normalize the color channels by subtracting the average color and dividing by the standard deviation (computed on ImageNet), after applying the augmentations
train_transforms = transforms.Compose([
                                        transforms.RandomResizedCrop((224, 224)),
                                        transforms.RandomHorizontalFlip(p=0.5),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                        ])
test_transforms = transforms.Compose([
                                        transforms.Resize(256, transforms.InterpolationMode.BICUBIC),
                                        transforms.RandomCrop((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                        ])
                                        
# Get Dataloaders
train_dataset = datasets.ImageFolder(os.path.join(IMAGENET_PATH, "train"), transform=train_transforms)
val_dataset = datasets.ImageFolder(os.path.join(IMAGENET_PATH, "train"), transform=test_transforms)
num_train_samples = len(train_dataset) 

 # Val is 10,009 of train samples for lr sweep
indices = list(range(num_train_samples))
random.shuffle(indices)
train_dataset = torch.utils.data.dataset.Subset(train_dataset, indices[10009:])
val_dataset = torch.utils.data.dataset.Subset(val_dataset, indices[:10009])

assert len(train_dataset) == num_train_samples - 10009
assert len(val_dataset) == 10009

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                          num_workers=8, drop_last=False, shuffle=True)

val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                          num_workers=8, drop_last=False, shuffle=True)

In [36]:
# Utils 
# Code taken from https://github.com/facebookresearch/moco/blob/main/main_lincls.py
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [37]:
# Train and Eval functions
# Code taken from https://github.com/facebookresearch/moco/blob/main/main_lincls.py
def evaluate(model, val_dataloder):
    top1 = AverageMeter('Acc@1', ':6.2f')
    model.eval()
    with torch.no_grad():
        for images, target in val_dataloder:
            images, target = images.to(device), target.to(device)

            output = model(images)
            acc1 = accuracy(output, target)
            top1.update(acc1[0], images.size(0))
    
    print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
    return top1.avg


def train_linear_eval(model, train_dataloader, val_dataloader, learning_rate):
    # Freeze backbone
    for name, param in model.named_parameters():
        if name not in ['fc.weight', 'fc.bias']:
            param.requires_grad = False

    # Init FC layer (per source above)
    model.fc.weight.data.normal_(mean=0.0, std=0.01)
    model.fc.bias.data.zero_()

    # optimize only the linear classifier
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    assert len(parameters) == 2  # fc.weight, fc.bias
    optimizer = torch.optim.SGD(parameters, learning_rate,
                                nesterov=True,
                                momentum=MOMENTUM)
    
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    best_val_acc = 0.
    for epoch in tqdm.notebook.trange(NUM_EPOCHS, desc="training", unit="epoch"):
        with tqdm.notebook.tqdm(train_dataloader, desc="epoch {}".format(epoch + 1), unit="batch", total=len(train_dataloader)) as batch_iterator:
            for images, target in batch_iterator:
                images, target = images.to(device), target.to(device)

                output = model(images)
                loss = criterion(output, target)

                # compute gradient and do SGD step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
            val_acc = evaluate(model, val_dataloader)
            best_val_acc = max(val_acc, best_val_acc)
    print(f"Training Complete. Achieved best validation accuracy of {best_val_acc:.3f}")
    return best_val_acc

In [38]:
# Sweeps over LR and returns best one and its corresp. val accuracy
def sweep_lr(lr_list, model, train_dataloader, val_dataloader):
    best_lr = None
    best_lr_acc = 0.
    for lr in tqdm.notebook.tqdm(lr_list):
        print(f"Training with LR = {lr}")
        best_acc = train_linear_eval(model, train_dataloader, val_dataloader, lr)
        if best_acc > best_lr_acc:
            best_lr_acc = best_acc
            best_lr = lr
    return best_lr, best_lr_acc
             

In [39]:
LR, acc = sweep_lr(lr_list, model, train_dataloader, val_dataloader)
print(f"Found Best LR = {LR}, which results in validation accuracy of {acc:3.f}")

  0%|          | 0/5 [00:00<?, ?it/s]

Training with LR = 0.4


training:   0%|          | 0/8 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/1271158 [00:00<?, ?batch/s]

KeyboardInterrupt: 

## Training and Testing Linear Evaluation using Best LR

Paper reports 74.3%

In [None]:
del val_loader
del val_dataset

train_dataset = datasets.ImageFolder(os.path.join(IMAGENET_PATH, "train"), transform=train_transforms)

test_dataset = datasets.ImageFolder(os.path.join(IMAGENET_PATH, "val"), transform=test_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                          num_workers=8, drop_last=False, shuffle=True)

test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                          num_workers=8, drop_last=False, shuffle=True)

train_linear_eval(model, train_dataloader, test_dataloader, LR)

NameError: name 'train_loader' is not defined