Name : Saarthak Khamkar   
Roll No. : D088   
SAP ID : 60009230057

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CIFAR10
import torchvision.models as models
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler  # mixed precision

# Data Augmentation

In [None]:
simclr_transform = T.Compose([
    T.RandomResizedCrop(32),
    T.RandomHorizontalFlip(),
    T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.GaussianBlur(kernel_size=3),
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=simclr_transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

100%|██████████| 170M/170M [00:10<00:00, 16.1MB/s]


# Encoder Network

In [None]:
class Encoder(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.backbone = nn.Sequential(*list(base_model.children())[:-1])
        self.fc = nn.Linear(base_model.fc.in_features, 128)

    def forward(self, x):
        x = self.backbone(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.normalize(x, dim=1)


# Projection Head

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=128, out_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, 512)
        self.fc2 = nn.Linear(512, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# SimCLR Model

In [None]:
class SimCLR(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.encoder = Encoder(base_model)
        self.projection_head = ProjectionHead()

    def forward(self, x):
        x = self.encoder(x)
        x = self.projection_head(x)
        return x

# NT-Xent Loss

In [None]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        z = torch.cat([z_i, z_j], dim=0)  # 2N x D
        sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / self.temperature
        mask = torch.eye(sim.size(0), device=z.device).bool()
        sim = sim.masked_fill(mask, -9e15)  # mask self-similarity
        labels = torch.arange(z_i.size(0), device=z.device)
        labels = torch.cat([labels, labels], dim=0)
        loss = F.cross_entropy(sim, labels)
        return loss

# Training Loop

In [None]:
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    scaler = GradScaler()  # mixed precision

    for x, _ in tqdm(loader):
        x = x.to(device)
        # Create two views for contrastive learning
        x_i = x
        x_j = x.clone()  # You can also apply a separate augmentation here

        optimizer.zero_grad()
        with autocast():  # mixed precision
            z_i, z_j = model(x_i), model(x_j)
            loss = criterion(z_i, z_j)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    return total_loss / len(loader)

# Main

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = models.resnet18(pretrained=True)  # use pretrained for faster convergence
model = SimCLR(base_model).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = NTXentLoss().to(device)

for epoch in range(1):
    loss = train(model, train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 167MB/s]
  scaler = GradScaler()  # mixed precision
  with autocast():  # mixed precision
100%|██████████| 391/391 [26:36<00:00,  4.08s/it]

Epoch 1, Loss: 4500000175654456.5000





Colab Link : https://colab.research.google.com/drive/13bivmbUU0z0gGcpg-vnuLDBTfEdAEtkj?usp=sharing