In [7]:
import os
import time
import copy
import math
import random
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image

import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive

# --- Configuration ---
BATCH_SIZE = 64         
EPOCHS_PRETRAIN = 10    
EPOCHS_LINEAR_EVAL = 10 
LEARNING_RATE = 3e-4    
WEIGHT_DECAY = 1e-6     
PROJECTION_DIM = 128   
HIDDEN_DIM = 512     
MOMENTUM_BASE = 0.99   
NUM_WORKERS = 2        

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PIN_MEMORY = True if DEVICE.type == 'cuda' else False

print(f"Using device: {DEVICE}")

# Set seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

# For reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


Using device: cuda


In [8]:
Dataset = "/kaggle/input/melanoma-cancer-dataset"

file_paths = []
labels = []

for class_name in os.listdir(Dataset):
    class_dir = os.path.join(Dataset, class_name)
    for image_name in os.listdir(class_dir):
        file_paths.append(os.path.join(class_dir, image_name))
        labels.append(class_name)

df = pd.DataFrame({"file_path": file_paths, "label": labels})
df = df.sample(frac=1).reset_index(drop=True)

In [9]:
# @title 2. BYOL Data Augmentations
# This cell defines the custom data augmentation pipeline for BYOL.

class CustomBYOLTransform:
    def __init__(self, size=32):
        # ImageNet stats for normalization
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        # Core augmentations for BYOL (similar to SimCLR)
        self.transform1 = transforms.Compose([
            transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=int(0.1 * size))], p=0.5), # Crucial for BYOL
            transforms.ToTensor(),
            normalize
        ])

        # Often a slightly different set of augmentations for the second view
        # or just a different random application of the same ones.
        self.transform2 = transforms.Compose([
            transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=int(0.1 * size))], p=0.5), # Crucial for BYOL
            transforms.ToTensor(),
            normalize
        ])

    def __call__(self, x):
        return self.transform1(x), self.transform2(x)

# CIFAR-10 dataset with custom BYOL transformations
class CIFAR10BYOL(datasets.CIFAR10):
    def __init__(self, root, train=True, transform=None, download=False):
        super().__init__(root, train=train, transform=transform, download=download)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            # Return two augmented views of the same image
            return self.transform(img)
        else:
            return img, target

# Load CIFAR-10 dataset
train_dataset_byol = CIFAR10BYOL(root='./data', train=True, download=True, transform=CustomBYOLTransform())
train_loader_byol = DataLoader(train_dataset_byol, batch_size=BATCH_SIZE, shuffle=True,
                               num_workers=NUM_WORKERS, drop_last=True, pin_memory=True)

# Dataset for linear evaluation (standard transforms)
train_transform_eval = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform_eval = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset_eval = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform_eval)
test_dataset_eval = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform_eval)

train_loader_eval = DataLoader(train_dataset_eval, batch_size=BATCH_SIZE, shuffle=True,
                               num_workers=NUM_WORKERS, pin_memory=True)
test_loader_eval = DataLoader(test_dataset_eval, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=True)

print(f"BYOL training dataset size: {len(train_dataset_byol)}")
print(f"Linear evaluation training dataset size: {len(train_dataset_eval)}")
print(f"Linear evaluation test dataset size: {len(test_dataset_eval)}")

100%|██████████| 170M/170M [00:02<00:00, 60.0MB/s] 


BYOL training dataset size: 50000
Linear evaluation training dataset size: 50000
Linear evaluation test dataset size: 10000


In [10]:
# @title 3. BYOL Model Components
# This cell defines the encoder, projector, and predictor networks.

class BYOLResNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Use a pre-trained ResNet-50 as the backbone, but modify for CIFAR-10
        # CIFAR-10 images are 32x32, so initial layers need adjustment or can be kept as is
        # and rely on RandomResizedCrop to handle the size.
        # ResNet default initial layers are designed for 224x224.
        # For CIFAR-10, often the first maxpool is removed, and stride of conv1 is reduced.
        # For simplicity here, we'll use the default ResNet-50 and let torchvision handle initial layers.
        # However, for optimal CIFAR performance, you might adapt initial layers.

        # Use ResNet-50
        resnet = models.resnet50(weights=None) # Start without pre-trained ImageNet weights

        # Remove the final fully connected layer (classifier)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])

        # Get feature dimension from the flattened output of the encoder
        # For ResNet-50, output before FC layer is 2048 feature maps.
        # After global average pooling, it's 2048.
        self.feature_dim = 2048

    def forward(self, x):
        return self.encoder(x).view(x.size(0), -1) # Flatten the output

class ProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

class PredictionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# Main BYOL Model
class BYOL(nn.Module):
    def __init__(self, encoder: BYOLResNet, projection_dim, hidden_dim):
        super().__init__()
        self.online_encoder = encoder
        self.online_projector = ProjectionHead(encoder.feature_dim, hidden_dim, projection_dim)
        self.online_predictor = PredictionHead(projection_dim, hidden_dim, projection_dim)

        self.target_encoder = copy.deepcopy(encoder)
        self.target_projector = copy.deepcopy(self.online_projector)

        # Freeze target network parameters
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False

    def forward(self, x1, x2):
        # Online network forward pass
        online_proj_out1 = self.online_projector(self.online_encoder(x1))
        online_pred_out1 = self.online_predictor(online_proj_out1)

        online_proj_out2 = self.online_projector(self.online_encoder(x2))
        online_pred_out2 = self.online_predictor(online_proj_out2)

        # Target network forward pass (no gradients needed for target)
        with torch.no_grad():
            target_proj_out1 = self.target_projector(self.target_encoder(x1))
            target_proj_out2 = self.target_projector(self.target_encoder(x2))

        return online_pred_out1, online_pred_out2, target_proj_out1, target_proj_out2

    # EMA update for target network
    def update_target_network(self, momentum):
        for online_param, target_param in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            target_param.data = target_param.data * momentum + online_param.data * (1. - momentum)
        for online_param, target_param in zip(self.online_projector.parameters(), self.target_projector.parameters()):
            target_param.data = target_param.data * momentum + online_param.data * (1. - momentum)

# Instantiate the BYOL model
byol_encoder = BYOLResNet().to(DEVICE)
byol_model = BYOL(byol_encoder, PROJECTION_DIM, HIDDEN_DIM).to(DEVICE)

# Check model architecture
print(byol_model)

BYOL(
  (online_encoder): BYOLResNet(
    (encoder): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
   

In [11]:
# @title 4. BYOL Loss Function
# This cell defines the custom BYOL loss.

def byol_loss(online_pred, target_proj):
    # Normalize outputs before computing MSE
    online_pred_norm = F.normalize(online_pred, dim=-1)
    target_proj_norm = F.normalize(target_proj, dim=-1)

    # Negative MSE loss as per BYOL paper (L2 norm squared)
    # The paper uses 1 - cosine_similarity as their loss equivalent to this,
    # which leads to this form.
    loss = 2 - 2 * (online_pred_norm * target_proj_norm).sum(dim=-1)
    return loss.mean()

In [None]:
# @title 5. Self-Supervised Pre-training Loop
# This cell contains the main training logic for BYOL.

def adjust_learning_rate(optimizer, init_lr, epoch, total_epochs):
    """Cosine learning rate decay"""
    lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / total_epochs))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def adjust_momentum(epoch, total_epochs, momentum_base):
    """Momentum annealing for EMA"""
    m = 1. - (1. - momentum_base) * (math.cos(math.pi * epoch / total_epochs) + 1) / 2
    return m

print("--- Starting BYOL Self-Supervised Pre-training ---")
optimizer = optim.AdamW(byol_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_PRETRAIN, eta_min=0)

pretrain_losses = []

for epoch in range(1, EPOCHS_PRETRAIN + 1):
    byol_model.train()
    total_loss = 0.0
    start_time = time.time()

    # Adjust learning rate (cosine decay)
    adjust_learning_rate(optimizer, LEARNING_RATE, epoch, EPOCHS_PRETRAIN)

    # Adjust momentum for target network (cosine annealing)
    current_momentum = adjust_momentum(epoch, EPOCHS_PRETRAIN, MOMENTUM_BASE)

    for (x1, x2) in train_loader_byol:
        x1, x2 = x1.to(DEVICE), x2.to(DEVICE)

        optimizer.zero_grad()

        # Forward pass through BYOL model
        online_pred_out1, online_pred_out2, target_proj_out1, target_proj_out2 = byol_model(x1, x2)

        # Calculate loss (symmetric loss)
        loss1 = byol_loss(online_pred_out1, target_proj_out2)
        loss2 = byol_loss(online_pred_out2, target_proj_out1)
        loss = loss1 + loss2 # Sum of two symmetric loss terms

        loss.backward()
        optimizer.step()

        # Update target network parameters using EMA
        byol_model.update_target_network(current_momentum)

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader_byol)
    pretrain_losses.append(avg_loss)
    elapsed_time = time.time() - start_time

    print(f"Pretrain Epoch {epoch}/{EPOCHS_PRETRAIN} | Loss: {avg_loss:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f} | Momentum: {current_momentum:.4f} | Time: {elapsed_time:.1f}s")

print("--- BYOL Self-Supervised Pre-training Finished ---")

# Save the pre-trained encoder weights
torch.save(byol_model.online_encoder.state_dict(), os.path.join(Dataset, "byol_encoder_cifar10.pth"))
print(f"Saved pre-trained encoder to {os.path.join(Dataset, 'byol_encoder_cifar10.pth')}")

--- Starting BYOL Self-Supervised Pre-training ---


In [None]:
# @title 6. Linear Evaluation (Downstream Task)
# This cell defines a linear classifier and evaluates the pre-trained encoder.

class LinearClassifier(nn.Module):
    def __init__(self, feature_dim, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

print("\n--- Starting Linear Evaluation ---")

# Load the pre-trained encoder
# The encoder is the 'online_encoder' from the BYOL model
pretrained_encoder = BYOLResNet().to(DEVICE)
pretrained_encoder.load_state_dict(torch.load(os.path.join(Dataset, "byol_encoder_cifar10.pth")))

# Freeze the encoder's parameters
for param in pretrained_encoder.parameters():
    param.requires_grad = False

# Attach a new linear classifier head
classifier = LinearClassifier(pretrained_encoder.feature_dim, num_classes=10).to(DEVICE)

# Only optimize the classifier's parameters
optimizer_eval = optim.Adam(classifier.parameters(), lr=LEARNING_RATE)
criterion_eval = nn.CrossEntropyLoss()

linear_eval_train_losses = []
linear_eval_test_accuracies = []

for epoch in range(1, EPOCHS_LINEAR_EVAL + 1):
    classifier.train()
    total_loss = 0.0
    start_time = time.time()

    for images, labels in train_loader_eval:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer_eval.zero_grad()

        with torch.no_grad(): # Ensure encoder remains frozen
            features = pretrained_encoder(images)

        outputs = classifier(features)
        loss = criterion_eval(outputs, labels)

        loss.backward()
        optimizer_eval.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader_eval)
    linear_eval_train_losses.append(avg_train_loss)

    # Evaluation on test set
    classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader_eval:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            features = pretrained_encoder(images) # Use the frozen encoder
            outputs = classifier(features)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    linear_eval_test_accuracies.append(accuracy)

    elapsed_time = time.time() - start_time
    print(f"Linear Eval Epoch {epoch}/{EPOCHS_LINEAR_EVAL} | Train Loss: {avg_train_loss:.4f} | Test Acc: {accuracy:.2f}% | Time: {elapsed_time:.1f}s")

print("--- Linear Evaluation Finished ---")

# Optional: Train a randomly initialized encoder + classifier as a baseline for comparison
print("\n--- Training Baseline (Randomly Initialized Encoder) ---")
random_encoder = BYOLResNet().to(DEVICE) # New, random encoder
random_classifier = LinearClassifier(random_encoder.feature_dim, num_classes=10).to(DEVICE)
random_model = nn.Sequential(random_encoder, random_classifier).to(DEVICE) # Treat as one model for simplicity

optimizer_random = optim.Adam(random_model.parameters(), lr=LEARNING_RATE)
criterion_random = nn.CrossEntropyLoss()

random_eval_test_accuracies = []

for epoch in range(1, EPOCHS_LINEAR_EVAL + 1):
    random_model.train()
    for images, labels in train_loader_eval:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer_random.zero_grad()
        outputs = random_model(images)
        loss = criterion_random(outputs, labels)
        loss.backward()
        optimizer_random.step()

    random_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader_eval:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = random_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    random_eval_test_accuracies.append(accuracy)

    if epoch % 10 == 0 or epoch == EPOCHS_LINEAR_EVAL:
        print(f"Random Baseline Eval Epoch {epoch}/{EPOCHS_LINEAR_EVAL} | Test Acc: {accuracy:.2f}%")

print("--- Random Baseline Training Finished ---")

In [None]:
# @title 7. Visualization of Impact
# This cell provides code to visualize the training loss and embedding space.

# --- Loss Curve ---
plt.figure(figsize=(10, 5))
plt.plot(pretrain_losses, label='BYOL Pre-training Loss')
plt.title('BYOL Self-Supervised Pre-training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()
plt.savefig(os.path.join(Dataset, "byol_pretrain_loss.png"))
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(linear_eval_test_accuracies, label='BYOL Pre-trained + Linear Eval Acc')
plt.plot(random_eval_test_accuracies, label='Random Init + Full Train Acc')
plt.title('Linear Evaluation Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.grid(True)
plt.legend()
plt.savefig(os.path.join(Dataset, "linear_eval_accuracy.png"))
plt.show()

# --- Embedding Space Visualization (t-SNE) ---
print("\n--- Visualizing Embedding Space with t-SNE ---")

# Get embeddings from a subset of the test set
# It's computationally expensive to run t-SNE on all 10,000 images.
# Let's use 1000 images for visualization.
num_samples_tsne = 1000
subset_dataset_eval = torch.utils.data.Subset(test_dataset_eval, range(num_samples_tsne))
subset_loader_eval = DataLoader(subset_dataset_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

pretrained_encoder.eval() # Set encoder to evaluation mode
embeddings = []
labels = []

with torch.no_grad():
    for images, targets in subset_loader_eval:
        images = images.to(DEVICE)
        features = pretrained_encoder(images)
        embeddings.append(features.cpu().numpy())
        labels.append(targets.cpu().numpy())

embeddings = np.vstack(embeddings)
labels = np.concatenate(labels)

print(f"Running t-SNE on {embeddings.shape[0]} embeddings...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
embeddings_2d = tsne.fit_transform(embeddings)

# Plotting the t-SNE results
plt.figure(figsize=(12, 10))
sns.scatterplot(
    x=embeddings_2d[:, 0], y=embeddings_2d[:, 1],
    hue=labels, palette=sns.color_palette("tab10", 10),
    legend='full', alpha=0.7
)
plt.title('BYOL Pre-trained CIFAR-10 Embeddings (t-SNE)')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.grid(True)
plt.savefig(os.path.join(Dataset, "byol_embeddings_tsne.png"))
plt.show()

# --- Discussing the Impact with Students ---
print("\n--- Discussing the Impact of BYOL ---")
print("1. Self-Supervised Pre-training Loss Curve:")
print("   - Show how the BYOL loss (a combination of MSE on normalized predictions) smoothly decreases, indicating the model is learning to predict the target network's output.")
print("   - Emphasize that this happens *without any human labels*.")
print("   - Contrast with supervised learning loss curves that rely on explicit labels.")

print("\n2. Linear Evaluation Accuracy:")
print(f"   - **BYOL Pre-trained Encoder Accuracy: {linear_eval_test_accuracies[-1]:.2f}%**")
print(f"   - **Randomly Initialized Encoder Accuracy: {random_eval_test_accuracies[-1]:.2f}%**")
print("   - Highlight the significant performance gap.")
print("   - Explain that BYOL's pre-training forces the encoder to learn *general-purpose, semantically rich features* (like edges, textures, object parts) that are useful for downstream tasks (like classification) without ever seeing a single label.")
print("   - Contrast with the 'random' baseline, which learns everything from scratch with limited labeled data, often performing much worse.")

print("\n3. Embedding Space Visualization (t-SNE):")
print("   - Show the t-SNE plot where points are colored by their true CIFAR-10 classes.")
print("   - Point out that even though BYOL never saw these labels, images from the same class tend to *cluster together* in the embedding space.")
print("   - This visually demonstrates that the encoder has learned to group similar concepts, which is why it performs well in linear evaluation.")
print("   - Explain that this 'clustering' means the representations are 'disentangled' or 'meaningful' for classification.")

print("\n4. Key Concepts of BYOL:")
print("   - **No Negative Pairs:** Explain how BYOL elegantly avoids the need for computationally expensive negative samples, unlike SimCLR or MoCo.")
print("   - **Target Network & EMA:** Describe the 'target network' as a slowly updated copy of the online network (Exponential Moving Average - EMA). This provides a stable target to predict, preventing trivial solutions (collapse).")
print("   - **Predictor:** Explain the role of the predictor head in predicting the target's output, further preventing collapse by ensuring the online network's representation doesn't become too close to the target's immediately.")
print("   - **Asymmetric Design:** Highlight the asymmetry between the online and target networks (online has a predictor, target uses EMA, stop-gradient on target).")

print("\nBy demonstrating these aspects, students can grasp the power of self-supervised learning, specifically BYOL, in learning effective representations from vast amounts of unlabeled data, a critical skill in modern AI.")