In [12]:
import torch
import random
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset

In [25]:
import logging
from datetime import datetime

# Logging config
def setup_logging(log_file='simclr_training.log'):
    logging.basicConfig(
        filename=log_file,
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

setup_logging()

In [13]:
# We will implement our loss function to pass it through our CIFAR-10 dataset.
def nt_xent_loss(z1, z2, temperature=0.5):
    N = z1.size(0)
    z = torch.cat([z1, z2], dim=0)  # (2N, D)
    z = F.normalize(z, dim=1)

    sim_matrix = torch.matmul(z, z.T) / temperature

    # Mask self-similarity
    mask = torch.eye(2*N, device=z.device).bool()
    sim_matrix.masked_fill_(mask, -9e15)

    # Positive indices: (i, i+N) and (i+N, i)
    positives = torch.cat([
        torch.arange(N, device=z.device) + N,
        torch.arange(N, device=z.device)
    ])

    pos_sim = sim_matrix[torch.arange(2*N), positives]  # (2N,)
    numerator = torch.exp(pos_sim)
    denominator = torch.exp(sim_matrix).sum(dim=1)

    loss = -torch.log(numerator / denominator)
    return loss.mean()


In [14]:
def test_nt_xent_loss():
    torch.manual_seed(42)
    # Test 1: Identical views → low-ish loss
    z = F.normalize(torch.randn(8, 128), dim=1)
    loss = nt_xent_loss(z, z)
    assert loss.item() < 1.2, f"Expected reasonably low loss for identical views, got {loss.item()}"


    # Test 2: Symmetric property
    z1 = F.normalize(torch.randn(16, 128), dim=1)
    z2 = F.normalize(torch.randn(16, 128), dim=1)
    loss1 = nt_xent_loss(z1, z2)
    loss2 = nt_xent_loss(z2, z1)
    assert torch.allclose(loss1, loss2, atol=1e-5), f"Loss not symmetric: {loss1.item()} vs {loss2.item()}"

    # Test 3: Random projections → moderate/high loss
    z1 = F.normalize(torch.randn(64, 128), dim=1)
    z2 = F.normalize(torch.randn(64, 128), dim=1)
    loss = nt_xent_loss(z1, z2)
    assert loss.item() > 2.0, f"Expected high loss for random vectors, got {loss.item()}"

    # Test 4: Batch size 1 → should raise or return a valid result
    try:
        z1 = F.normalize(torch.randn(1, 128), dim=1)
        z2 = F.normalize(torch.randn(1, 128), dim=1)
        loss = nt_xent_loss(z1, z2)
        assert not torch.isnan(loss), "Loss is NaN for batch size 1"
    except Exception as e:
        print(f"Expected failure on batch=1: {e}")

    print("All tests passed successfully.")

# Run tests
test_nt_xent_loss()

All tests passed successfully.


In [18]:
# SimCLR training requires two augmented views of the same sample.
class SimCLRDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, transform):
        self.base_dataset = base_dataset
        self.transform = transform

    def __getitem__(self, index):
        x, _ = self.base_dataset[index]
        return self.transform(x), self.transform(x)

    def __len__(self):
        return len(self.base_dataset)

In [19]:
# Augmenting the images
def get_simclr_augmentations():
    return transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=3),
        transforms.ToTensor()
    ])

def get_cifar10_dataloader(batch_size=256):
    transform = get_simclr_augmentations()
    base_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
    dataset = SimCLRDataset(base_dataset, transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)

In [24]:
from tqdm import tqdm

def train_simclr(model, data_loader, optimizer, device, temperature=0.5, epoch=1):
    model.train()
    total_loss = 0.0

    loop = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch}", leave=False)

    for batch_idx, (x1, x2) in loop:
        x1, x2 = x1.to(device), x2.to(device)

        z1 = model(x1)
        z2 = model(x2)

        loss = nt_xent_loss(z1, z2, temperature)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / (batch_idx + 1)

        loop.set_postfix(loss=loss.item(), avg_loss=avg_loss)
        if batch_idx % 10 == 0:
            logging.info(f"Epoch [{epoch}] Batch [{batch_idx}/{len(data_loader)}] Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(data_loader)
    logging.info(f"Epoch [{epoch}] Completed. Avg Loss: {avg_loss:.4f}")
    return avg_loss

In [26]:
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimCLR().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    dataloader = get_cifar10_dataloader(batch_size=256)

    for epoch in range(1, 11):
        train_simclr(model, dataloader, optimizer, device, temperature=0.5, epoch=epoch)




KeyboardInterrupt: 