<a href="https://colab.research.google.com/github/KAVYANSHTYAGI/Self-Supervised/blob/main/SimCLR_cifar.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

# SimCLR Augmentations with stronger transformations
def get_transforms():
    return transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.8, 0.8, 0.8, 0.2),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=5),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

# Load Dataset (CIFAR-10 for now)
def get_dataset():
    transform = get_transforms()
    dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    return DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)

# Improved Projection Head with BatchNorm
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=512, proj_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(1024, proj_dim)
        self.bn2 = nn.BatchNorm1d(proj_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        return F.normalize(x, dim=1)

# SimCLR Model
class SimCLR(nn.Module):
    def __init__(self, backbone, projection_dim=128):
        super().__init__()
        self.encoder = backbone
        self.projection_head = ProjectionHead(input_dim=512, proj_dim=projection_dim)

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return z

# Updated Contrastive Loss with Lower Temperature
def contrastive_loss(z_i, z_j, temperature=0.07):
    batch_size = z_i.shape[0]
    z = torch.cat([z_i, z_j], dim=0)
    similarity_matrix = torch.mm(z, z.T) / temperature

    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
    similarity_matrix.masked_fill_(mask, float('-inf'))

    labels = torch.cat([torch.arange(batch_size) for _ in range(2)], dim=0).to(z.device)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

    loss = -torch.log((similarity_matrix.exp() * labels).sum(dim=1) / similarity_matrix.exp().sum(dim=1))
    return loss.mean()

# Training
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load Data
    train_loader = get_dataset()

    # Define Model
    resnet = torchvision.models.resnet18(pretrained=False)
    resnet.fc = nn.Identity()  # Remove classification head
    model = SimCLR(resnet).to(device)

    optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-6)

    # Training Loop
    for epoch in range(10):  # Train for 10 epochs first
        model.train()
        total_loss = 0
        for (x, _) in tqdm(train_loader):
            x1, x2 = x.to(device), x.to(device)  # Augmentations applied twice

            z1, z2 = model(x1), model(x2)
            loss = contrastive_loss(z1, z2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")


Files already downloaded and verified


100%|██████████| 196/196 [04:27<00:00,  1.37s/it]


Epoch 1, Loss: 0.0120


100%|██████████| 196/196 [04:28<00:00,  1.37s/it]


Epoch 2, Loss: 0.0088


100%|██████████| 196/196 [04:30<00:00,  1.38s/it]


Epoch 3, Loss: 0.0049


100%|██████████| 196/196 [04:28<00:00,  1.37s/it]


Epoch 4, Loss: 0.0032


100%|██████████| 196/196 [04:28<00:00,  1.37s/it]


Epoch 5, Loss: 0.0027


100%|██████████| 196/196 [04:28<00:00,  1.37s/it]


Epoch 6, Loss: 0.0020


100%|██████████| 196/196 [04:28<00:00,  1.37s/it]


Epoch 7, Loss: 0.0019


100%|██████████| 196/196 [04:28<00:00,  1.37s/it]


Epoch 8, Loss: 0.0014


100%|██████████| 196/196 [04:29<00:00,  1.37s/it]


Epoch 9, Loss: 0.0019


100%|██████████| 196/196 [04:29<00:00,  1.37s/it]

Epoch 10, Loss: 0.0022





tensorflow

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np
import tensorflow_addons as tfa  # For potential Gaussian blur (if needed)
import math

# 1. Data Augmentations (SimCLR style)
def random_resized_crop(image, crop_size=32, scale=(0.2, 1.0)):
    """Randomly crops and resizes the image."""
    # Determine scale factor randomly
    shape = tf.shape(image)
    height, width = shape[0], shape[1]
    scale_factor = tf.random.uniform([], scale[0], scale[1])
    new_size = tf.cast(tf.cast(tf.minimum(height, width), tf.float32) * scale_factor, tf.int32)
    # Random crop then resize back to crop_size x crop_size
    image = tf.image.random_crop(image, size=[new_size, new_size, 3])
    image = tf.image.resize(image, (crop_size, crop_size))
    return image

def color_jitter(image, brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2):
    image = tf.image.random_brightness(image, max_delta=brightness)
    image = tf.image.random_contrast(image, lower=1-contrast, upper=1+contrast)
    image = tf.image.random_saturation(image, lower=1-saturation, upper=1+saturation)
    image = tf.image.random_hue(image, max_delta=hue)
    return image

def random_grayscale(image, p=0.2):
    def to_grayscale():
        gray = tf.image.rgb_to_grayscale(image)
        return tf.image.grayscale_to_rgb(gray)
    return tf.cond(tf.less(tf.random.uniform([], 0, 1), p), to_grayscale, lambda: image)

def gaussian_blur(image, kernel_size=5, sigma=1.0):
    """Applies Gaussian blur using a depthwise conv.
    This is a simple approximation; tensorflow-addons also offers image filtering."""
    # Create 1D Gaussian kernel
    radius = kernel_size // 2
    x = tf.range(-radius, radius + 1, dtype=tf.float32)
    blur_filter = tf.exp(-0.5 * (x / sigma) ** 2)
    blur_filter = blur_filter / tf.reduce_sum(blur_filter)
    # Reshape kernels for separable conv2d
    blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
    blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
    # Expand image dims to [1, h, w, c]
    image = tf.expand_dims(image, axis=0)
    channels = tf.shape(image)[-1]
    # Apply vertical and horizontal blur separately
    image = tf.nn.depthwise_conv2d(image, tf.repeat(blur_v, channels, axis=2), strides=[1, 1, 1, 1], padding='SAME')
    image = tf.nn.depthwise_conv2d(image, tf.repeat(blur_h, channels, axis=2), strides=[1, 1, 1, 1], padding='SAME')
    image = tf.squeeze(image, axis=0)
    return image

def normalize(image):
    # Normalize to mean=0.5, std=0.5 per channel assuming image is in [0,1]
    return (image - 0.5) / 0.5

def simclr_augment(image):
    # Convert image type and scale to [0,1]
    image = tf.cast(image, tf.float32) / 255.0
    image = random_resized_crop(image, crop_size=32)
    image = tf.image.random_flip_left_right(image)
    image = color_jitter(image)
    image = random_grayscale(image, p=0.2)
    image = gaussian_blur(image, kernel_size=5, sigma=1.0)
    image = normalize(image)
    return image

def two_augmentations(image):
    # Create two differently augmented versions of the same image
    return simclr_augment(image), simclr_augment(image)

# 2. Dataset Preparation (CIFAR-10)
def prepare_dataset(batch_size=256):
    (x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()
    dataset = tf.data.Dataset.from_tensor_slices(x_train)
    dataset = dataset.shuffle(buffer_size=10000)
    # For each image, get two augmentations
    dataset = dataset.map(lambda x: two_augmentations(x), num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

# 3. Projection Head with BatchNorm
class ProjectionHead(tf.keras.Model):
    def __init__(self, input_dim=512, proj_dim=128):
        super(ProjectionHead, self).__init__()
        self.dense1 = layers.Dense(1024)
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.ReLU()
        self.dense2 = layers.Dense(proj_dim)
        self.bn2 = layers.BatchNormalization()

    def call(self, x, training=False):
        x = self.dense1(x)
        x = self.bn1(x, training=training)
        x = self.relu(x)
        x = self.dense2(x)
        x = self.bn2(x, training=training)
        # L2 normalization along feature dimension
        x = tf.math.l2_normalize(x, axis=1)
        return x

# 4. SimCLR Model
def get_backbone():
    # Using a simple ResNet-like model. Here we use a ResNet50 without the top layer.
    base_model = tf.keras.applications.ResNet50(include_top=False, weights=None, pooling='avg', input_shape=(32, 32, 3))
    return base_model

class SimCLR(tf.keras.Model):
    def __init__(self, projection_dim=128):
        super(SimCLR, self).__init__()
        self.encoder = get_backbone()
        self.projection_head = ProjectionHead(input_dim=2048, proj_dim=projection_dim)
        # Note: ResNet50 outputs 2048-dimensional features.

    def call(self, x, training=False):
        h = self.encoder(x, training=training)
        z = self.projection_head(h, training=training)
        return z

# 5. Contrastive Loss (NT-Xent Loss)
def contrastive_loss(z1, z2, temperature=0.07):
    batch_size = tf.shape(z1)[0]
    # Concatenate embeddings
    z = tf.concat([z1, z2], axis=0)  # shape: [2*batch_size, feature_dim]
    # Compute similarity matrix
    sim_matrix = tf.matmul(z, z, transpose_b=True) / temperature  # shape: [2B, 2B]

    # Create mask to remove similarity of samples with themselves
    logits_mask = tf.linalg.diag(tf.ones(2 * batch_size)) * -1e9
    sim_matrix = sim_matrix + logits_mask

    # Create labels: for sample i in the first half, positive is i+batch_size; for the second half, positive is i-batch_size
    labels = tf.concat([tf.range(batch_size, 2 * batch_size), tf.range(0, batch_size)], axis=0)

    # For each sample, compute logits and apply cross-entropy loss.
    loss = tf.keras.losses.sparse_categorical_crossentropy(labels, sim_matrix, from_logits=True)
    return tf.reduce_mean(loss)

# 6. Training Loop
if __name__ == "__main__":
    # Set up distributed strategy if needed.
    device = "/GPU:0" if tf.config.list_physical_devices('GPU') else "/CPU:0"
    with tf.device(device):
        # Prepare dataset
        train_dataset = prepare_dataset(batch_size=256)

        # Define model and optimizer
        model = SimCLR(projection_dim=128)
        optimizer = optimizers.Adam(learning_rate=5e-4)

        # Training parameters
        epochs = 10

        # Training loop
        for epoch in range(epochs):
            total_loss = 0.0
            steps = 0
            for (x1, x2) in train_dataset:
                with tf.GradientTape() as tape:
                    # Forward pass for both augmentations
                    z1 = model(x1, training=True)
                    z2 = model(x2, training=True)
                    loss = contrastive_loss(z1, z2, temperature=0.07)
                gradients = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                total_loss += loss.numpy()
                steps += 1
            print(f"Epoch {epoch+1}, Loss: {total_loss/steps:.4f}")
