<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision.models import resnet50

# Define the SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim):
        super(SimCLR, self).__init__()
        self.backbone = nn.Sequential(*list(base_model.children())[:-1])
        self.projection = nn.Sequential(
            nn.Linear(base_model.fc.in_features, projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, start_dim=1)
        x = self.projection(x)
        return x

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
dataset = CIFAR10(root='data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Define the SimCLR model with ResNet50 backbone
base_model = resnet50(pretrained=False)
simclr_model = SimCLR(base_model, projection_dim=128)

# Example of forward pass
for images, _ in dataloader:
    projections = simclr_model(images)
    print(projections.shape)  # Expected shape: [batch_size, projection_dim]
    break