<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim

class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_model
        num_ftrs = self.encoder.fc.in_features  # Get the number of input features
        self.encoder.fc = nn.Identity()  # Remove the classification head
        self.projector = nn.Sequential(
            nn.Linear(num_ftrs, projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        h = torch.flatten(h, start_dim=1)  # Flatten the tensor
        z = self.projector(h)
        return z

transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    transforms.ToTensor()
])

dataset = datasets.CIFAR10(root='data', transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

base_model = torchvision.models.resnet18(weights=None)
simclr_model = SimCLR(base_model, projection_dim=128)

optimizer = optim.Adam(simclr_model.parameters(), lr=3e-4)

def contrastive_loss(out_1, out_2, temperature):
    batch_size = out_1.shape[0]
    out_1 = nn.functional.normalize(out_1, dim=1)
    out_2 = nn.functional.normalize(out_2, dim=1)
    out = torch.cat([out_1, out_2], dim=0)
    sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
    mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size)).bool()
    sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)
    pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
    pos_sim = torch.cat([pos_sim, pos_sim], dim=0)  # Duplicate for both batches
    loss = -torch.log(pos_sim / sim_matrix.sum(dim=-1)).mean()
    return loss

for epoch in range(100):
    for batch in dataloader:
        images, _ = batch
        # Duplicate the images to create pairs
        images_1 = images.clone()
        images_2 = images.clone()
        out_1 = simclr_model(images_1)
        out_2 = simclr_model(images_2)
        loss = contrastive_loss(out_1, out_2, temperature=0.5)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")