# Tutorial: Train MoCo on CIFAR-10

Main Reference:

[Tutorial 2: Train MoCo on CIFAR-10](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_moco_memory_bank.html)

When training self-supervised models using contrastive loss we usually face one big problem. To get good results, we need many negative examples for the contrastive loss to work. Therefore, we need a large batch size. However, not everyone has access to a cluster full of GPUs or TPUs. To solve this problem, alternative approaches have been developed. Some of them use a memory bank to store old negative examples we can query to compensate for the smaller batch size. MoCo takes this approach one step further by including a momentum encoder.

We use the CIFAR-10 dataset for this tutorial.

# Imports

In [1]:
import copy

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision

from lightly.data import LightlyDataset
from lightly.loss import NTXentLoss
from lightly.models.modules.heads import MoCoProjectionHead
from lightly.models.utils import (
    batch_shuffle,
    batch_unshuffle,
    deactivate_requires_grad,
    update_momentum,
)
from lightly.transforms import MoCoV2Transform, utils

# Configuration

The default configuration uses a batch size of 512. This requires around 6.4GB of GPU memory. When training for 100 epochs you should achieve around 73% test set accuracy. When training for 200 epochs accuracy increases to about 80%.

In [2]:
num_workers = 8
batch_size = 256
memory_bank_size = 1024
seed = 1
max_epochs = 10

In [3]:
# The dataset structure should be like this:
# cifar10/train/
#  L airplane/
#    L 10008_airplane.png
#    L ...
#  L automobile/
#  L bird/
#  L cat/
#  L deer/
#  L dog/
#  L frog/
#  L horse/
#  L ship/
#  L truck/
path_to_train = "/data/cifar-10-kaggle/cifar10/cifar10/train"
path_to_test = "/data/cifar-10-kaggle/cifar10/cifar10/test"

In [4]:
pl.seed_everything(seed)

Global seed set to 1


1

# Setup data augmentations and loaders

We start with our data preprocessing pipeline. We can implement augmentations from the MoCo paper using the transforms provided by lightly. Images from the CIFAR-10 dataset have a resolution of 32x32 pixels. Let’s use this resolution to train our model.

In [5]:
# disable blur because we're working with tiny images
moco_transform = MoCoV2Transform(
    input_size=32,
    gaussian_blur=0.0,
)

We don’t want any augmentation for our test data. Therefore, we create custom, torchvision based data transformations. Let’s ensure the size is correct and we normalize the data in the same way as we do with the training data.

In [6]:
# Augmentations typically used to train on cifar-10
train_classifier_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        ),
    ]
)

# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((32, 32)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        ),
    ]
)

# We use the moco augmentations for training moco
dataset_train_moco = LightlyDataset(input_dir=path_to_train, transform=moco_transform)

# Since we also train a linear classifier on the pre-trained moco model we
# reuse the test augmentations here (MoCo augmentations are very strong and
# usually reduce accuracy of models which are not used for contrastive learning.
# Our linear layer will be trained using cross entropy loss and labels provided
# by the dataset. Therefore we chose light augmentations.)
dataset_train_classifier = LightlyDataset(
    input_dir=path_to_train, transform=train_classifier_transforms
)

dataset_test = LightlyDataset(input_dir=path_to_test, transform=test_transforms)

Create the dataloaders to load and preprocess the data in the background.

In [7]:
dataloader_train_moco = torch.utils.data.DataLoader(
    dataset_train_moco,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
)

dataloader_train_classifier = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
)

dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

# Create the MoCo Lightning Module

Now we create our MoCo model. We use PyTorch Lightning to train our model. We follow the specification of the lightning module. In this example we set the number of features for the hidden dimension to 512. The momentum for the Momentum Encoder is set to 0.99 (default is 0.999) since other reports show that this works better for Cifar-10.

For the backbone we use the lightly variant of a Mobilenet-v2. You can use another model following our playground to use custom backbones.

In [8]:
class MocoModel(pl.LightningModule):
    def __init__(self, num_ftrs=512):
        super().__init__()

        self.num_ftrs = num_ftrs

        # create a Mobilenet_v2 backbone and remove the classification head
        mobilenet_v2 = torchvision.models.mobilenet_v2()
        self.backbone = nn.Sequential(
            *list(mobilenet_v2.children())[:-1],
            nn.AdaptiveAvgPool2d(output_size=(1,1)),
            nn.Flatten(),
            nn.Linear(1280, num_ftrs, bias=False)
            )

        # resnet = ResNetGenerator("resnet-18", 1, num_splits=8)
        # self.backbone = nn.Sequential(
        #     *list(resnet.children())[:-1],
        #     nn.AdaptiveAvgPool2d(1),
        # )

        # create a moco model based on ResNet
        self.projection_head = MoCoProjectionHead(num_ftrs, 512, 128)
        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)
        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

        # create our loss with the optional memory bank
        self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size)

    def training_step(self, batch, batch_idx):
        (x_q, x_k), _, _ = batch

        # update momentum
        update_momentum(self.backbone, self.backbone_momentum, 0.99)
        update_momentum(self.projection_head, self.projection_head_momentum, 0.99)

        # get queries
        q = self.backbone(x_q).flatten(start_dim=1)
        q = self.projection_head(q)

        # get keys
        k, shuffle = batch_shuffle(x_k)
        k = self.backbone_momentum(k).flatten(start_dim=1)
        k = self.projection_head_momentum(k)
        k = batch_unshuffle(k, shuffle)

        loss = self.criterion(q, k)
        self.log("train_loss_ssl", loss)
        return loss

    def on_train_epoch_end(self):
        self.custom_histogram_weights()

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.parameters(),
            lr=6e-2,
            momentum=0.9,
            weight_decay=5e-4,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

# Create the Classifier Lightning Module

We create a linear classifier using the features we extract using MoCo and train it on the dataset

In [9]:
class Classifier(pl.LightningModule):
    def __init__(self, backbone):
        super().__init__()
        # use the pretrained ResNet backbone
        self.backbone = backbone

        # freeze the backbone
        deactivate_requires_grad(backbone)

        # create a linear layer for our downstream classification model
        self.fc = nn.Linear(512, 10)

        self.criterion = nn.CrossEntropyLoss()
        self.validation_step_outputs = []

    def forward(self, x):
        y_hat = self.backbone(x).flatten(start_dim=1)
        y_hat = self.fc(y_hat)
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss_fc", loss)
        return loss

    def on_train_epoch_end(self):
        self.custom_histogram_weights()

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)

        # calculate number of correct predictions
        _, predicted = torch.max(y_hat, 1)
        num = predicted.shape[0]
        correct = (predicted == y).float().sum()
        self.validation_step_outputs.append((num, correct))
        return num, correct

    def on_validation_epoch_end(self):
        # calculate and log top1 accuracy
        if self.validation_step_outputs:
            total_num = 0
            total_correct = 0
            for num, correct in self.validation_step_outputs:
                total_num += num
                total_correct += correct
            acc = total_correct / total_num
            self.log("val_acc", acc, on_epoch=True, prog_bar=True)
            self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.fc.parameters(), lr=30.0)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

# Train the MoCo model

`TIP`: Checkout the tensorboard logs while the model is training.

tensorboard –logdir lightning_logs/ to start tensorboard

We can instantiate the model and train it using the lightning trainer.

In [10]:
model = MocoModel()
trainer = pl.Trainer(max_epochs=max_epochs,
                     devices=1,
                     accelerator="gpu",
                     precision=16,
                     )
trainer.fit(model, dataloader_train_moco)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type               | Params
----------------------------------------------------------------
0 | backbone                 | Sequential         | 2.9 M 
1 | projection_head          | MoCoProjectionHead | 328 K 
2 | backbone_momentum        | Sequential         | 2.9 M 
3 | projection_head_momentum | MoCoProjectionHead | 328 K 
4 | criterion                | NTXentLoss         | 0     
----------------------------------------------------------------
3.2 M     Trainable params
3.2 M     Non-trainable params
6.4 M     Total params
25.660    Total estimated model params size (MB)


Epoch 9: 100%|██████████| 195/195 [00:28<00:00,  6.76it/s, v_num=1]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 195/195 [00:31<00:00,  6.21it/s, v_num=1]


Train the Classifier

In [11]:
model.eval()
classifier = Classifier(model.backbone)
trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu")
trainer.fit(classifier, dataloader_train_classifier, dataloader_test)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | backbone  | Sequential       | 2.9 M 
1 | fc        | Linear           | 5.1 K 
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
5.1 K     Trainable params
2.9 M     Non-trainable params
2.9 M     Total params
11.537    Total estimated model params size (MB)


Epoch 9: 100%|██████████| 195/195 [00:09<00:00, 20.21it/s, v_num=2, val_acc=0.149]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 195/195 [00:10<00:00, 17.93it/s, v_num=2, val_acc=0.149]
