In [None]:
# First, restart your Colab runtime completely:
# Runtime → Restart runtime

# Then add this at the very beginning of your notebook:
import gc
import torch
import os

# Set memory optimization BEFORE any model creation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Clear any existing variables
for var in list(globals().keys()):
    if var not in ['__builtins__', '__name__', '__doc__', '__loader__']:
        del globals()[var]

gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

print(f"GPU Memory after cleanup: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [None]:
!pip install timm tqdm matplotlib

# Cell 1: Mount Google Drive for saving models and checkpoints
from google.colab import drive
drive.mount('/content/drive')

In [None]:
EXPERIMENT_RUN = 5 # Update for every new configuration
RANDOM_SEED = 42

# Dataset parameters
IND_CLASS_RATIO = 0.80
PRETRAIN_EXAMPLE_RATIO = 0.75
VAL_RATIO = 0.20
CIFAR100_NUM_CLASSES = 100
BATCH_SIZE = 16
NUM_OF_WORKERS = 1

# Model parameters
MODEL = 'vit_small_patch16_224'
DROPOUT_PATH_RATE = 0.10
DROPOUT_RATE = 0.10
PATCH_SIZE = 2

# Optimizer parameters
WARMUP_LR_START = 1e-6
ETA_MIN = 1e-6
LR_WARMUP_EPOCHS = 20

# Training parameters
TRAINING_EPOCHS = 200
OPTIMIZER_LR = 5e-4
WEIGHT_DECAY = 0.05
BETAS = (0.9, 0.999)
ACCUMULATION_STEPS = 16
PATIENCE_EPOCHS = 10
LABEL_SMOOTHING = 0.20

In [None]:
# Create directories for checkpoints
CHECKPOINT_DIR_PATH = f"/content/drive/MyDrive/vit_pretrained_cifar100/checkpoints/cifar100_{EXPERIMENT_RUN}"
!mkdir -p dir_path

In [None]:
# Cell 2: Import libraries and set up device
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
import numpy as np
import os
import random
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # Use notebook version for better Colab progress bars
import math
import time
import requests
import tarfile
import shutil
from IPython.display import display, clear_output

# For ViT model
import timm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed) # set numpy seed
    torch.manual_seed(seed) # set torch (CPU) seed
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed) # set torch (GPU) seed
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
        torch.backends.cudnn.deterministic = True # set cudnn to deterministic mode
        torch.backends.cudnn.benchmark = False  # Disable benchmarking for reproducibility
    # Set seeds for data loading operations
    os.environ['PYTHONHASHSEED'] = str(seed)

    # Document the environment for future reproducibility
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}")

set_seed()

# Worker initialization function for DataLoaders
def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Set device - Colab typically provides a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Model: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Cell 4: Create a data wrapper class to remap the class lables
class ClassRemappingDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, class_mapping):
        self.dataset = dataset
        self.class_mapping = class_mapping

    def __getitem__(self, index):
        img, target = self.dataset[index]
        # Map the original class index to the new consecutive index
        new_target = self.class_mapping[target]
        return img, new_target

    def __len__(self):
        return len(self.dataset)

In [None]:
# Cell 6: Dataset preparation function - updated to track left-out examples
def prepare_dataset(data_dir='/content/drive/MyDrive/data'):
    """
    Prepare dataset for pretraining according to requirements:
    - 80% of classes for pretraining
    - 75% of each pretraining class examples
    - 20% of classes reserved for continual learning
    Now also tracks the 25% left-out examples from pretraining classes for continual learning
    """

    # Define transforms with stronger augmentation for training from scratch
    train_transform = transforms.Compose([
        transforms.Resize(224),  # Resize to ViT input size
        transforms.TrivialAugmentWide(),  # Better than RandAugment
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.5071, 0.4867, 0.4408),
            std=(0.2675, 0.2565, 0.2761)
        ),
        transforms.RandomErasing(p=0.15)
    ])

    test_transform = transforms.Compose([
        transforms.Resize(224),  # Resize to ViT input size of 224x224
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.5071, 0.4867, 0.4408),
            std=(0.2675, 0.2565, 0.2761)
        )
    ])

    # Load CIFAR-100 dataset
    train_dataset = CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
    test_dataset = CIFAR100(root=data_dir, train=False, download=True, transform=test_transform)

    # Number of classes
    num_of_classes = CIFAR100_NUM_CLASSES
    num_of_pretrain_classes = int(IND_CLASS_RATIO * num_of_classes)

    all_class_indicies = list(range(num_of_classes))
    random.shuffle(all_class_indicies) # Shuffle the class indicies before splitting into pretraining and OOD classes

    pretrain_classes = all_class_indicies[:num_of_pretrain_classes]
    ood_classes = all_class_indicies[num_of_pretrain_classes:]

    ind_class_mapping = {class_idx: i for i, class_idx in enumerate(pretrain_classes)}
    # ex. pretrain_class_mapping = {23: 0, 11: 1, 93: 2, ...}

    print(f"Selected {len(pretrain_classes)} classes for pretraining")
    print(f"Reserved {len(ood_classes)} classes for OOD detection")

    train_indices = [i for i, (input_tensor, label) in enumerate(train_dataset) if label in pretrain_classes]

    class_indices = {} # Maps indices of examples to their class labels
    for idx in train_indices:
        input_tensor, label = train_dataset[idx]
        if label not in class_indices:
            class_indices[label] = []
        class_indices[label].append(idx)

    pretrained_ind_indices, pretrained_left_out_indices = [], []
    for label, example_indices in class_indices.items():
        num_of_examples = len(example_indices)
        num_of_pretrain_examples = int(PRETRAIN_EXAMPLE_RATIO * num_of_examples)
        pretrained_ind_indices.extend(example_indices[:num_of_pretrain_examples])
        pretrained_left_out_indices.extend(example_indices[num_of_pretrain_examples:])  # Store the 25% left-out examples

    # Create subset datasets
    print("Creating subset datasets...")
    pretrain_subset = Subset(train_dataset, pretrained_ind_indices)
    pretrain_dataset = ClassRemappingDataset(pretrain_subset, ind_class_mapping)

    # Create train-val split
    n_pretrain = len(pretrain_dataset)
    n_val = int(VAL_RATIO * n_pretrain)
    n_train = n_pretrain - n_val

    # Use generator for reproducible split
    generator = torch.Generator().manual_seed(42)
    pretrain_train_dataset, pretrain_val_dataset = random_split(
        pretrain_dataset, [n_train, n_val], generator=generator
    )

    # Create dataloaders
    print("Creating data loaders...")
    batch_size = BATCH_SIZE

    pretrain_loader = DataLoader(
        pretrain_train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_OF_WORKERS,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    val_loader = DataLoader(
        pretrain_val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_OF_WORKERS,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    full_test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_OF_WORKERS,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    print(f"Pretraining on {len(pretrain_train_dataset)} samples")
    print(f"Validation on {len(pretrain_val_dataset)} samples")
    print(f"Full test set has {len(test_dataset)} samples")
    print(f"Left out {len(pretrained_left_out_indices)} examples from pretraining classes for continual learning")

    # Store class information
    class_info = {
        'num_of_classes': num_of_classes, # Old name for this property "n_classes"
        'pretrain_classes': pretrain_classes,
        'left_out_classes': ood_classes, # Old name for this property "continual_classes"
        'left_out_ind_indices': pretrained_left_out_indices, # Old name for this property "left_out_indices"
        'pretrained_ind_indices': pretrained_ind_indices,
        'pretrain_class_mapping': ind_class_mapping # Old name for this property "class_mapping"
    }

    return pretrain_loader, val_loader, full_test_loader, class_info

In [None]:
# Cell 7: Model creation and learning rate scheduler
def create_vit_model_from_scratch(num_classes):
    """
    Create a ViT model from scratch (without pre-trained weights)
    """
    # Create ViT model with random initialization (pretrained=False)
    model = timm.create_model(
        MODEL,
        pretrained=False,
        num_classes=num_classes,
        drop_path_rate=DROPOUT_PATH_RATE,
        drop_rate=DROPOUT_RATE,
        patch_size=PATCH_SIZE
    )

    # Better initialization for Transformers
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    model.apply(_init_weights)
    return model

# Learning rate scheduler with warmup
class WarmupCosineScheduler():
    def __init__(self, optimizer, warmup_epochs, max_epochs):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = WARMUP_LR_START
        self.eta_min = ETA_MIN

        # Get base lr
        self.base_lr = []
        for group in optimizer.param_groups:
            self.base_lr.append(group['lr'])

    def step(self, epoch):
        if epoch < self.warmup_epochs:
            # Linear warmup
            lr_mult = epoch / self.warmup_epochs
            for i, group in enumerate(self.optimizer.param_groups):
                group['lr'] = self.warmup_start_lr + lr_mult * (self.base_lr[i] - self.warmup_start_lr)
        else:
            # Cosine annealing
            for i, group in enumerate(self.optimizer.param_groups):
                progress = (epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)
                cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
                group['lr'] = self.eta_min + cosine_decay * (self.base_lr[i] - self.eta_min)

        return [group['lr'] for group in self.optimizer.param_groups]

In [None]:
# Add MixUp and CutMix data augmentation
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
def train_model_from_scratch(model, train_loader, val_loader, class_info, dataset_name,
                           num_epochs=1):
    """
    Train the ViT model from scratch with proper hyperparameters and mixed precision
    """
    # Import for mixed precision training
    from torch.amp.autocast_mode import autocast
    from torch.amp import GradScaler

    # Loss function
    criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)

    # Optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=OPTIMIZER_LR,
        weight_decay=WEIGHT_DECAY,
        betas=BETAS
    )

    # Learning rate scheduler with warmup
    scheduler = WarmupCosineScheduler(
        optimizer,
        warmup_epochs=LR_WARMUP_EPOCHS,
        max_epochs=num_epochs
    )

    # Move model to device
    model = model.to(device)

    # Initialize gradient scaler for mixed precision training
    scaler = GradScaler('cuda', enabled=(device.type == 'cuda'))

    # Gradient accumulation steps (for larger effective batch size)
    accumulation_steps = ACCUMULATION_STEPS

    # Training and validation history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'lr': []
    }

    # Best model tracking
    best_val_acc = 0.0
    patience = PATIENCE_EPOCHS  # Early stopping patience
    patience_counter = 0

    # Checkpoint directory on Google Drive
    checkpoint_dir = CHECKPOINT_DIR_PATH
    os.makedirs(checkpoint_dir, exist_ok=True)

    print(f"<======= Starting training for {num_epochs} epochs... =======>")
    for epoch in range(num_epochs):
        epoch_start_time = time.time()

        # Update learning rate
        current_lr = scheduler.step(epoch)
        history['lr'].append(current_lr[0])  # Log learning rate

        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

            # Apply MixUp augmentation with 90% probability
            use_mixup = np.random.random() < 0.9
            if use_mixup:
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha=1.0)
                # ✅ FIX: Don't move inputs to device again - already there
                targets_a, targets_b = targets_a.to(device, non_blocking=True), targets_b.to(device, non_blocking=True)

            # Zero gradients only when accumulation steps completed
            if (batch_idx % accumulation_steps == 0):
                optimizer.zero_grad(set_to_none=True)  # ✅ FIX: Use set_to_none=True

            # Forward pass with mixed precision
            with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                outputs = model(inputs)
                if use_mixup:
                    loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
                else:
                    loss = criterion(outputs, targets)

                # Scale loss by accumulation steps for gradient accumulation
                loss = loss / accumulation_steps

            # Backward pass (w/ gradient scaling)
            scaler.scale(loss).backward()

            # Step optimizer only at the end of accumulation steps
            if ((batch_idx + 1) % accumulation_steps == 0) or (batch_idx + 1 == len(train_loader)):
                # Unscale before gradient clipping
                scaler.unscale_(optimizer)

                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Step with gradient scaler
                scaler.step(optimizer)
                scaler.update()

                # ✅ FIX: Zero gradients after step
                optimizer.zero_grad(set_to_none=True)

            # Track statistics (adjust for mixup)
            with torch.no_grad():  # ✅ FIX: Wrap statistics in no_grad
                if use_mixup:
                    train_loss += loss.item() * accumulation_steps
                    _, predicted = outputs.max(1)
                    train_total += targets.size(0)
                    if lam > 0.5:
                        train_correct += predicted.eq(targets_a).sum().item()
                    else:
                        train_correct += predicted.eq(targets_b).sum().item()
                else:
                    train_loss += loss.item() * accumulation_steps
                    _, predicted = outputs.max(1)
                    train_total += targets.size(0)
                    train_correct += predicted.eq(targets).sum().item()

            # Update progress bar
            pbar.set_postfix({
                'loss': train_loss / (batch_idx + 1),
                'acc': 100. * train_correct / train_total,
                'lr': current_lr[0]
            })

            # ✅ FIX: Minimal memory cleanup - NO torch.cuda.empty_cache()!
            # Only clear memory occasionally, not every batch
            if batch_idx % 100 == 0 and batch_idx > 0:  # Every 100 batches
                torch.cuda.empty_cache()

        # Calculate average training metrics
        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        # ✅ OPTIMIZED VALIDATION LOOP
        with torch.no_grad():
            with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
                for batch_idx, (inputs, targets) in enumerate(pbar):
                    inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

                    # Forward pass
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)

                    # Track statistics
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    val_total += targets.size(0)
                    val_correct += predicted.eq(targets).sum().item()

                    # Update progress bar
                    pbar.set_postfix({
                        'loss': val_loss / (batch_idx + 1),
                        'acc': 100. * val_correct / val_total
                    })

                    # ✅ FIX: Only clear memory occasionally
                    if batch_idx % 100 == 0 and batch_idx > 0:
                        torch.cuda.empty_cache()

        # Calculate average validation metrics
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time

        # Print epoch summary
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% - "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% - "
              f"LR: {current_lr[0]:.6f} - "
              f"Time: {epoch_time:.1f}s")

        # Save to history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        # Plot training progress
        if (epoch + 1) % 5 == 0 or epoch == 0:  # Every 5 epochs
            plot_training_progress(history, dataset_name)

        # Save best model and check early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0  # Reset patience counter
            print(f"New best validation accuracy: {best_val_acc:.2f}%")

            # Save the best model
            save_model(model, optimizer, epoch, history, class_info, dataset_name,
                      checkpoint_dir=checkpoint_dir, is_best=True)
        else:
            patience_counter += 1
            print(f"Validation accuracy did not improve. Patience: {patience_counter}/{patience}")

            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            save_model(model, optimizer, epoch, history, class_info, dataset_name,
                      checkpoint_dir=checkpoint_dir, is_best=False,
                      checkpoint_name=f"checkpoint_epoch_{epoch+1}")

    # Save final model
    save_model(model, optimizer, epoch, history, class_info, dataset_name,
              checkpoint_dir=checkpoint_dir, is_best=False)

    return model, history

In [None]:
# Cell 9: Functions for saving, evaluating, and loading models
def save_model(model, optimizer, epoch, history, class_info, dataset_name,
              checkpoint_dir=None, is_best=False, checkpoint_name=None):
    """
    Save model checkpoint
    """
    if checkpoint_dir is None:
        checkpoint_dir = CHECKPOINT_DIR_PATH

    if checkpoint_name:
        model_type = checkpoint_name
    else:
        model_type = 'best' if is_best else 'final'

    os.makedirs(checkpoint_dir, exist_ok=True)

    checkpoint_path = os.path.join(checkpoint_dir, f"vit_{model_type}_checkpoint.pth")

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(), # Model weights
        'optimizer_state_dict': optimizer.state_dict(), # Optimizer state
        'history': history, # Training history
        'class_info': class_info # Class information
    }

    torch.save(checkpoint, checkpoint_path)
    print(f"Saved {model_type} model checkpoint to {checkpoint_path}")

# Currently not used but could be used to evaluate the model on the full test set
# def evaluate_model(model, test_loader, class_info):
#     """
#     Evaluate model on the full test set with consistent class mapping
#     """
#     model.eval()
#     test_loss = 0.0
#     test_correct = 0
#     test_total = 0

#     # Class-wise accuracy
#     class_correct = {}
#     class_total = {}

#     # Initialize counters for each class
#     for cls in range(class_info['n_classes']):
#         class_correct[cls] = 0
#         class_total[cls] = 0

#     criterion = nn.CrossEntropyLoss()

#     # Create proper reverse mapping using class_mapping from class_info
#     if 'class_mapping' in class_info:
#         # Use the stored mapping if available
#         inverse_mapping = {i: cls for cls, i in class_info['class_mapping'].items()}
#     else:
#         # Fallback to create mapping from pretrain_classes
#         inverse_mapping = {i: cls for i, cls in enumerate(class_info['pretrain_classes'])}

#     with torch.no_grad():
#         # Use mixed precision for evaluation - fixed deprecated API
#         from torch.amp.autocast_mode import autocast
#         with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
#             pbar = tqdm(test_loader, desc="Evaluating")
#             for batch_idx, (inputs, targets) in enumerate(pbar):
#                 inputs, targets = inputs.to(device), targets.to(device)

#                 # Forward pass
#                 outputs = model(inputs)
#                 loss = criterion(outputs, targets)

#                 # Track overall statistics - fix the class mapping issue
#                 _, predicted_indices = outputs.max(1)

#                 # SAFER APPROACH: Only count samples where the original class is in pretrain_classes
#                 mask = torch.tensor([t.item() in class_info['pretrain_classes'] for t in targets.cpu()],
#                                     device=device)

#                 # Filtered statistics
#                 test_total += mask.sum().item()

#                 correct_predictions = torch.zeros_like(targets, dtype=torch.bool)
#                 for i, (pred_idx, target) in enumerate(zip(predicted_indices, targets)):
#                     if pred_idx.item() in inverse_mapping and target.item() in class_info['pretrain_classes']:
#                         orig_class = inverse_mapping[pred_idx.item()]
#                         if orig_class == target.item():
#                             correct_predictions[i] = True

#                 test_correct += (correct_predictions & mask).sum().item()

#                 # Track class-wise statistics - only for classes in pretrain_classes
#                 for cls in class_info['pretrain_classes']:
#                     cls_idx = (targets == cls)
#                     class_total[cls] += cls_idx.sum().item()
#                     class_correct[cls] += (correct_predictions & cls_idx).sum().item()

#                 # Update progress bar
#                 pbar.set_postfix({
#                     'acc': 100. * test_correct / test_total if test_total > 0 else 0
#                 })

#                 # Free up GPU memory
#                 del inputs, targets, outputs
#                 torch.cuda.empty_cache()

#     # Calculate average test metrics
#     test_acc = 100. * test_correct / test_total if test_total > 0 else 0

#     print(f"Test Acc: {test_acc:.2f}%")

#     # Calculate accuracy for pretrain and continual classes
#     pretrain_correct = sum(class_correct[cls] for cls in class_info['pretrain_classes'])
#     pretrain_total = sum(class_total[cls] for cls in class_info['pretrain_classes'])
#     pretrain_acc = 100. * pretrain_correct / pretrain_total if pretrain_total > 0 else 0

#     continual_correct = sum(class_correct.get(cls, 0) for cls in class_info['continual_classes'])
#     continual_total = sum(class_total.get(cls, 0) for cls in class_info['continual_classes'])
#     continual_acc = 100. * continual_correct / continual_total if continual_total > 0 else 0

#     print(f"Pretrain Classes Acc: {pretrain_acc:.2f}%")
#     print(f"Continual Classes Acc: {continual_acc:.2f}%")

#     return test_acc, {
#         'pretrain_acc': pretrain_acc,
#         'continual_acc': continual_acc,
#         'class_acc': {cls: 100. * class_correct.get(cls, 0) / class_total.get(cls, 1)
#                      for cls in range(class_info['n_classes'])}
#     }

def load_pretrained_model(dataset_name, model_type='best', checkpoint_dir=None):
    """
    Load a pretrained ViT model
    """
    if checkpoint_dir is None:
        checkpoint_dir = CHECKPOINT_DIR_PATH

    checkpoint_path = os.path.join(checkpoint_dir, f"vit_{model_type}_checkpoint.pth")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device)

    class_info = checkpoint['class_info']
    model = create_vit_model_from_scratch(num_classes=len(class_info['pretrain_classes']))
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)

    print(f"Loaded {model_type} {dataset_name} ViT model from {checkpoint_path}")
    print(f"Model was trained for {checkpoint['epoch'] + 1} epochs")

    return model, class_info

In [None]:

# Cell 10: Visualization functions
def plot_training_progress(history, dataset_name):
    """
    Plot training progress during training
    """
    plt.figure(figsize=(15, 5))

    # Plot loss
    plt.subplot(1, 3, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title(f'{dataset_name} Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 3, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title(f'{dataset_name} Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    # Plot learning rate
    plt.subplot(1, 3, 3)
    plt.plot(history['lr'])
    plt.title(f'{dataset_name} Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')

    plt.tight_layout()
    plt.show()

def plot_training_history(history, dataset_name):
    """
    Plot complete training history after training
    """
    plt.figure(figsize=(15, 10))

    # Plot loss
    plt.subplot(2, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title(f'{dataset_name} Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracy
    plt.subplot(2, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title(f'{dataset_name} Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    # Plot learning rate
    plt.subplot(2, 2, 3)
    plt.plot(history['lr'])
    plt.title(f'{dataset_name} Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')

    # Save the figure to Google Drive
    save_path = f"{CHECKPOINT_DIR_PATH}_training_history.png"
    plt.savefig(save_path)
    print(f"Saved training history plot to {save_path}")

    plt.tight_layout()
    plt.show()

In [None]:
# Cell 11: Main training function
def train_cifar100():
    """
    Train ViT on CIFAR-100 from scratch
    """
    # Set seed for reproducibility
    set_seed(42)

    print("=== Pretraining ViT on CIFAR-100 from scratch ===")
    cifar_train_loader, cifar_val_loader, cifar_test_loader, cifar_class_info = prepare_dataset()

    cifar_model = create_vit_model_from_scratch(num_classes=len(cifar_class_info['pretrain_classes']))
    cifar_model, cifar_history = train_model_from_scratch(
        cifar_model, cifar_train_loader, cifar_val_loader, cifar_class_info, 'cifar100', TRAINING_EPOCHS
    )

    # print("\n=== Evaluating CIFAR-100 ViT on Full Test Set ===")
    # evaluate_model(cifar_model, cifar_test_loader, cifar_class_info)

    # Plot final training history
    plot_training_history(cifar_history, 'cifar100')

    print("\nCIFAR-100 pretraining complete! Model saved to Google Drive.")
    return cifar_model, cifar_history, cifar_class_info

In [None]:
def print_memory_usage():
    if torch.cuda.is_available():
        print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated, "
              f"{torch.cuda.memory_reserved() / 1e9:.2f} GB reserved")

In [None]:
# Uncomment the line below to run the CIFAR-100 training
cifar_model, cifar_history, cifar_class_info = train_cifar100()