<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Define the contrastive model
class ContrastiveModel(nn.Module):
    def __init__(self, feature_dim):
        super(ContrastiveModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 32 * 32, feature_dim)
        )

    def forward(self, x):
        return self.encoder(x)

# Define contrastive loss
def contrastive_loss(features, temperature=0.5):
    """
    Contrastive loss based on feature similarity
    """
    # Normalize features along rows
    features = nn.functional.normalize(features, dim=1)
    # Compute logits (similarity scores)
    logits = torch.matmul(features, features.T) / temperature
    # Create labels for positive pairs
    labels = torch.arange(features.size(0)).to(features.device)
    return nn.CrossEntropyLoss()(logits, labels)

# Augmentation pipeline
transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Dataset and DataLoader
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize model, optimizer, and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ContrastiveModel(feature_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    total_loss = 0
    model.train()
    for images, _ in train_loader:
        images = images.to(device)  # Move data to device
        features = model(images)   # Extract features
        loss = contrastive_loss(features)  # Compute contrastive loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch [{epoch + 1}/10], Loss: {total_loss:.4f}")