<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Complete_Contrastive_Learning_Training_Loop_(Synthetic_Example).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# --- Contrastive Model ---
class ContrastiveModel(nn.Module):
    def __init__(self, input_dim, projection_dim):
        super(ContrastiveModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        return self.encoder(x)

# --- Contrastive Loss (NT-Xent-style) ---
def contrastive_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)  # [2N, D]
    z = F.normalize(z, dim=1)

    similarity_matrix = torch.matmul(z, z.T) / temperature
    labels = torch.arange(batch_size)
    labels = torch.cat([labels, labels], dim=0).to(z.device)

    # Mask out self-similarities
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -1e9)

    positives = torch.cat([torch.diag(similarity_matrix, batch_size),
                           torch.diag(similarity_matrix, -batch_size)], dim=0)
    logits = torch.cat([positives.unsqueeze(1),
                        similarity_matrix[~mask].view(2 * batch_size, -1)], dim=1)
    targets = torch.zeros(2 * batch_size, dtype=torch.long).to(z.device)

    return F.cross_entropy(logits, targets)

# --- Synthetic Dataset ---
class ContrastiveDataset(Dataset):
    def __init__(self, n_samples=1000, dim=128):
        self.samples = torch.randn(n_samples, dim)

    def __getitem__(self, idx):
        x = self.samples[idx]
        x_i = x + 0.05 * torch.randn_like(x)
        x_j = x + 0.05 * torch.randn_like(x)
        return x_i, x_j

    def __len__(self):
        return len(self.samples)

# --- Training ---
def train(model, dataloader, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for x_i, x_j in dataloader:
            x_i, x_j = x_i.to(device), x_j.to(device)
            z_i = model(x_i)
            z_j = model(x_j)
            loss = contrastive_loss(z_i, z_j)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

# --- Main ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_dim = 128
projection_dim = 64
batch_size = 64

dataset = ContrastiveDataset(n_samples=2048, dim=input_dim)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = ContrastiveModel(input_dim, projection_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

train(model, dataloader, optimizer, device, epochs=10)