<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder
        self.projection_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        features = self.encoder(x)
        features = nn.functional.normalize(features, dim=-1)  # Normalize features
        projections = self.projection_head(features)
        projections = nn.functional.normalize(projections, dim=-1)  # Normalize projections
        return projections

def NTXentLoss(projections, temperature):
    """
    Compute the normalized temperature-scaled cross-entropy loss.
    """
    batch_size = projections.size(0)
    # Cosine similarity between all pairs in the batch
    similarity_matrix = torch.matmul(projections, projections.T) / temperature

    # Mask out similarities of the same samples (diagonal entries)
    mask = torch.eye(batch_size, device=projections.device).bool()
    similarity_matrix.masked_fill_(mask, float('-inf'))

    # Targets: positive pairs are diagonal block indices in augmented batch
    labels = torch.arange(batch_size, device=projections.device)
    loss = nn.CrossEntropyLoss()(similarity_matrix, labels)
    return loss

# Data preparation with augmentations
transform = transforms.Compose([
    transforms.RandomResizedCrop(size=32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter()], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# CIFAR-10 dataset
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# Base encoder (adjusted architecture for CIFAR-10)
base_encoder = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(32 * 8 * 8, 512)  # Output size after convolution and flattening
)
model = SimCLR(base_encoder, projection_dim=128).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# Optimizer and parameters
optimizer = optim.Adam(model.parameters(), lr=3e-4)
temperature = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training loop
for epoch in range(20):  # Training for 20 epochs
    for images, _ in dataloader:
        images = images.to(device)

        # Forward pass
        projections = model(images)

        # Combine projections and compute contrastive loss
        loss = NTXentLoss(projections, temperature)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/20], Loss: {loss.item():.4f}")