<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class SimCLRModel(nn.Module):
    def __init__(self, base_model):
        super(SimCLRModel, self).__init__()
        in_features = base_model.fc.in_features
        self.encoder = base_model
        self.encoder.fc = nn.Identity()  # Remove the classifier layer
        self.projection_head = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.projection_head(x)
        return F.normalize(x, dim=-1)

def simclr_loss(features, temperature=0.5):
    batch_size = features.shape[0]
    labels = torch.arange(batch_size)
    similarity_matrix = F.cosine_similarity(features.unsqueeze(1), features.unsqueeze(0), dim=-1)
    logits = similarity_matrix / temperature
    loss = F.cross_entropy(logits, labels)
    return loss

# Example usage
base_model = models.resnet18(weights=None)
model = SimCLRModel(base_model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
    img = torch.randn(64, 3, 224, 224)
    features = model(img)
    loss = simclr_loss(features)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item():.4f}')