<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Example_SimCLR_for_Vision_SSL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

# Example encoder (ResNet)
class ResNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.resnet.fc = nn.Identity()  # Remove final classification layer

    def forward(self, x):
        return self.resnet(x)

# Define encoder output dimension
encoder_output_dim = 512  # For ResNet18

# SimCLR Model
class SimCLRModel(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(encoder_output_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )

    def forward(self, x):
        h = self.encoder(x)
        return F.normalize(self.projector(h), dim=1)

# Contrastive loss calculation
def contrastive_loss(out1, out2, temperature=0.5):
    logits = torch.matmul(out1, out2.T) / temperature
    labels = torch.arange(len(logits)).to(out1.device)
    return F.cross_entropy(logits, labels)

# Example usage
encoder = ResNetEncoder()
model = SimCLRModel(encoder)

# Example data
inputs = torch.randn(8, 3, 224, 224)  # Batch of images
outputs = model(inputs)

# Simulate two augmented views
out1 = model(inputs)
out2 = model(inputs)

# Calculate contrastive loss
loss = contrastive_loss(out1, out2)
print(loss.item())