In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SupCon(nn.Module):
    def __init__(self, temperature=0.07):
        super(SupCon, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        """
        features: Tensor of shape (B, D), where B is the batch size and D is the size of the embeddings?
        labels:   Tensor of shape (B,)
        """
        
        device = features.device
        # normalize embeddings, so sim(z_i, z_j) = z_i ⋅ z_j
        features = F.normalize(features)
        # similarity matrix (cosine similarity) -> sim_ij = z_i ⋅ z_j / τ
        sim_matrix = torch.matmul(features, features.T) / self.temperature # shape: (B, B)

        # create label mask
        labels = labels.contiguous().view(-1, 1) # make sure this tensor is laid out correctly in memory, then reshape it from (B,) to (B, 1) -> column vector
        mask = torch.eq(labels, labels.T).float().to(device) # mask[i, j] = 1 if same class, else 0
        # -> torch.eq basically creates a matrix by comparing y_i and y_j, where we use labels.T to obtain a square matrix of shape (B, B)
        # remove self-comparisons
        logits_mask = torch.ones_like(mask)
        logits_mask.fill_diagonal_(0) # -> mask[i, i] = 0
        mask = mask * logits_mask

        # compute log-softmax over rows
        logits = sim_matrix - torch.max(sim_matrix, dim=1, keepdim=True)[0]
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True))

        # average over positives
        mask_sum = mask.sum(dim=1)
        mask_sum = torch.clamp(mask_sum, min=1.0)
        mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask_sum

        loss = -mean_log_prob_pos.mean()
        return loss


class CombinedLoss(nn.Module):
    def __init__(self, alpha = 0.2):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.entropy_loss = nn.CrossEntropyLoss()
        self.supcon_loss = SupCon()

    def forward(self, embeddings, outputs, labels):
        return self.entropy_loss(outputs, labels) * (1 - self.alpha) + self.supcon_loss(embeddings, labels) * self.alpha

In [None]:
import torch
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self, in_channels = 3, num_classes = 10, proj_dim = 128):
        super(CNN, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.projector = nn.Sequential(
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(),
            nn.Linear(1024, proj_dim),
            nn.BatchNorm1d(proj_dim)
        )

        self.classifier = nn.Sequential(
            nn.Linear(proj_dim, proj_dim*4),
            nn.ReLU(),
            nn.Linear(proj_dim*4, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        
        proj = self.projector(x)      # for SupCon
        logits = self.classifier(proj)   # for CE

        return proj, logits

------

In [None]:
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [None]:
def get_device():
    if torch.cuda.is_available():
        # NVIDIA GPU
        device = torch.device("cuda")
        print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        # Apple Silicon GPU (MPS)
        device = torch.device("mps")
        print("Using MPS (Apple Silicon GPU)")
    else:
        # Fallback to CPU
        device = torch.device("cpu")
        print("Using CPU")
    return device

In [None]:
DEVICE = get_device()
BATCH_SIZE = 64
EPOCHS = 10

In [None]:
"""
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # TODO: add transformations/augmentations?
"""
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),        # randomly flip the image horizontally
    transforms.RandomRotation(15),                 # randomly rotate the image by ±15 degrees
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),  # random crop and resize to 32x32
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # adjust colors
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
model = CNN().to(DEVICE)

In [None]:
cross_entropy_loss = nn.CrossEntropyLoss()
sup_con_loss = SupCon()
total_step = len(train_loader)

In [None]:
# training the CNN only

# disabling classifier weights update
for param in model.parameters():
    param.requires_grad = True

for param in model.classifier.parameters():
    param.requires_grad = False

optimizer = optim.Adam(model.parameters(), lr=0.01)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for i, (images, labels) in enumerate(train_loader): # iterating over all the batches
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        embeddings, _ = model(images)
        loss = sup_con_loss(embeddings, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    model.eval()
    with torch.no_grad():
        validation_loss = 0
        for images, labels in validation_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            embeddings, _ = model(images)
            loss = sup_con_loss(embeddings, labels)
            validation_loss += loss.item()

    validation_loss /= len(validation_loader)

    print(f"> Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Validation Loss: {validation_loss:.4f}")

In [None]:
# training the CLASSIFIER only

# enabling only classifier weights update
for param in model.parameters():
    param.requires_grad = False

for param in model.classifier.parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=0.01)

In [None]:

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_accuracy = 0
    for i, (images, labels) in enumerate(train_loader): # iterating over all the batches
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        _, logits = model(images)
        loss = cross_entropy_loss(logits, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_accuracy += (torch.argmax(logits, dim=1) == labels).sum().item() / len(labels)

    train_loss /= len(train_loader)
    train_accuracy = train_accuracy / len(train_loader) * 100

    # validation
    model.eval()
    with torch.no_grad():
        validation_loss = 0
        validation_accuracy = 0
        for images, labels in validation_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            _, logits = model(images)
            loss = cross_entropy_loss(logits, labels)
            validation_loss += loss.item()
            validation_accuracy += (torch.argmax(logits, dim=1) == labels).sum().item() / len(labels)

    validation_loss /= len(validation_loader)
    validation_accuracy = validation_accuracy / len(validation_loader) * 100


    print(f"> Epoch {epoch+1}/{EPOCHS}")
    print(f"  Training loss      : {train_loss:.4f}, Training accuracy  : {train_accuracy:.2f}%")
    print(f"  Validation loss    : {validation_loss:.4f}, Validation accuracy: {validation_accuracy:.2f}%")