<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Define the SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_encoder, encoder_output_dim, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder  # Base encoder (e.g., ResNet or simple MLP)
        self.projector = nn.Sequential(
            nn.Linear(encoder_output_dim, projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x):
        # Extract features from the encoder
        features = self.encoder(x)
        # Project features to a latent space
        projections = self.projector(features)
        return projections

# Example base encoder (simple MLP for demonstration)
base_encoder = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 256),
    nn.ReLU(),
    nn.Linear(256, 128)  # Last layer's output dimension is 128
)

# Instantiate the SimCLR model
model = SimCLR(base_encoder, encoder_output_dim=128, projection_dim=64)

# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Dummy dataset for demonstration
# Simulates input images (28x28) and corresponding dummy labels
dummy_images = torch.randn(1000, 1, 28, 28)  # 1000 grayscale images
dummy_labels = torch.randint(0, 10, (1000,))  # 10 random classes for illustration
dataset = TensorDataset(dummy_images, dummy_labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Training loop
epochs = 10
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for images, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(images)  # Forward pass through SimCLR
        # For demonstration, compare outputs with labels (classification surrogate)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(dataloader):.4f}")

# Forward pass to verify output shape
with torch.no_grad():
    sample_input = torch.randn(32, 1, 28, 28)  # Batch of 32, 28x28 grayscale images
    sample_output = model(sample_input)
    print(f"Sample output shape: {sample_output.shape}")  # Expected: (32, 64)