In [9]:
# Self-Supervised Learning with DINO on CIFAR-10 using ResNet50
# This notebook includes local and global crops, centering and sharpening,
# and evaluates the model under different training conditions.

# Imports
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.models import resnet50
from torchvision.models.resnet import ResNet
import copy
from tqdm import tqdm
from PIL import Image

In [83]:
# Data Augmentation Transformations for Local and Global Crops
class DINOTransform(object):
    def __init__(self, global_crop_scale=(0.4, 1.0), local_crop_scale=(0.05, 0.4), local_crops_number=4, image_size=32):
        self.global_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=global_crop_scale),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])
        
        self.local_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=local_crop_scale),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])
        self.local_crops_number = local_crops_number

    def __call__(self, x):
        # Two global crops for teacher and student, several local crops for student
        crops = [self.global_transform(x) for _ in range(2)]  # 2 global crops
        crops.extend([self.local_transform(x) for _ in range(self.local_crops_number)])  # 4 local crops is default

        return torch.stack(crops)

sl_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.4, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262))
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262))
])

In [42]:
# Load and preprocess CIFAR-10 dataset
def load_data(batch_size):
    ssl_transform = DINOTransform()
    ssl_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=ssl_transform)
    sl_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=sl_transform)
    val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)

    ssl_train_loader = DataLoader(ssl_dataset, batch_size=batch_size, shuffle=True)
    sl_train_loader = DataLoader(sl_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    return ssl_train_loader, sl_train_loader, val_loader

In [None]:
# Define DINO framework with ResNet50
class DINO(nn.Module):
    def __init__(self, student: nn.Module,
                    teacher: nn.Module,
                    num_classes: int,
                    device: torch.device,
                    temperature_student=0.07,
                    temperature_teacher=0.04,
                    learning_rate=0.001,
                    momentum=0.9,
                    center_momentum=0.9,
                    local_crops_number=4):
        super(DINO, self).__init__()
        self.device = device
        self.student = student.to(device)
        self.teacher = teacher.to(device)
        self.teacher.load_state_dict(self.student.state_dict())
        self.optimizer = optim.AdamW(self.student.parameters(), lr=learning_rate)
        self.temperature_student = temperature_student
        self.temperature_teacher = temperature_teacher
        self.center_momentum = center_momentum
        self.momentum = momentum
        self.register_buffer('center', torch.zeros(1, num_classes))
        self.local_crops_number=local_crops_number
        
        teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False
        
    def H(self, student_outputs, teacher_outputs):
        """
        Custom cross-entropy for soft labels, applied with centering and sharpening.
        """

        #print("Student_outputs shape:", student_outputs.shape)
        #print("Teacher outputs shape: ", teacher_outputs.shape)
        #print("student_outputs device: ", student_outputs.get_device())
        student_probs = nn.functional.log_softmax(student_outputs / self.temperature_student, dim=1)

        teacher_outputs = teacher_outputs.detach()
        #print("teacher_outputs device: ", teacher_outputs.get_device())
        centered_output = (teacher_outputs - self.center)
        teacher_probs = nn.functional.softmax(centered_output / self.temperature_teacher, dim=1)

        #print("Student_probs shape: ", student_probs.shape)
        #print("Teacher probs shape: ", teacher_probs.shape)

        losses = []
        for t in teacher_probs:
            for s in student_probs:
                #print(t)
                #print(s)
                loss = - (t * s).sum().mean()
                losses.append(loss)

        #print(len(losses))
        #print("losses stack shape", torch.stack(losses).shape)

        return torch.stack(losses)
    
    def train_step(self, crops: tuple[list, list]):
        """
        Perform a single training step with DINO.
        Args:
            crops (tuple of list of Tensor): List containing augmented two augmented views of the same image.
        """
        crops = crops.to(self.device)
        #print("Crops shape: ", crops.shape)

        # Forward pass through the student model on both global and local crops
        student_outputs = self.student(crops)
        
        # Forward pass through the teacher model on global crops only
        with torch.no_grad():
            #print(crops[:2].shape)
            teacher_outputs = self.teacher(crops[:2]) # We use only the two global crops
        
        # Compute DINO loss using soft cross-entropy with centering and sharpening
        losses = self.H(student_outputs, teacher_outputs)
        loss = losses.mean()
        #print("Loss: ", loss.item())
        
        #loss /= len(student_outputs)  # Normalize by the number of student outputs
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update teacher model after each batch
        for param_student, param_teacher in zip(self.student.parameters(), self.teacher.parameters()):
            param_teacher.data = self.momentum * param_teacher.data + (1 - self.momentum) * param_student.data

        with torch.no_grad():
            self.center = self.center_momentum * self.center + (1 - self.center_momentum)*teacher_outputs.mean(dim=0)
        
        return loss.item()

In [None]:
# Training loop
def ssl_train(dino, train_loader, epochs):
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
            #print("data shape: ", data.shape)
            batch_loss = 0
            for image in data: # We do only one sample at a time for simplicity
                #print("image shape: ", image.shape)
                loss = dino.train_step(image)
                batch_loss += loss
                #print("batch loss: ", batch_loss)
            total_loss += batch_loss
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

In [160]:
# Evaluation function
def evaluate(dino, val_loader, device):
    """
    Evaluate the student model on the validation set.
    """
    dino.student.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, labels in val_loader:
            data = data.to(device)
            labels = labels.to(device)
            outputs = dino.student(data)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        
    dino.student.train()
    
    accuracy = 100 * correct / total
    print(f'Validation Accuracy: {accuracy:.2f}%')

In [161]:
# Hyperparameters
batch_size = 128
epochs = 10
learning_rate = 0.001
temperature_teacher = 0.04
temperature_student = 0.07
num_classes = 10
local_crops_number = 4

In [162]:
# Model Initialization with ResNet50
student_model = resnet50(num_classes=num_classes)
teacher_model = resnet50(num_classes=num_classes)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Instantiate DINO framework
dino = DINO(student=student_model, teacher=teacher_model, device=device, num_classes=num_classes,
            temperature_student=temperature_student, temperature_teacher=temperature_teacher, 
            learning_rate=learning_rate, local_crops_number=local_crops_number)

dino.to(device)

# Load data
ssl_train_loader, sl_train_loader, val_loader = load_data(batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [70]:
# 1. Evaluate ResNet50 without any training (random weights)
print("Evaluation with random weights (no training):")
evaluate(dino, val_loader, device)

Evaluation with random weights (no training):
Validation Accuracy: 10.15%


In [163]:
# 2. Train with only self-supervised learning
print("\nTraining with self-supervised learning (DINO):")
ssl_train(dino, ssl_train_loader, epochs)
print("\nEvaluation after self-supervised learning:")
evaluate(dino, val_loader)


Training with self-supervised learning (DINO):


 32%|███▏      | 127/391 [15:31<32:16,  7.34s/it]


KeyboardInterrupt: 

In [None]:
# 3. Supervised training function (on top of self-supervised pre-trained student model)
def supervised_train(dino, train_loader, val_loader, epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(dino.student.parameters(), lr=learning_rate)
    for epoch in range(epochs):
        for data, labels in train_loader:
            outputs = dino.student(data)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")
    print("Supervised training complete.")

In [None]:
# 4. Train supervised-only ResNet50 model for comparison
print("\nTraining ResNet50 model with supervised learning only:")
supervised_only_model = copy.deepcopy(student_model)
supervised_only_dino = DINO(student=supervised_only_model, teacher=teacher_model)
supervised_train(supervised_only_dino, train_loader, val_loader, epochs)
print("\nEvaluation after supervised training only:")
evaluate(supervised_only_dino, val_loader)

In [None]:
# 5. Fine-tune (supervised) after self-supervised pre-training
print("\nFine-tuning (supervised) after self-supervised pre-training:")
supervised_train(dino, train_loader, val_loader, epochs)
print("\nEvaluation after self-supervised pre-training + supervised fine-tuning:")
evaluate(dino, val_loader)