<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.models import resnet50, ResNet50_Weights

# Define the SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(SimCLR, self).__init__()
        # Get the number of features from the ResNet model
        in_features = base_model.fc.in_features

        # Replace the ResNet fully connected layer with an identity function
        base_model.fc = nn.Identity()

        # Store the base model and define the projection head
        self.base_model = base_model
        self.proj_head = nn.Sequential(
            nn.Linear(in_features, projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x):
        # Extract features from the base model and apply the projection head
        features = self.base_model(x)
        normalized_features = F.normalize(features, dim=1)  # Normalize the feature vectors
        return self.proj_head(normalized_features)

# Define the contrastive loss function
def contrastive_loss(out1, out2, temperature=0.07):
    # Normalize outputs to unit norm
    out1 = F.normalize(out1, dim=1)
    out2 = F.normalize(out2, dim=1)

    # Concatenate outputs
    out = torch.cat([out1, out2], dim=0)  # Combine both augmented views

    # Compute similarity matrix
    sim_matrix = torch.matmul(out, out.T) / temperature
    sim_matrix = sim_matrix - torch.eye(sim_matrix.size(0)).to(out.device) * 1e9  # Mask diagonal to prevent self-similarity

    # Create labels for positive pairs
    labels = torch.cat([torch.arange(out1.size(0)) for _ in range(2)], dim=0).to(out.device)

    # Compute the cross-entropy loss
    loss = nn.CrossEntropyLoss()(sim_matrix, labels)
    return loss

# Define the augmentation pipeline
augmentation = transforms.Compose([
    transforms.RandomResizedCrop(224),  # Resize to 224x224
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),  # Add brightness, contrast, saturation, and hue jitter
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize for ImageNet-trained ResNet
])

# Load the CIFAR-10 dataset and DataLoader
class AugmentedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        img1 = self.transform(img)
        img2 = self.transform(img)  # Apply two augmentations
        return img1, img2

    def __len__(self):
        return len(self.dataset)

# CIFAR-10 dataset with augmentation applied
cifar10_dataset = datasets.CIFAR10(root='data', train=True, download=True)
augmented_dataset = AugmentedDataset(cifar10_dataset, augmentation)
dataloader = torch.utils.data.DataLoader(augmented_dataset, batch_size=128, shuffle=True, drop_last=True)

# Initialize the model, optimizer, and device
base_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)  # Load pretrained ResNet50 weights
model = SimCLR(base_model=base_model).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
optimizer = optim.Adam(model.parameters(), lr=0.0003)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training loop
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for (data1, data2) in dataloader:  # Pair of augmented images
        # Move data to the appropriate device
        data1, data2 = data1.to(device), data2.to(device)

        # Forward pass
        out1 = model(data1)
        out2 = model(data2)

        # Compute the contrastive loss
        loss = contrastive_loss(out1, out2)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate loss
        total_loss += loss.item()

    # Logging the average loss per epoch
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(dataloader):.4f}")