<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights

# Define a model with ResNet18 as the base and a projection head
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(SimCLR, self).__init__()
        # Replace the ResNet's final fully connected layer with an identity
        self.base_model = base_model
        in_features = base_model.fc.in_features
        self.base_model.fc = nn.Identity()

        # Add a projection head
        self.fc = nn.Sequential(
            nn.Linear(in_features, projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x):
        # Extract features from the ResNet base model
        features = self.base_model(x)  # Output shape: (batch_size, in_features)
        embeddings = self.fc(features)  # Output shape: (batch_size, projection_dim)
        return embeddings

# Example training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model = SimCLR(base_model=base_model).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0003)
criterion = nn.MSELoss()  # Replace this with contrastive loss for actual SimCLR training

# Dummy data loader with random tensors simulating image inputs
from torch.utils.data import DataLoader, TensorDataset
batch_size = 32
dummy_images = torch.randn(100, 3, 224, 224)  # Simulate 100 images of size 3x224x224
dummy_dataset = TensorDataset(dummy_images)
dataloader = DataLoader(dummy_dataset, batch_size=batch_size, shuffle=True)

# Training loop
epochs = 5
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for data in dataloader:
        images = data[0].to(device)  # Simulate image inputs
        optimizer.zero_grad()

        # Forward pass
        embeddings = model(images)

        # Example dummy loss (contrastive loss to be used in practice)
        loss = criterion(embeddings, embeddings)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(dataloader):.4f}")