In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CIFAR10
import torchvision.models as models
from tqdm import tqdm

In [None]:

simclr_transform = T.Compose([
    T.RandomResizedCrop(32),
    T.RandomHorizontalFlip(),
    T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.GaussianBlur(kernel_size=3),
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465],
                std=[0.247, 0.243, 0.261])
])

In [None]:

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=simclr_transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4, drop_last=True)


100%|██████████| 170M/170M [00:02<00:00, 68.0MB/s]


In [None]:

class Encoder(nn.Module):
    def __init__(self, base_model):
        super(Encoder, self).__init__()
        self.backbone = nn.Sequential(*list(base_model.children())[:-1])
        self.fc = nn.Linear(base_model.fc.in_features, 128)

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = F.normalize(x, dim=1)
        return x


In [None]:

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=512, out_dim=128):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:

class SimCLR(nn.Module):
    def __init__(self, base_model):
        super(SimCLR, self).__init__()
        self.encoder = Encoder(base_model)
        self.projection_head = ProjectionHead()

    def forward(self, x):
        h = self.encoder(x)
        z = self.projection_head(h)
        return h, z

In [None]:

class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)
        sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
        sim = sim / self.temperature

        # Mask self-similarities
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
        sim.masked_fill_(mask, -9e15)

        # Positive pairs: i <-> i + batch_size
        positives = torch.cat([torch.arange(batch_size, 2 * batch_size),
                               torch.arange(0, batch_size)]).to(z.device)
        labels = positives

        loss = F.cross_entropy(sim, labels)
        return loss


In [None]:

def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for (x_i, _), (x_j, _) in tqdm(zip(loader, loader), total=len(loader), desc="Training"):
        x_i, x_j = x_i.to(device), x_j.to(device)
        optimizer.zero_grad()

        _, z_i = model(x_i)
        _, z_j = model(x_j)

        loss = criterion(z_i, z_j)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

base_model = models.resnet18(weights=None)
model = SimCLR(base_model).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = NTXentLoss(temperature=0.5).to(device)

EPOCHS = 10

for epoch in range(EPOCHS):
    avg_loss = train(model, train_loader, optimizer, criterion, device)
    print(f"Epoch [{epoch+1}/{EPOCHS}] - Loss: {avg_loss:.4f}")

print("Training Complete!")

Training: 100%|██████████| 195/195 [30:12<00:00,  9.30s/it]

Epoch [1/10] - Loss: 6.2365



Training: 100%|██████████| 195/195 [30:38<00:00,  9.43s/it]

Epoch [2/10] - Loss: 6.2364



Training: 100%|██████████| 195/195 [30:30<00:00,  9.39s/it]

Epoch [3/10] - Loss: 6.2364



Training: 100%|██████████| 195/195 [29:27<00:00,  9.06s/it]


Epoch [4/10] - Loss: 6.2364


Training: 100%|██████████| 195/195 [28:56<00:00,  8.91s/it]


Epoch [5/10] - Loss: 6.2364


Training: 100%|██████████| 195/195 [29:31<00:00,  9.09s/it]

Epoch [6/10] - Loss: 6.2364



Training: 100%|██████████| 195/195 [29:01<00:00,  8.93s/it]


Epoch [7/10] - Loss: 6.2364


Training:   5%|▍         | 9/195 [01:28<27:45,  8.95s/it]