<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/SimCLR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
import numpy as np

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a stronger transformation pipeline
transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
dataset = datasets.CIFAR10(root='./data', transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)

# Define the SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(SimCLR, self).__init__()
        self.backbone = base_model(weights=None)
        self.backbone_fc_in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.projection = nn.Sequential(
            nn.Linear(self.backbone_fc_in_features, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.backbone(x)
        z = self.projection(h)
        return h, z

# Initialize the model
model = SimCLR(base_model=resnet50).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Contrastive learning loss (NT-Xent)
def nt_xent_loss(out_1, out_2, temperature=0.5):
    batch_size = out_1.shape[0]
    out = torch.cat([out_1, out_2], dim=0)
    sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
    mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()
    sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)

    positive_samples = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
    positive_samples = torch.cat([positive_samples, positive_samples], dim=0)
    loss = -torch.log(positive_samples / sim_matrix.sum(dim=-1))
    return loss.mean()

# Training loop
for epoch in range(100):
    for img1, img2 in dataloader:
        img1, img2 = img1.to(device), img2.to(device)
        _, out1 = model(img1)
        _, out2 = model(img2)
        loss = nt_xent_loss(out1, out2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Save the model checkpoint
torch.save(model.state_dict(), 'simclr_model.pth')

# Load the model checkpoint
model.load_state_dict(torch.load('simclr_model.pth'))