<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets, models

class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(SimCLR, self).__init__()
        self.encoder = nn.Sequential(*list(base_model.children())[:-1])
        self.projection_head = nn.Sequential(
            nn.Linear(base_model.fc.in_features, projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return self.projection_head(x)

transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

base_model = models.resnet18(pretrained=False)
simclr_model = SimCLR(base_model)
optimizer = optim.Adam(simclr_model.parameters(), lr=0.001)

for epoch in range(10):
    for images, _ in dataloader:
        images1, images2 = images, images  # Simulating two augmented views
        z1 = simclr_model(images1)
        z2 = simclr_model(images2)

        # Contrastive loss computation (placeholder)
        loss = ((z1 - z2) ** 2).mean()  # Simplified loss for demonstration
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')