<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10

# Data augmentation for contrastive learning
transform = transforms.Compose([
    transforms.RandomResizedCrop(size=32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load CIFAR10 dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Define SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder
        self.encoder.fc = nn.Identity()  # Remove the final FC layer
        self.projection_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return h, z

# Define the NT-Xent loss (Normalized Temperature-scaled Cross Entropy)
class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def forward(self, z_i, z_j):
        batch_size = z_i.shape[0]
        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = self.similarity_f(representations.unsqueeze(1), representations.unsqueeze(0))
        labels = torch.cat([torch.arange(batch_size) for _ in range(2)], dim=0).to(z_i.device)
        masks = torch.eye(labels.shape[0], dtype=torch.bool).to(z_i.device)
        similarity_matrix = similarity_matrix[~masks].view(labels.shape[0], -1)

        positives = similarity_matrix[range(labels.shape[0]), labels].view(labels.shape[0], 1)
        negatives = similarity_matrix

        logits = torch.cat([positives, negatives], dim=1)
        logits /= self.temperature
        loss = self.criterion(logits, labels)
        loss /= labels.shape[0]
        return loss

# Initialize the SimCLR model, optimizer, and loss function
base_encoder = models.resnet18(weights=None)
projection_dim = 128
simclr_model = SimCLR(base_encoder, projection_dim)
optimizer = optim.Adam(simclr_model.parameters(), lr=0.001)
criterion = NTXentLoss(batch_size=64, temperature=0.5)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    simclr_model.train()
    for (x_i, _), (x_j, _) in zip(train_loader, train_loader):
        # Removed .cuda() calls to use CPU
        _, z_i = simclr_model(x_i)
        _, z_j = simclr_model(x_j)

        loss = criterion(z_i, z_j)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

print("Training complete.")