<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL)_with_Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets

# Encoder model
class SimpleEncoder(nn.Module):
    def __init__(self, feature_dim=128):
        super(SimpleEncoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.fc = nn.Linear(256 * 4 * 4, feature_dim)

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, start_dim=1)
        return self.fc(x)

# Contrastive loss function
def contrastive_loss(features, temperature=0.5):
    features = F.normalize(features, dim=1)
    batch_size = features.size(0) // 2

    # Compute similarity matrix
    similarity_matrix = torch.matmul(features, features.T) / temperature

    # Create ground truth labels for positive pairs
    labels = torch.cat([torch.arange(batch_size) for _ in range(2)], dim=0).to(features.device)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

    # Mask out self-similarity
    mask = torch.eye(labels.size(0), device=features.device).bool()
    labels = labels[~mask].view(labels.size(0), -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.size(0), -1)

    # Compute loss
    loss = F.cross_entropy(similarity_matrix, labels.argmax(dim=1))
    return loss

# Dataset and transformations
transform = transforms.Compose([
    transforms.RandomResizedCrop(size=32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.CIFAR10(root='/tmp/CIFAR10', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

# Initialize encoder and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = SimpleEncoder().to(device)
optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)

# Training loop
for epoch in range(100):
    epoch_loss = 0
    for data, _ in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass
        features = encoder(data)
        features = torch.cat([features, features], dim=0)  # Simulate positive pairs

        # Compute loss and update weights
        loss = contrastive_loss(features)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {epoch_loss / len(dataloader):.4f}')