<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
import torch.optim as optim

class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim):
        super(SimCLR, self).__init__()
        self.base_model = base_model
        in_features = base_model.fc.in_features  # Determine in_features before replacing fc
        self.base_model.fc = nn.Identity()  # Remove the classification head
        self.projector = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        h = self.base_model(x)
        z = self.projector(h)
        return z

# Define the SimCLR transform
transform = T.Compose([
    T.RandomResizedCrop(size=224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

# Example using ResNet as the base model
base_model = models.resnet18(weights=None)  # Update: Use 'weights=None' instead of 'pretrained=False'

model = SimCLR(base_model, projection_dim=128)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Function to compute contrastive loss
def contrastive_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.shape[0]
    z_i = nn.functional.normalize(z_i, dim=1)
    z_j = nn.functional.normalize(z_j, dim=1)
    representations = torch.cat([z_i, z_j], dim=0)
    similarity_matrix = torch.matmul(representations, representations.T) / temperature
    labels = torch.cat([torch.arange(batch_size) for _ in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(labels.device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)

    return nn.functional.cross_entropy(logits, labels)

# Training loop
def train(model, dataloader, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, _ in dataloader:
            batch_size = images.size(0)
            images = torch.cat([images, images.flip(dims=[3])], dim=0)  # Generate positive pairs
            features = model(images)
            z_i, z_j = torch.chunk(features, 2, dim=0)
            loss = contrastive_loss(z_i, z_j)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

# Assuming you have a DataLoader `dataloader` for your dataset
# train(model, dataloader, optimizer, epochs=100)