<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# Data augmentation for contrastive learning
transform = transforms.Compose([
    transforms.RandomResizedCrop(size=32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Define a simple encoder
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.model = torchvision.models.resnet18(weights=None)
        self.model.fc = nn.Identity()

    def forward(self, x):
        return self.model(x)

# Define the projection head
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionHead, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim)
        )

    def forward(self, x):
        return self.model(x)

# Contrastive loss
def contrastive_loss(out_1, out_2, temperature):
    out_1 = nn.functional.normalize(out_1, dim=1)
    out_2 = nn.functional.normalize(out_2, dim=1)
    return -torch.mean(torch.log(torch.exp(torch.sum(out_1 * out_2, dim=1) / temperature)))

# Training loop
device = torch.device("cpu")  # Use CPU
encoder = Encoder().to(device)
projector = ProjectionHead(input_dim=512, output_dim=128).to(device)
optimizer = optim.Adam(list(encoder.parameters()) + list(projector.parameters()), lr=0.001)

for epoch in range(10):
    for (x_i, _), (x_j, _) in zip(train_loader, train_loader):
        x_i, x_j = x_i.to(device), x_j.to(device)
        out_1 = projector(encoder(x_i))
        out_2 = projector(encoder(x_j))
        loss = contrastive_loss(out_1, out_2, temperature=0.5)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")