<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Contrastive_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class ContrastiveModel(nn.Module):
    def __init__(self, feature_dim):
        super(ContrastiveModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, feature_dim)
        )

    def forward(self, x):
        return self.encoder(x)

def contrastive_loss(features, labels, margin=1.0):
    batch_size = features.shape[0]
    pairwise_distances = torch.cdist(features, features, p=2)
    labels = labels.unsqueeze(1) == labels.unsqueeze(0)
    positive_pairs = pairwise_distances[labels].clamp(min=0.0)
    negative_pairs = pairwise_distances[~labels].clamp(min=0.0)
    loss = torch.sum(positive_pairs) / positive_pairs.numel() + torch.sum(margin - negative_pairs).clamp(min=0).mean()
    return loss

model = ContrastiveModel(feature_dim=64)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = contrastive_loss

# Example training loop
for epoch in range(5):  # Train for 5 epochs for simplicity
    optimizer.zero_grad()
    # Simulated batch of features and labels
    batch_features = torch.randn(32, 28 * 28)
    batch_labels = torch.randint(0, 10, (32,))
    features = model(batch_features)
    loss = criterion(features, batch_labels)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")