In [2]:
# Self-Supervised Learning with DINO on CIFAR-10 using ResNet50
# This notebook includes local and global crops, centering and sharpening,
# and evaluates the model under different training conditions.

# Imports
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.models import resnet50
from torchvision.models.resnet import ResNet
import copy
from tqdm import tqdm
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [32]:
# Data Augmentation Transformations for Local and Global Crops
class DINOTransform(object):
    def __init__(self, global_crop_scale=(0.4, 1.0), local_crop_scale=(0.05, 0.4), local_crops_number=4, image_size=32):
        self.global_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=global_crop_scale),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])
        
        self.local_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=local_crop_scale),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        ])
        self.local_crops_number = local_crops_number

    def __call__(self, x):
        # Two global crops for teacher and student, several local crops for student
        crops = [self.global_transform(x) for _ in range(2)]  # 2 global crops
        crops.extend([self.local_transform(x) for _ in range(self.local_crops_number)])  # 4 local crops is default

        return crops

sl_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.4, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262))
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262))
])

In [78]:
# Load and preprocess CIFAR-10 dataset
def load_data(batch_size):
    ssl_transform = DINOTransform()
    ssl_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=ssl_transform)
    sl_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=sl_transform)
    val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)

    ssl_train_loader = DataLoader(ssl_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8, drop_last=True)
    sl_train_loader = DataLoader(sl_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=8)
    return ssl_train_loader, sl_train_loader, val_loader

In [64]:
# Define DINO framework with ResNet50
class DINO(nn.Module):
    def __init__(self, student: nn.Module,
                    teacher: nn.Module,
                    num_classes: int,
                    device: torch.device,
                    temperature_student=0.07,
                    temperature_teacher=0.04,
                    learning_rate=0.001,
                    momentum=0.9,
                    center_momentum=0.9,
                    local_crops_number=4):
        super(DINO, self).__init__()
        self.device = device
        self.student = student.to(device)
        self.teacher = teacher.to(device)
        self.teacher.load_state_dict(self.student.state_dict())
        self.optimizer = optim.AdamW(self.student.parameters(), lr=learning_rate)
        self.temperature_student = temperature_student
        self.temperature_teacher = temperature_teacher
        self.center_momentum = center_momentum
        self.momentum = momentum
        self.register_buffer('center', torch.zeros(1, num_classes))
        self.local_crops_number=local_crops_number
        
        teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False
        
    def H(self, teacher_outputs, student_outputs):
        """
        Custom cross-entropy for soft labels, applied with centering and sharpening.
        """

        #print("Student_outputs shape:", student_outputs.shape)
        #print("Teacher outputs shape: ", teacher_outputs.shape)
        #print("student_outputs device: ", student_outputs.get_device())
        student_probs = nn.functional.log_softmax(student_outputs / self.temperature_student, dim=1)

        teacher_outputs = teacher_outputs.detach()
        #print("teacher_outputs device: ", teacher_outputs.get_device())
        centered_output = (teacher_outputs - self.center)
        teacher_probs = nn.functional.softmax(centered_output / self.temperature_teacher, dim=1)

        #print("Student_probs shape: ", student_probs.shape)
        #print("Teacher probs shape: ", teacher_probs.shape)

        
        loss = - (teacher_probs * student_probs).sum(dim=1).mean()

        #print(len(losses))
        #print("losses stack shape", torch.stack(losses).shape)

        return loss
    
    def train_step(self, crops: tuple[list, list]):
        """
        Perform a single training step with DINO.
        Args:
            crops (tuple of list of Tensor): List containing augmented two augmented views of the same image.
        """
        # Forward pass through the student model on both global and local crops
        student_outputs = torch.stack([self.student(crop.to(self.device)) for crop in crops])
        #student_outputs = self.student(crops)
        
        # Forward pass through the teacher model on global crops only
        with torch.no_grad():
            #print(crops[:2].shape)
            teacher_crops = crops[:2] # We use only the two global crops
            teacher_outputs = torch.stack([self.teacher(crop.to(self.device)) for crop in teacher_crops])
        
        # Compute DINO loss using soft cross-entropy with centering and sharpening for every pair t and s
        losses = []
        for t in teacher_outputs:
            for s in student_outputs:
                dino_loss = self.H(t, s)
                losses.append(dino_loss)
        
        #print("loss shape: ", torch.stack(losses).shape)

        loss = torch.stack(losses).mean()
        #print("Loss: ", loss.item())
        
        #loss /= len(student_outputs)  # Normalize by the number of student outputs
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update teacher model after each batch
        for param_student, param_teacher in zip(self.student.parameters(), self.teacher.parameters()):
            param_teacher.data = self.momentum * param_teacher.data + (1 - self.momentum) * param_student.data

        with torch.no_grad():
            self.center = self.center_momentum * self.center + (1 - self.center_momentum)*teacher_outputs.mean(dim=0)
        
        return loss.item()

In [79]:
# Training loop
def ssl_train(dino, train_loader, epochs):
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (images, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
            #print("data shape: ", data.shape)
            #print("images shape: ", np.shape(images))
            loss = dino.train_step(images)
            #print("batch loss: ", batch_loss)
            total_loss += loss
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

In [87]:
# Evaluation function
def evaluate(student, val_loader, device):
    """
    Evaluate the student model on the validation set.
    """
    student.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, labels in val_loader:
            data = data.to(device)
            labels = labels.to(device)
            outputs = student(data)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        
    student.train()
    
    accuracy = 100 * correct / total
    print(f'Validation Accuracy: {accuracy:.2f}%')

In [91]:
# Hyperparameters
batch_size = 256
ssl_epochs = 5
sl_epochs = 20
learning_rate = 0.001
temperature_teacher = 0.04
temperature_student = 0.07
num_classes = 10
local_crops_number = 4

In [99]:
# Model Initialization with ResNet50
student_model = resnet50(num_classes=num_classes)
teacher_model = resnet50(num_classes=num_classes)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
supervised_only_model = copy.deepcopy(student_model).to(device)

# Instantiate DINO framework
dino = DINO(student=student_model, teacher=teacher_model, device=device, num_classes=num_classes,
            temperature_student=temperature_student, temperature_teacher=temperature_teacher, 
            learning_rate=learning_rate, local_crops_number=local_crops_number)

dino.to(device)

# Load data
ssl_train_loader, sl_train_loader, val_loader = load_data(batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [100]:
# 1. Evaluate ResNet50 without any training (random weights), should be around 10%
print("Evaluation with random weights (no training):")
evaluate(dino.student, val_loader, device)

Evaluation with random weights (no training):
Validation Accuracy: 11.03%


In [101]:
# 2. Train with only self-supervised learning, should be roughly the same as without any supervised training
print("\nTraining with self-supervised learning (DINO):")
ssl_train(dino, ssl_train_loader, ssl_epochs)
print("\nEvaluation after self-supervised learning:")
evaluate(dino.student, val_loader, device)


Training with self-supervised learning (DINO):


100%|██████████| 195/195 [01:14<00:00,  2.63it/s]

Epoch 1/5, Loss: 9.0973



100%|██████████| 195/195 [01:14<00:00,  2.62it/s]

Epoch 2/5, Loss: 3.9559



100%|██████████| 195/195 [01:14<00:00,  2.62it/s]

Epoch 3/5, Loss: 1.7042



100%|██████████| 195/195 [01:14<00:00,  2.61it/s]

Epoch 4/5, Loss: 1.2449



100%|██████████| 195/195 [01:11<00:00,  2.71it/s]

Epoch 5/5, Loss: 1.6333

Evaluation after self-supervised learning:





Validation Accuracy: 10.00%


In [102]:
# 3. Supervised training function (on top of self-supervised pre-trained student model)
def supervised_train(model, train_loader, epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(epochs):
        for batch_idx, (data, labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
            data = data.to(device)
            labels = labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")
    print("Supervised training complete.")

In [103]:
# 4. Train supervised-only ResNet50 model for comparison
print("\nTraining ResNet50 model with supervised learning only:")
supervised_train(supervised_only_model, sl_train_loader, sl_epochs)
print("\nEvaluation after supervised training only:")
evaluate(supervised_only_model, val_loader, device)


Training ResNet50 model with supervised learning only:


100%|██████████| 195/195 [00:09<00:00, 21.25it/s]

Epoch [1/20], Loss: 1.9524



100%|██████████| 195/195 [00:09<00:00, 21.34it/s]

Epoch [2/20], Loss: 1.6387



100%|██████████| 195/195 [00:09<00:00, 21.41it/s]

Epoch [3/20], Loss: 1.6971



100%|██████████| 195/195 [00:09<00:00, 21.35it/s]

Epoch [4/20], Loss: 1.3400



100%|██████████| 195/195 [00:09<00:00, 21.29it/s]

Epoch [5/20], Loss: 1.4890



100%|██████████| 195/195 [00:09<00:00, 21.29it/s]

Epoch [6/20], Loss: 1.3899



100%|██████████| 195/195 [00:09<00:00, 21.37it/s]

Epoch [7/20], Loss: 1.4210



100%|██████████| 195/195 [00:09<00:00, 21.18it/s]


Epoch [8/20], Loss: 1.4299


100%|██████████| 195/195 [00:09<00:00, 20.67it/s]

Epoch [9/20], Loss: 1.2288



100%|██████████| 195/195 [00:09<00:00, 20.47it/s]

Epoch [10/20], Loss: 1.8314



100%|██████████| 195/195 [00:09<00:00, 21.24it/s]

Epoch [11/20], Loss: 1.6746



100%|██████████| 195/195 [00:09<00:00, 21.34it/s]

Epoch [12/20], Loss: 1.3585



100%|██████████| 195/195 [00:09<00:00, 21.43it/s]

Epoch [13/20], Loss: 1.3113



100%|██████████| 195/195 [00:09<00:00, 21.38it/s]

Epoch [14/20], Loss: 1.1822



100%|██████████| 195/195 [00:09<00:00, 21.45it/s]

Epoch [15/20], Loss: 1.1847



100%|██████████| 195/195 [00:09<00:00, 21.47it/s]

Epoch [16/20], Loss: 1.4402



100%|██████████| 195/195 [00:09<00:00, 21.27it/s]

Epoch [17/20], Loss: 1.2057



100%|██████████| 195/195 [00:09<00:00, 21.37it/s]

Epoch [18/20], Loss: 1.0685



100%|██████████| 195/195 [00:09<00:00, 21.42it/s]

Epoch [19/20], Loss: 1.0385



100%|██████████| 195/195 [00:09<00:00, 21.38it/s]

Epoch [20/20], Loss: 1.0537
Supervised training complete.

Evaluation after supervised training only:





Validation Accuracy: 64.41%


In [None]:
# 5. Fine-tune (supervised) after self-supervised pre-training, should be slighlty better than only supervised
print("\nFine-tuning (supervised) after self-supervised pre-training:")
supervised_train(dino.student, sl_train_loader, sl_epochs)
print("\nEvaluation after self-supervised pre-training + supervised fine-tuning:")
evaluate(dino.student, val_loader, device)


Fine-tuning (supervised) after self-supervised pre-training:


100%|██████████| 195/195 [00:09<00:00, 21.37it/s]

Epoch [1/20], Loss: 1.5518



100%|██████████| 195/195 [00:09<00:00, 21.39it/s]

Epoch [2/20], Loss: 1.5922



100%|██████████| 195/195 [00:09<00:00, 21.27it/s]

Epoch [3/20], Loss: 1.3611



100%|██████████| 195/195 [00:09<00:00, 21.56it/s]

Epoch [4/20], Loss: 1.8953



100%|██████████| 195/195 [00:09<00:00, 21.43it/s]

Epoch [5/20], Loss: 2.0084



100%|██████████| 195/195 [00:09<00:00, 21.32it/s]

Epoch [6/20], Loss: 1.7771



100%|██████████| 195/195 [00:09<00:00, 21.22it/s]


Epoch [7/20], Loss: 1.5285


100%|██████████| 195/195 [00:09<00:00, 21.47it/s]

Epoch [8/20], Loss: 1.6608



100%|██████████| 195/195 [00:09<00:00, 21.35it/s]

Epoch [9/20], Loss: 1.4558



100%|██████████| 195/195 [00:09<00:00, 21.42it/s]

Epoch [10/20], Loss: 1.3326



100%|██████████| 195/195 [00:09<00:00, 21.29it/s]

Epoch [11/20], Loss: 1.4449



100%|██████████| 195/195 [00:09<00:00, 21.44it/s]

Epoch [12/20], Loss: 1.4216



100%|██████████| 195/195 [00:09<00:00, 21.44it/s]

Epoch [13/20], Loss: 1.2233



100%|██████████| 195/195 [00:09<00:00, 21.41it/s]

Epoch [14/20], Loss: 1.0270



100%|██████████| 195/195 [00:09<00:00, 21.32it/s]

Epoch [15/20], Loss: 1.2347



100%|██████████| 195/195 [00:09<00:00, 21.27it/s]

Epoch [16/20], Loss: 0.9913



100%|██████████| 195/195 [00:09<00:00, 21.49it/s]

Epoch [17/20], Loss: 1.0343



100%|██████████| 195/195 [00:09<00:00, 21.41it/s]

Epoch [18/20], Loss: 0.9823



100%|██████████| 195/195 [00:09<00:00, 21.32it/s]

Epoch [19/20], Loss: 0.8926



100%|██████████| 195/195 [00:09<00:00, 21.25it/s]

Epoch [20/20], Loss: 0.9074
Supervised training complete.

Evaluation after self-supervised pre-training + supervised fine-tuning:





Validation Accuracy: 69.15%
