In [1]:
# Import the necessary libraries
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
# Import torch vision
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
# Import resnet50 model from torchvision
from torchvision.models import resnet50
from torch.optim.optimizer import Optimizer, required
import re



In [2]:
# import a simple dataset for testing model architecture
# import the cifar10 dataset



# Data augmentation

In [3]:
# configure the image height and width for resizing the images to input the model
# As the pytorch resnet model requires the input image to be 224x224 ,even with pre-trained weights equal False
# we will resize the images to 224x224, as size bigger than 224x224 will be cropped to 224x224
image_height = 224
image_width = 224

def get_color_distortion(s=1.0):
    # s is the strength of color distortion.
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([
    rnd_color_jitter,
    rnd_gray])
    return color_distort
# This is the best combination of data augmentation techniques for the SimCLR model shown in the paper
data_transforms = transforms.Compose([
    transforms.RandomResizedCrop((image_height, image_width)), # This follow the random cropping and resizing in the paper
    get_color_distortion(s=1),
    transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)), # In the paper, the kernel size is 10% of the image height and width and sigma is between 0.1 and 2.0. As the kernel size must be odd, we choose 23 as the kernel size.
    transforms.ToTensor()
])




# Define the loss function for SimCLR

In [4]:
def nt_xent_loss(queries, keys, temperature = 0.1):
    b, device = queries.shape[0], queries.device

    n = b * 2  # 同一图片内部不同patch也是负样本
    projs = torch.cat((queries, keys))
    logits = projs @ projs.t()

    mask = torch.eye(n, device=device).bool()
    logits = logits[~mask].reshape(n, n - 1)  # 同一图片内部不同patch也是负样本，除了自己和自己
    logits /= temperature

    labels = torch.cat(((torch.arange(b, device = device) + b - 1), torch.arange(b, device=device)), dim=0)
    loss = F.cross_entropy(logits, labels, reduction = 'sum')
    loss /= n
    return loss



# Get the backbone model f(.) to train on the data augmentation dataset.

In [5]:

# Load the resnet50 model which returns the features before the classification layer
model = resnet50(pretrained=False) # Optionally, you can set pretrained=True to use the pre-trained weights 
# return the features before the classification layer
model.fc = nn.Identity() # Remove the classification layer
# Print the model architecture
# print(model)
# get the output shape of the model by passing a random tensor of the image size
# print(model(torch.randn(1, 3, image_height, image_width)).shape)




# Define the model
class SimCLR(nn.Module):
    def __init__(self, model, temperature=0.1):
        super(SimCLR, self).__init__()
        # get the device of the model
        self.model = model
        # This is the two-layer MLP projection head as described in the paper whcih represents the g(.) function
        self.projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )
        self.temperature = temperature
        # Define the cosine similarity function
        # cosine_similarity = lambda z_i, z_j: torch.dot(z_i, z_j) / (torch.norm(z_i) * torch.norm(z_j))
    def forward(self, x_i, x_j):
        h_i = self.model(x_i)
        h_j = self.model(x_j)
        # print(h.shape)
        z_i = self.projection_head(h_i)
        # print(z_i.shape)
        z_j = self.projection_head(h_j)
        # get the normalized projection head output
        # z_i = nn.functional.normalize(z_i, dim=1)
        # z_j = nn.functional.normalize(z_j, dim=1)

        # Loss calculation by nt_xent_loss function
        loss = nt_xent_loss(z_i, z_j, self.temperature)
        return loss
# Create the SimCLR model




In [6]:
# define the Lars optimizer from scratch 
# As the the coursework limit to use of 4 external libraries, we will implement the LARS optimizer from scratch

from torch.optim.optimizer import Optimizer, required
import re

EETA_DEFAULT = 0.001


class LARS(Optimizer):
    """
    Layer-wise Adaptive Rate Scaling for large batch training.
    Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
    I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
    """

    def __init__(
        self,
        params,
        lr=required,
        momentum=0.9,
        use_nesterov=False,
        weight_decay=0.0,
        exclude_from_weight_decay=None,
        exclude_from_layer_adaptation=None,
        classic_momentum=True,
        eeta=EETA_DEFAULT,
    ):
        """Constructs a LARSOptimizer.
        Args:
        lr: A `float` for learning rate.
        momentum: A `float` for momentum.
        use_nesterov: A 'Boolean' for whether to use nesterov momentum.
        weight_decay: A `float` for weight decay.
        exclude_from_weight_decay: A list of `string` for variable screening, if
            any of the string appears in a variable's name, the variable will be
            excluded for computing weight decay. For example, one could specify
            the list like ['batch_normalization', 'bias'] to exclude BN and bias
            from weight decay.
        exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
            for layer adaptation. If it is None, it will be defaulted the same as
            exclude_from_weight_decay.
        classic_momentum: A `boolean` for whether to use classic (or popular)
            momentum. The learning rate is applied during momeuntum update in
            classic momentum, but after momentum for popular momentum.
        eeta: A `float` for scaling of learning rate when computing trust ratio.
        name: The name for the scope.
        """

        self.epoch = 0
        defaults = dict(
            lr=lr,
            momentum=momentum,
            use_nesterov=use_nesterov,
            weight_decay=weight_decay,
            exclude_from_weight_decay=exclude_from_weight_decay,
            exclude_from_layer_adaptation=exclude_from_layer_adaptation,
            classic_momentum=classic_momentum,
            eeta=eeta,
        )

        super(LARS, self).__init__(params, defaults)
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.use_nesterov = use_nesterov
        self.classic_momentum = classic_momentum
        self.eeta = eeta
        self.exclude_from_weight_decay = exclude_from_weight_decay
        # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
        # arg is None.
        if exclude_from_layer_adaptation:
            self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
        else:
            self.exclude_from_layer_adaptation = exclude_from_weight_decay

    def step(self, epoch=None, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        if epoch is None:
            epoch = self.epoch
            self.epoch += 1

        for group in self.param_groups:
            weight_decay = group["weight_decay"]
            momentum = group["momentum"]
            eeta = group["eeta"]
            lr = group["lr"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                param = p.data
                grad = p.grad.data

                param_state = self.state[p]

                # TODO: get param names
                # if self._use_weight_decay(param_name):
                grad += self.weight_decay * param

                if self.classic_momentum:
                    trust_ratio = 1.0

                    # TODO: get param names
                    # if self._do_layer_adaptation(param_name):
                    w_norm = torch.norm(param)
                    g_norm = torch.norm(grad)

                    device = g_norm.get_device()
                    trust_ratio = torch.where(
                        w_norm.gt(0),
                        torch.where(
                            g_norm.gt(0),
                            (self.eeta * w_norm / g_norm),
                            torch.Tensor([1.0]).to(device),
                        ),
                        torch.Tensor([1.0]).to(device),
                    ).item()

                    scaled_lr = lr * trust_ratio
                    if "momentum_buffer" not in param_state:
                        next_v = param_state["momentum_buffer"] = torch.zeros_like(
                            p.data
                        )
                    else:
                        next_v = param_state["momentum_buffer"]

                    next_v.mul_(momentum).add_(scaled_lr, grad)
                    if self.use_nesterov:
                        update = (self.momentum * next_v) + (scaled_lr * grad)
                    else:
                        update = next_v

                    p.data.add_(-update)
                else:
                    raise NotImplementedError

        return loss

    def _use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if not self.weight_decay:
            return False
        if self.exclude_from_weight_decay:
            for r in self.exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def _do_layer_adaptation(self, param_name):
        """Whether to do layer-wise learning rate adaptation for `param_name`."""
        if self.exclude_from_layer_adaptation:
            for r in self.exclude_from_layer_adaptation:
                if re.search(r, param_name) is not None:
                    return False
        return True



# Start training the model

In [14]:
# Start the training loop for SimCLR
model = resnet50(pretrained=False)
model.fc = nn.Identity()
simclr_model = SimCLR(model)
# Hyperparameters
batch_size = 64
learning_rate = 0.0001
num_epochs = 30

# Simple transform for the original CIFAR-10 dataset and resize the images to 224x224
transform = transforms.Compose([
    transforms.Resize((image_height, image_width)),
    transforms.ToTensor()
])
# define the dataset with and without data augmentation and with 
train_dataset = datasets.CIFAR10(root='dataset/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# move the model to the device
model = model.to(device)  # this is the model you want to save for pre-training, where f(.) is the ResNet-50
simclr_model = simclr_model.to(device)

# Define the optimizer
optimizer = LARS(
    [params for params in model.parameters() if params.requires_grad],
    lr=0.2,
    weight_decay=1e-6,
    exclude_from_weight_decay=["batch_normalization", "bias"],
)

# Start the training loop for SimCLR
for epoch in range(num_epochs):
    for data in train_loader:
        images_i, _ = data
        images_i = images_i.to(device)
        # apply the data augmentation techniques to the images
        images_j = data_transforms(data[0])
        images_j = images_j.to(device)
        # Perform the forward pass
        loss = simclr_model(images_i, images_j)
        # Perform the backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# Save the model
torch.save(model.state_dict(), 'simclr_backbone.ckpt')
torch.save(simclr_model.state_dict(), 'simclr_model.ckpt')


Files already downloaded and verified




TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>