<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision.transforms as T
import torchvision.models as models
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10
from torch import nn

# SimCLR-style augmentations
transform = T.Compose([
    T.RandomResizedCrop(size=224),
    T.RandomHorizontalFlip(),
    T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.GaussianBlur(kernel_size=(5, 5)),
    T.ToTensor()
])

# Custom dataset to generate pairs of augmented views
class SimCLRDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        img, _ = self.dataset[index]  # Ignore labels for SimCLR pretraining
        img1 = self.transform(img)  # First augmented view
        img2 = self.transform(img)  # Second augmented view
        return img1, img2

    def __len__(self):
        return len(self.dataset)

# Define a simple encoder network
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.base_model = models.resnet50(pretrained=False)
        self.base_model.fc = nn.Identity()  # Remove the final classification layer
        self.projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )

    def forward(self, x):
        features = self.base_model(x)
        return self.projection_head(features)

# Contrastive loss function
def contrastive_loss(embeddings, temperature=0.5):
    batch_size = embeddings.size(0) // 2  # Two views per sample
    labels = torch.cat([torch.arange(batch_size) for _ in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

    # Normalize embeddings
    embeddings = F.normalize(embeddings, dim=1)

    # Similarity matrix
    similarity_matrix = torch.matmul(embeddings, embeddings.T)
    mask = torch.eye(labels.shape[0], dtype=torch.bool)

    # Remove self-comparisons
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

    # Positives and negatives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
    negatives = similarity_matrix[~labels.bool()].view(labels.shape[0], -1)

    # Contrastive logits
    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long)  # Positives are first
    logits = logits / temperature

    return F.cross_entropy(logits, labels)

# Training loop
def train_simclr(encoder, dataloader, optimizer, temperature=0.5, epochs=10):
    device = torch.device("cpu")  # Use CPU
    encoder.train()
    encoder.to(device)  # Send encoder to CPU

    for epoch in range(epochs):
        total_loss = 0
        for img1, img2 in dataloader:
            img1, img2 = img1.to(device), img2.to(device)  # Send images to CPU
            optimizer.zero_grad()

            # Generate embeddings
            embeddings1 = encoder(img1)
            embeddings2 = encoder(img2)

            # Concatenate embeddings
            embeddings = torch.cat([embeddings1, embeddings2], dim=0)

            # Compute contrastive loss
            loss = contrastive_loss(embeddings, temperature)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")

# Load CIFAR-10 dataset and prepare DataLoader
cifar10_dataset = CIFAR10(root="./data", train=True, download=True, transform=None)
simclr_dataset = SimCLRDataset(cifar10_dataset, transform=transform)
dataloader = DataLoader(simclr_dataset, batch_size=64, shuffle=True, num_workers=2)

# Instantiate model, optimizer, and start training on CPU
encoder = Encoder()
optimizer = torch.optim.Adam(encoder.parameters(), lr=3e-4)

train_simclr(encoder, dataloader, optimizer, temperature=0.5, epochs=10)