<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F  # Add this import
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

# Define a simple encoder network
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.fc = nn.Linear(128 * 8 * 8, 128)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.normalize(x, dim=1)

# Define SimCLR model
class SimCLR(nn.Module):
    def __init__(self, encoder):
        super(SimCLR, self).__init__()
        self.encoder = encoder

    def forward(self, x):
        return self.encoder(x)

# Load CIFAR-10 dataset with augmentations
transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize the encoder and SimCLR model
encoder = Encoder()
simclr = SimCLR(encoder)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(simclr.parameters(), lr=0.001)

# Additional transforms for augmentations
augmentations = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
])

# Training loop
for epoch in range(10):
    for (images, _) in train_loader:
        # Apply two different augmentations to the same images
        aug1 = torch.stack([augmentations(img) for img in images])
        aug2 = torch.stack([augmentations(img) for img in images])

        # Forward pass
        z1 = simclr(aug1)
        z2 = simclr(aug2)

        # Compute similarity and loss
        sim_matrix = torch.matmul(z1, z2.T)
        loss = criterion(sim_matrix, torch.arange(sim_matrix.size(0)))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")