<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F


# Custom Dataset for SimCLR
class SimCLRDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        x, _ = self.dataset[index]
        x1 = self.transform(x)  # First augmented view
        x2 = self.transform(x)  # Second augmented view
        return x1, x2


# SimCLR Model
class SimCLR(nn.Module):
    def __init__(self, encoder, projection_dim=128):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(encoder.output_dim, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return F.normalize(z, dim=1)  # Normalize projection to unit hypersphere


# Contrastive Loss
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.shape[0]
        z = torch.cat([z_i, z_j], dim=0)  # Combine z_i and z_j
        sim = torch.mm(z, z.t()) / self.temperature  # Cosine similarity
        sim_max = torch.max(sim, dim=1, keepdim=True)[0].detach()  # Stabilize numerics
        exp_sim = torch.exp(sim - sim_max)  # Subtract max for numerical stability

        pos_sim = torch.cat([torch.diag(sim[:batch_size, :batch_size]),
                             torch.diag(sim[batch_size:, batch_size:])], dim=0)  # Positive pairs
        pos_loss = -torch.log(pos_sim / (exp_sim.sum(dim=1) - torch.exp(sim.diag())))

        return pos_loss.mean()


# Encoder (example: a simple CNN or ResNet backbone)
class Encoder(nn.Module):
    def __init__(self, output_dim):
        super(Encoder, self).__init__()
        self.conv = nn.Conv2d(3, output_dim, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.output_dim = output_dim

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        return x.view(x.size(0), -1)


# Example usage
encoder = Encoder(output_dim=256)  # Example simple encoder
model = SimCLR(encoder, projection_dim=128)
criterion = ContrastiveLoss()

transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

train_data = datasets.CIFAR10(root='data', train=True, download=True)
simclr_dataset = SimCLRDataset(train_data, transform)
train_loader = DataLoader(simclr_dataset, batch_size=64, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    total_loss = 0
    for (x1, x2) in train_loader:
        x1, x2 = x1.to(device), x2.to(device)
        z1 = model(x1)
        z2 = model(x2)
        loss = criterion(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}")