<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/_SimCLR_for_Image_Representation_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from PIL import Image

# Define SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = base_model
        self.in_features = base_model.fc.in_features  # Get in_features of the original fc layer
        self.encoder.fc = nn.Identity()  # Remove the original fully connected layer
        self.projection_head = nn.Sequential(
            nn.Linear(self.in_features, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return F.normalize(z, dim=1)

# Data augmentation for self-supervised learning
transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

# Load dataset and create DataLoader
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Initialize model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
projection_dim = 128
model = SimCLR(base_model, projection_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# SimCLR loss function
def nt_xent_loss(z_i, z_j, temperature=0.5):
    N = 2 * z_i.size(0)
    z = torch.cat((z_i, z_j), dim=0)
    sim = torch.matmul(z, z.T) / temperature
    sim_i_j = torch.diag(sim, N // 2)
    sim_j_i = torch.diag(sim, -N // 2)
    positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0)
    labels = torch.zeros(N).to(device).long()
    mask = torch.eye(N, device=device).bool()
    sim = sim.masked_fill(mask, -float('inf'))
    loss = F.cross_entropy(sim, labels)
    return loss

# Function to apply transformations twice and get two versions of the same image
def get_augmented_images(images):
    transform_i = transforms.Compose([
        transforms.RandomResizedCrop(size=224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    transform_j = transforms.Compose([
        transforms.RandomResizedCrop(size=224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])

    images_i = torch.stack([transform_i(Image.fromarray((img.numpy().transpose(1, 2, 0) * 255).astype('uint8'))) for img in images])
    images_j = torch.stack([transform_j(Image.fromarray((img.numpy().transpose(1, 2, 0) * 255).astype('uint8'))) for img in images])
    return images_i, images_j

# Training loop with added loss tracking
for epoch in range(10):
    model.train()
    epoch_loss = 0
    for images, _ in dataloader:
        images_i, images_j = get_augmented_images(images)
        images_i, images_j = images_i.to(device), images_j.to(device)

        optimizer.zero_grad()
        z_i = model(images_i)
        z_j = model(images_j)

        # Applying contrastive loss with adaptive temperature if needed
        loss = nt_xent_loss(z_i, z_j, temperature=0.5)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/10], Loss: {avg_loss:.4f}")

print("Enhanced training completed!")