<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch torchvision

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the contrastive loss function (SimCLR)
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        # Normalize the representations
        z_i = torch.nn.functional.normalize(z_i, dim=-1)
        z_j = torch.nn.functional.normalize(z_j, dim=-1)

        # Compute the similarity matrix
        similarity_matrix = torch.matmul(z_i, z_j.T) / self.temperature

        # Compute labels (identity mapping for positives)
        labels = torch.arange(z_i.size(0)).long().to(z_i.device)
        loss = nn.CrossEntropyLoss()(similarity_matrix, labels)
        return loss

# Data transformations and dataloaders
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Model (ResNet18 as a backbone)
model = resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 128)  # Project to 128 dimensions
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
contrastive_loss = ContrastiveLoss()

# Training loop
num_epochs = 10
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    for images, _ in dataloader:
        images = images.to(device)

        # Create two augmentations of the same batch (SimCLR style)
        augmented_images_1 = images
        augmented_images_2 = torch.flip(images, [3])  # Simple example; replace with stronger augmentations

        # Combine for joint forward pass
        augmented_images = torch.cat([augmented_images_1, augmented_images_2], dim=0)

        # Forward pass
        features = model(augmented_images)
        z_i, z_j = features[:len(features)//2], features[len(features)//2:]

        # Compute the contrastive loss
        loss = contrastive_loss(z_i, z_j)
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print epoch progress
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")