<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset


# Define SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim):
        super(SimCLR, self).__init__()
        # Extract ResNet backbone (exclude fc layer)
        self.encoder = nn.Sequential(*list(base_model.children())[:-1])  # Remove classification head
        # Projection head
        self.projector = nn.Sequential(
            nn.Linear(base_model.fc.in_features, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)  # Backbone feature extraction
        h = h.view(h.size(0), -1)  # Flatten features
        z = self.projector(h)  # Projection head for contrastive learning
        return F.normalize(z, dim=1)  # Normalize for stability


# Define random dataset with two augmented views
class RandomDataset(Dataset):
    def __init__(self, size, transform=None):
        self.size = size
        self.transform = transform

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        image = torch.randn(3, 224, 224)  # Simulated random image
        view_1 = self.transform(image) if self.transform else image
        view_2 = self.transform(image) if self.transform else image
        return view_1, view_2


# Define contrastive loss function
def contrastive_loss(out_1, out_2, temperature=0.5):
    """
    Computes the contrastive loss using InfoNCE.
    Args:
        out_1: First augmented view of shape [batch_size, projection_dim].
        out_2: Second augmented view of shape [batch_size, projection_dim].
        temperature: Temperature parameter for scaling.
    Returns:
        Scalar loss value.
    """
    batch_size = out_1.size(0)
    out = torch.cat([out_1, out_2], dim=0)  # Combine positive pairs
    sim_matrix = F.cosine_similarity(out.unsqueeze(1), out.unsqueeze(0), dim=2) / temperature  # Full similarity matrix

    # Positive similarities (diagonal offsets)
    pos_sim = torch.cat([
        sim_matrix[torch.arange(batch_size), torch.arange(batch_size, batch_size * 2)],
        sim_matrix[torch.arange(batch_size, batch_size * 2), torch.arange(batch_size)]
    ])

    # Mask out self-similarities
    mask = torch.eye(batch_size * 2, device=sim_matrix.device).bool()
    sim_matrix = sim_matrix[~mask].view(batch_size * 2, -1)

    # Compute InfoNCE loss
    loss = -torch.log(torch.exp(pos_sim) / sim_matrix.exp().sum(dim=1)).mean()
    return loss


# Augmentation pipeline
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),  # Brightness, Contrast, Saturation, Hue
    transforms.GaussianBlur(kernel_size=3),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Dataset and DataLoader
dataset = RandomDataset(1000, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize SimCLR model and optimizer
model = SimCLR(models.resnet18(weights=None), projection_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    model.train()
    epoch_loss = 0
    for view_1, view_2 in dataloader:
        optimizer.zero_grad()
        out_1 = model(view_1)
        out_2 = model(view_2)
        loss = contrastive_loss(out_1, out_2, temperature=0.5)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {epoch_loss / len(dataloader):.4f}')