<a href="https://colab.research.google.com/github/abibalimi/self-supervised/blob/main/SimCLR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Reproducing SimCLR

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pathlib import Path
import matplotlib.pyplot as plt
import time
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
from tqdm.auto import tqdm

#### Steps:


1.   Data Augmentation
2.   Encoder Network
3.   Projection Head
4.   Contrastive Loss

#### Hyperparameters

In [1]:
BATCH_SIZE = 1024
BASE_LR = 1e-3 #3 * BATCH_SIZE / 256  # Learning rate = 1.2
WEIGHT_DECAY = 1e-6
WARM_UP_RATE = 0.4
TEMPERATURE = 0.5
EPOCHS = 10

#          ***         Data Augmentation         ***         #

In [3]:
# Define Albumentations augmentations
augmentation = A.Compose([
    # Inception-style cropping: random crop, flip, and resize to 32x32
    A.RandomResizedCrop((32, 32), scale=(0.08, 1.0)),
    A.HorizontalFlip(),
    A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])

In [4]:
# Custom dataset to apply Albumentations
class AugmentedDataset(Dataset):
    def __init__(self, dataset, augment):
        self.dataset = dataset
        self.augmentation = augment

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        image = np.array(image)  # Convert PIL Image to numpy array
        x1 = self.augmentation(image=image)['image']
        x2 = self.augmentation(image=image)['image']
        return x1, x2, label


#          ***         Encoder Network (ResNet-18)         ***         #

In [6]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.backbone = resnet18()   # Random initialization
        self.backbone.fc = nn.Identity()  # Remove the final classification layer

    def forward(self, x):
        return self.backbone(x)

#          ***         Projection Head         ***         #

In [7]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

#          ***         SimCLR Model         ***         #

In [8]:
class SimCLR(nn.Module):
    def __init__(self, encoder, projection_head):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projection_head = projection_head

    def forward(self, x1, x2):
        # Encode the two augmented views
        h1 = self.encoder(x1)
        h2 = self.encoder(x2)

        # Project to the lower-dimensional space
        z1 = self.projection_head(h1)
        z2 = self.projection_head(h2)

        return z1, z2

#          ***         Contrastive Loss (NT-Xent)         ***         #

In [9]:
def contrastive_loss(z1, z2, temperature=0.1):
    BATCH_SIZE = z1.size(0)
    z = torch.cat([z1, z2], dim=0)  # Concatenate both views
    z = nn.functional.normalize(z, dim=1)  # Normalize feature vectors

    # Compute similarity matrix
    sim_matrix = torch.matmul(z, z.T) / temperature

    # Create labels for positive pairs
    labels = torch.arange(BATCH_SIZE, device=z.device)
    labels = torch.cat([labels + BATCH_SIZE, labels])  # Positive pairs are diagonal elements

    # Compute cross-entropy loss
    loss = nn.functional.cross_entropy(sim_matrix, labels)
    return loss

# Utils

#### Load CIFAR-10 dataset

#### Initialize the model

#### Learning rate scheduler

In [None]:
def lr_scheduler(optimizer, epochs=EPOCHS, warm_up_rate=WARM_UP_RATE):
    """Schedules the learning rate"""
    warmup_epochs = epochs * warm_up_rate  # 10%
    total_epochs = epochs # 100
    warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=warmup_epochs)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs - warmup_epochs)
    return scheduler, warmup_scheduler, warmup_epochs

#### Save checkpoints

In [None]:
def save_checkpoint(epoch, model, optimizer, scheduler, loss, checkpoint_dir):
    """ Function to save checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss
    }
    torch.save(checkpoint, checkpoint_dir / f"simclr_checkpoint_epoch_{epoch+1}.pth")
    print(f"✅ Checkpoint saved at epoch {epoch}")

In [13]:
!pip install wandb -Uq

In [14]:
import wandb

In [15]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33molush[0m ([33molush-ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Define a sweep

In [16]:
sweep_config = {
    'method': 'random'
    }

In [17]:
metric = {
    'name': 'loss',
    'goal': 'minimize'
    }

sweep_config['metric'] = metric

In [18]:
parameters_dict = {
    'optimizer': {
        'values': ['adam', 'sgd']
        }
    }

sweep_config['parameters'] = parameters_dict

In [20]:
parameters_dict.update({
    'learning_rate': {
        # a flat distribution between 1e-3 and 0.1
        'distribution': 'uniform',
        'min': 1e-3,
        'max': 1.
      },
    'batch_size': {
        # integers between 128 and 1024
        # with evenly-distributed logarithms
        'distribution': 'q_log_uniform_values',
        'q': 8,
        'min': 128,
        'max': 1024,
      }
    })

In [21]:
parameters_dict.update({
    'epochs': {
        'value': 10}
    })

In [24]:
sweep_config['backbone'] = 'ResNet18'

In [25]:
import pprint
pprint.pprint(sweep_config)

{'backbone': 'ResNet18',
 'method': 'random',
 'metric': {'goal': 'minimize', 'name': 'loss'},
 'parameters': {'batch_size': {'distribution': 'q_log_uniform_values',
                               'max': 1024,
                               'min': 128,
                               'q': 8},
                'epochs': {'value': 10},
                'learning_rate': {'distribution': 'uniform',
                                  'max': 1.0,
                                  'min': 0.001},
                'optimizer': {'values': ['adam', 'sgd']}}}


## Initialize the Sweep

In [23]:
sweep_id = wandb.sweep(sweep_config, project="simCLR-sweeps")

Create sweep with ID: u16zfjto
Sweep URL: https://wandb.ai/olush-ai/simCLR-sweeps/sweeps/u16zfjto


## Define SimCLR code


In [None]:
def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config

        loader = build_dataset(config.batch_size)
        network = build_network()
        optimizer = build_optimizer(network, config.optimizer, config.learning_rate)

        # Tell wandb to watch what the model gets up to: gradients, weights, and more!
        wandb.watch(network, log="all", log_freq=10)


        #for epoch in range(config.epochs):
        for epoch in tqdm(range(config.epochs)):
            avg_loss = train_epoch(network, loader, optimizer)
            wandb.log({"loss": avg_loss, "epoch": epoch})
            print(f"Epoch [{epoch+1}/{EPOCHS}], Losses :: Train = {avg_loss:.4f}")

In [27]:
def build_dataset(batch_size):
    """Loads split datasets"""
    # download CIFAR10 training dataset
    dataset = CIFAR10(root='./data', train=True, download=True)
    augmented_dataset = AugmentedDataset(dataset, augmentation)
    loader = DataLoader(augmented_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    return loader


def build_network():
    """Initializes the model/network"""
    encoder = Encoder().to(device)
    projection_head = ProjectionHead().to(device)
    model = SimCLR(encoder, projection_head).to(device)
    return model.to(device)


def build_optimizer(network, optimizer, learning_rate):
    """Initializes the optimizer"""
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9,
                              weight_decay=WEIGHT_DECAY)
    elif optimizer == "adam":
        optimizer = optim.Adam(network.parameters(),
                               lr=learning_rate,
                               weight_decay=WEIGHT_DECAY)
    return optimizer



def train_epoch(model, loader, optimizer):
    cumu_loss = 0
    #for _, (data, target) in enumerate(loader):
    for _, (x1, x2, _) in enumerate(loader):
        # Zero the gradients for every batch!
        optimizer.zero_grad()

        # Move data to device and make predictions for this batch (Forward pass)
        x1, x2 = x1.to(device), x2.to(device)

        # ➡ Forward pass : Compute contrastive loss
        z1, z2 = model(x1, x2)
        loss = contrastive_loss(z1, z2, TEMPERATURE)

        # ⬅ Backward pass + weight update
        loss.backward()
        optimizer.step()

        wandb.log({"batch loss": loss.item()})

    return cumu_loss / len(loader)


def val_epoch(model, loader, optimizer):
    cumu_loss = 0
    # Set the model to evaluation/validation mode
    model.eval()
    with torch.no_grad():
        for batch_idx, (x1, x2, _) in enumerate(loader):
            # Move data to device and predict
            x1, x2 = x1.to(device), x2.to(device)
            z1, z2 = model(x1, x2)

            # Compute contrastive loss
            val_loss = contrastive_loss(z1, z2, TEMPERATURE)
            per_epoch_val_loss += val_loss.item()

    wandb.log({"val batch loss": val_loss.item()})

    return cumu_loss / len(loader)

## Activate sweep agents

In [None]:
wandb.agent(sweep_id, train, count=10)