<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_model(pretrained=False)
        # Remove the final fully connected layer (fc)
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
        self.projector = nn.Sequential(
            nn.Linear(512, projection_dim),  # Use 512 as the in_features for ResNet-18
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, start_dim=1)  # Flatten the output
        x = self.projector(x)
        return x

def nt_xent_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.shape[0]
    z_i = F.normalize(z_i, dim=1)
    z_j = F.normalize(z_j, dim=1)
    representations = torch.cat([z_i, z_j], dim=0)
    similarity_matrix = torch.matmul(representations, representations.T) / temperature
    labels = torch.cat([torch.arange(batch_size) for _ in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(labels.device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)

    return F.cross_entropy(logits, labels)

# Example usage
base_model = models.resnet18
projection_dim = 128
model = SimCLR(base_model, projection_dim)

transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    for images, _ in data_loader:
        optimizer.zero_grad()
        images = torch.cat([images, images.flip(dims=[3])], dim=0)  # Create positive pairs
        features = model(images)
        z_i, z_j = torch.chunk(features, 2, dim=0)
        loss = nt_xent_loss(z_i, z_j)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item():.4f}')