In [None]:
!pip install timm tqdm matplotlib

# Cell 2: Mount Google Drive for saving models and checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Create directories for checkpoints
!mkdir -p /content/drive/MyDrive/vit_pretraining/checkpoints/cifar100
!mkdir -p /content/drive/MyDrive/vit_pretraining/checkpoints/tiny-imagenet

In [None]:
# Cell 3: Import libraries and set up device
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
import numpy as np
import os
import random
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # Use notebook version for better Colab progress bars
import math
import time
import requests
import tarfile
import shutil
from IPython.display import display, clear_output

# For ViT model
import timm

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

set_seed()

# Set device - Colab typically provides a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Model: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Cell 4: Download and prepare Tiny-ImageNet dataset
def download_and_extract_tiny_imagenet(data_dir="/content/drive/MyDrive/data"):
    """
    Downloads and extracts the Tiny-ImageNet dataset.
    """
    os.makedirs(data_dir, exist_ok=True)

    # URL for Tiny ImageNet
    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"

    # Define the path to save the downloaded file
    zip_path = os.path.join(data_dir, "tiny-imagenet-200.zip")
    dataset_path = os.path.join(data_dir, "tiny-imagenet-200")

    # Check if the dataset already exists
    if os.path.exists(dataset_path):
        print(f"Dataset already exists at {dataset_path}")
        return dataset_path

    # Download the dataset
    print(f"Downloading Tiny-ImageNet from {url}...")
    start_time = time.time()

    # Stream the download with progress updates
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte

    with open(zip_path, 'wb') as f:
        for data in tqdm(response.iter_content(block_size),
                         total=total_size // block_size,
                         unit='KiB', unit_scale=True):
            f.write(data)

    download_time = time.time() - start_time
    print(f"Download completed in {download_time:.2f} seconds")

    # Extract the dataset
    print(f"Extracting dataset to {data_dir}...")
    extract_start_time = time.time()

    import zipfile
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        for member in tqdm(zip_ref.infolist(), desc='Extracting '):
            zip_ref.extract(member, data_dir)

    extract_time = time.time() - extract_start_time
    print(f"Extraction completed in {extract_time:.2f} seconds")

    # Clean up the zip file
    os.remove(zip_path)
    print(f"Removed zip file {zip_path}")

    return dataset_path

# Call the function to download and extract Tiny-ImageNet
tiny_imagenet_path = download_and_extract_tiny_imagenet()

In [None]:
# Cell 5: Define TinyImageNet dataset class
class TinyImageNet(torch.utils.data.Dataset):
    def __init__(self, root, train=True, transform=None):
        self.root = root
        self.transform = transform
        self.train = train
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.targets = []

        # Determine dataset directories
        if self.train:
            self.train_dir = os.path.join(root, 'train')
            if not os.path.isdir(self.train_dir):
                raise RuntimeError(f'Train directory not found at {self.train_dir}')
            self._load_train_data()
        else:
            self.val_dir = os.path.join(root, 'val')
            if not os.path.isdir(self.val_dir):
                raise RuntimeError(f'Val directory not found at {self.val_dir}')
            self._load_val_data()

    def _load_train_data(self):
        # Scan for class directories
        for class_dir in sorted(os.listdir(self.train_dir)):
            class_path = os.path.join(self.train_dir, class_dir)
            if os.path.isdir(class_path):
                self.classes.append(class_dir)

        # Create class mapping
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        # Scan for images
        for class_dir in self.classes:
            class_idx = self.class_to_idx[class_dir]
            class_path = os.path.join(self.train_dir, class_dir)

            # Check if there's an 'images' subdirectory
            images_dir = os.path.join(class_path, 'images')
            if os.path.isdir(images_dir):
                # Directory structure: root/train/class/images/*.JPEG
                for img_file in os.listdir(images_dir):
                    if img_file.endswith('.JPEG'):
                        self.images.append(os.path.join(images_dir, img_file))
                        self.targets.append(class_idx)
            else:
                # Directory structure: root/train/class/*.JPEG
                for img_file in os.listdir(class_path):
                    if img_file.endswith('.JPEG'):
                        self.images.append(os.path.join(class_path, img_file))
                        self.targets.append(class_idx)

    def _load_val_data(self):
        # First try to find the val_annotations.txt file
        val_annotations_path = os.path.join(self.val_dir, 'val_annotations.txt')

        if os.path.isfile(val_annotations_path):
            # Standard structure with val_annotations.txt

            # First, get all classes from train set if available
            train_dir = os.path.join(self.root, 'train')
            if os.path.isdir(train_dir):
                for class_dir in sorted(os.listdir(train_dir)):
                    class_path = os.path.join(train_dir, class_dir)
                    if os.path.isdir(class_path):
                        self.classes.append(class_dir)

            # If train set was not available, extract classes from val_annotations
            if not self.classes:
                with open(val_annotations_path, 'r') as f:
                    class_ids = set()
                    for line in f:
                        parts = line.strip().split('\t')
                        if len(parts) >= 2:
                            class_ids.add(parts[1])
                    self.classes = sorted(list(class_ids))

            # Create class mapping
            self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

            # Load images and targets from val_annotations.txt
            images_dir = os.path.join(self.val_dir, 'images')
            if not os.path.isdir(images_dir):
                images_dir = self.val_dir

            with open(val_annotations_path, 'r') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        img_file, class_id = parts[0], parts[1]
                        if class_id in self.class_to_idx:
                            img_path = os.path.join(images_dir, img_file)
                            if os.path.isfile(img_path):
                                self.images.append(img_path)
                                self.targets.append(self.class_to_idx[class_id])
        else:
            # Alternative structure: val dir contains subdirectories for classes
            for class_dir in sorted(os.listdir(self.val_dir)):
                class_path = os.path.join(self.val_dir, class_dir)
                if os.path.isdir(class_path):
                    self.classes.append(class_dir)

            # Create class mapping
            self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

            # Scan for images
            for class_dir in self.classes:
                class_idx = self.class_to_idx[class_dir]
                class_path = os.path.join(self.val_dir, class_dir)

                for img_file in os.listdir(class_path):
                    if img_file.endswith(('.JPEG', '.jpeg', '.jpg', '.png')):
                        self.images.append(os.path.join(class_path, img_file))
                        self.targets.append(class_idx)

    def __getitem__(self, index):
        img_path = self.images[index]
        target = self.targets[index]

        # Load image
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy image as fallback
            img = Image.new('RGB', (64, 64), color='gray')

        if self.transform:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.images)

In [None]:
# Add this class after your TinyImageNet class definition
class ClassRemappingDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, class_mapping):
        self.dataset = dataset
        self.class_mapping = class_mapping

    def __getitem__(self, index):
        img, target = self.dataset[index]
        # Map the original class index to the new consecutive index
        new_target = self.class_mapping[target]
        return img, new_target

    def __len__(self):
        return len(self.dataset)

In [None]:
# Cell 6: Dataset preparation function
def prepare_dataset(dataset_name, data_dir='/content/drive/MyDrive/data'):
    """
    Prepare dataset for pretraining according to requirements:
    - 80% of classes for pretraining
    - 75% of each pretraining class examples
    - 20% of classes reserved for continual learning
    """
    if dataset_name == 'cifar100':
        # Define transforms with stronger augmentation for training from scratch
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandAugment(num_ops=2, magnitude=9),  # More aggressive augmentation
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.Resize((224, 224))  # Resize to ViT input size
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.Resize((224, 224))  # Resize to ViT input size
        ])

        # Load CIFAR-100 dataset
        train_dataset = CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
        test_dataset = CIFAR100(root=data_dir, train=False, download=True, transform=test_transform)

        # Number of classes
        n_classes = 100

    elif dataset_name == 'tiny-imagenet':
        # Define transforms for tiny-imagenet with stronger augmentation
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(64),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandAugment(num_ops=2, magnitude=9),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            transforms.Resize((224, 224))  # Resize to ViT input size
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            transforms.Resize((224, 224))  # Resize to ViT input size
        ])

        # Path to Tiny ImageNet
        tiny_imagenet_root = os.path.join(data_dir, "tiny-imagenet-200")

        # Load tiny-imagenet dataset
        train_dataset = TinyImageNet(root=tiny_imagenet_root, train=True, transform=train_transform)
        test_dataset = TinyImageNet(root=tiny_imagenet_root, train=False, transform=test_transform)

        # Number of classes
        n_classes = 200  # tiny-imagenet has 200 classes

    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

    # Select 80% of classes for pretraining
    n_pretrain_classes = int(0.8 * n_classes)
    all_classes = list(range(n_classes))
    random.shuffle(all_classes)
    pretrain_classes = all_classes[:n_pretrain_classes]
    continual_classes = all_classes[n_pretrain_classes:]

    class_mapping = {cls: i for i, cls in enumerate(pretrain_classes)}


    print(f"Selected {len(pretrain_classes)} classes for pretraining")
    print(f"Reserved {len(continual_classes)} classes for continual learning")

    # Create indices of samples belonging to pretraining classes
    train_indices = [i for i, (_, label) in enumerate(train_dataset) if label in pretrain_classes]

    # Group indices by class
    class_indices = {}
    for idx in train_indices:
        _, label = train_dataset[idx]
        if label not in class_indices:
            class_indices[label] = []
        class_indices[label].append(idx)

    # Select 75% of samples for each pretraining class
    pretrain_indices = []
    for label, indices in class_indices.items():
        n_samples = len(indices)
        n_pretrain_samples = int(0.75 * n_samples)
        pretrain_indices.extend(indices[:n_pretrain_samples])

    # Create a subset dataset for pretraining
    pretrain_subset = Subset(train_dataset, pretrain_indices)

    # Wrap the subset with the class remapping dataset
    pretrain_dataset = ClassRemappingDataset(pretrain_subset, class_mapping)

    # Create train-val split (80-20) from the pretrain dataset
    n_pretrain = len(pretrain_dataset)
    n_val = int(0.2 * n_pretrain)
    n_train = n_pretrain - n_val

    pretrain_train_dataset, pretrain_val_dataset = random_split(
        pretrain_dataset, [n_train, n_val]
    )

    # Create dataloaders
    # Adjust batch size based on available GPU memory - smaller for Colab
    batch_size = 32  # Reduced from 64 for Colab's GPU memory constraints

    pretrain_loader = DataLoader(
        pretrain_train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,  # Reduced for Colab
        pin_memory=True  # Faster data transfer to GPU
    )

    val_loader = DataLoader(
        pretrain_val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,  # Reduced for Colab
        pin_memory=True
    )

    full_test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,  # Reduced for Colab
        pin_memory=True
    )

    print(f"Pretraining on {len(pretrain_train_dataset)} samples")
    print(f"Validation on {len(pretrain_val_dataset)} samples")
    print(f"Full test set has {len(test_dataset)} samples")

    # Store the class information for later use
    class_info = {
        'n_classes': n_classes,
        'pretrain_classes': pretrain_classes,
        'continual_classes': continual_classes
    }

    return pretrain_loader, val_loader, full_test_loader, class_info

In [None]:
# Cell 7: Model creation and learning rate scheduler
def create_vit_model_from_scratch(num_classes):
    """
    Create a ViT model from scratch (without pre-trained weights)
    """
    # Create ViT model with random initialization (pretrained=False)
    model = timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=num_classes)

    # Better initialization for Transformers
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    model.apply(_init_weights)
    return model

# Learning rate scheduler with warmup
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_epochs, max_epochs, warmup_start_lr=1e-6, eta_min=1e-6):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = warmup_start_lr
        self.eta_min = eta_min

        # Get base lr
        self.base_lr = []
        for group in optimizer.param_groups:
            self.base_lr.append(group['lr'])

    def step(self, epoch):
        if epoch < self.warmup_epochs:
            # Linear warmup
            lr_mult = epoch / self.warmup_epochs
            for i, group in enumerate(self.optimizer.param_groups):
                group['lr'] = self.warmup_start_lr + lr_mult * (self.base_lr[i] - self.warmup_start_lr)
        else:
            # Cosine annealing
            for i, group in enumerate(self.optimizer.param_groups):
                progress = (epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)
                cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
                group['lr'] = self.eta_min + cosine_decay * (self.base_lr[i] - self.eta_min)

        return [group['lr'] for group in self.optimizer.param_groups]

In [None]:
# Cell 8: Training function
def train_model_from_scratch(model, train_loader, val_loader, class_info, dataset_name,
                             num_epochs=50):  # Reduced epochs for Colab
    """
    Train the ViT model from scratch with proper hyperparameters
    """
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    # Higher learning rate for training from scratch
    # Weight decay is important for regularization when training from scratch
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05, betas=(0.9, 0.999))

    # Learning rate scheduler with warmup
    warmup_epochs = 10
    scheduler = WarmupCosineScheduler(
        optimizer,
        warmup_epochs=warmup_epochs,
        max_epochs=num_epochs,
        warmup_start_lr=1e-6,
        eta_min=1e-6
    )
    # Move model to device
    model = model.to(device)

    # Training and validation history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'lr': []
    }

    # Best model tracking
    best_val_acc = 0.0
    patience = 10  # Early stopping patience
    patience_counter = 0

    # Checkpoint directory on Google Drive
    checkpoint_dir = f"/content/drive/MyDrive/vit_pretraining/checkpoints/{dataset_name}"
    os.makedirs(checkpoint_dir, exist_ok=True)

    print(f"Starting training for {num_epochs} epochs...")
    for epoch in range(num_epochs):
        epoch_start_time = time.time()

        # Update learning rate
        current_lr = scheduler.step(epoch)
        history['lr'].append(current_lr[0])  # Log learning rate

        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward pass and optimize
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            # Track statistics
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

            # Update progress bar
            pbar.set_postfix({
                'loss': train_loss / (batch_idx + 1),
                'acc': 100. * train_correct / train_total,
                'lr': current_lr[0]
            })

            # Free up GPU memory
            del inputs, targets, outputs, loss
            torch.cuda.empty_cache()

        # Calculate average training metrics
        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
            for batch_idx, (inputs, targets) in enumerate(pbar):
                inputs, targets = inputs.to(device), targets.to(device)

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                # Track statistics
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

                # Update progress bar
                pbar.set_postfix({
                    'loss': val_loss / (batch_idx + 1),
                    'acc': 100. * val_correct / val_total
                })

                # Free up GPU memory
                del inputs, targets, outputs, loss
                torch.cuda.empty_cache()

        # Calculate average validation metrics
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total

        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time

        # Print epoch summary
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% - "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% - "
              f"LR: {current_lr[0]:.6f} - "
              f"Time: {epoch_time:.1f}s")

        # Save to history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        # Plot training progress
        if (epoch + 1) % 5 == 0 or epoch == 0:  # Every 5 epochs
            plot_training_progress(history, dataset_name)

        # Save best model and check early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0  # Reset patience counter
            print(f"New best validation accuracy: {best_val_acc:.2f}%")

            # Save the best model
            save_model(model, optimizer, epoch, history, class_info, dataset_name,
                      checkpoint_dir=checkpoint_dir, is_best=True)
        else:
            patience_counter += 1
            print(f"Validation accuracy did not improve. Patience: {patience_counter}/{patience}")

            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            save_model(model, optimizer, epoch, history, class_info, dataset_name,
                      checkpoint_dir=checkpoint_dir, is_best=False,
                      checkpoint_name=f"checkpoint_epoch_{epoch+1}")

    # Save final model
    save_model(model, optimizer, epoch, history, class_info, dataset_name,
              checkpoint_dir=checkpoint_dir, is_best=False)

    return model, history

In [None]:
# Cell 9: Functions for saving, evaluating, and loading models
def save_model(model, optimizer, epoch, history, class_info, dataset_name,
              checkpoint_dir=None, is_best=False, checkpoint_name=None):
    """
    Save model checkpoint
    """
    if checkpoint_dir is None:
        checkpoint_dir = f"/content/drive/MyDrive/vit_pretraining/checkpoints/{dataset_name}"

    if checkpoint_name:
        model_type = checkpoint_name
    else:
        model_type = 'best' if is_best else 'final'

    os.makedirs(checkpoint_dir, exist_ok=True)

    checkpoint_path = os.path.join(checkpoint_dir, f"vit_{model_type}_checkpoint.pth")

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'history': history,
        'class_info': class_info
    }

    torch.save(checkpoint, checkpoint_path)
    print(f"Saved {model_type} model checkpoint to {checkpoint_path}")

def evaluate_model(model, test_loader, class_info):
    """
    Evaluate model on the full test set
    """
    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0

    # Class-wise accuracy
    class_correct = {}
    class_total = {}

    # Initialize counters for each class
    for cls in range(class_info['n_classes']):
        class_correct[cls] = 0
        class_total[cls] = 0

    criterion = nn.CrossEntropyLoss()

    inverse_mapping = {i: cls for i, cls in enumerate(class_info['pretrain_classes'])}


    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Evaluating")
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Track overall statistics
            _, predicted = outputs.max(1)
            predicted = torch.tensor([inverse_mapping[p.item()] for p in predicted], device=device)
            test_total += targets.size(0)
            test_correct += predicted.eq(targets).sum().item()

            # Track class-wise statistics
            for i in range(len(class_info['pretrain_classes'])):
                cls = inverse_mapping[i]  # Get original class index
                cls_idx = (targets == i)
                class_total[cls] += cls_idx.sum().item()
                class_correct[cls] += (predicted.eq(targets) & cls_idx).sum().item()

            # Update progress bar
            pbar.set_postfix({
                'acc': 100. * test_correct / test_total
            })

            # Free up GPU memory
            del inputs, targets, outputs
            torch.cuda.empty_cache()

    # Calculate average test metrics
    test_acc = 100. * test_correct / test_total

    print(f"Test Acc: {test_acc:.2f}%")

    # Calculate accuracy for pretrain and continual classes
    pretrain_correct = sum(class_correct[cls] for cls in class_info['pretrain_classes'])
    pretrain_total = sum(class_total[cls] for cls in class_info['pretrain_classes'])
    pretrain_acc = 100. * pretrain_correct / pretrain_total if pretrain_total > 0 else 0

    continual_correct = sum(class_correct[cls] for cls in class_info['continual_classes'])
    continual_total = sum(class_total[cls] for cls in class_info['continual_classes'])
    continual_acc = 100. * continual_correct / continual_total if continual_total > 0 else 0

    print(f"Pretrain Classes Acc: {pretrain_acc:.2f}%")
    print(f"Continual Classes Acc: {continual_acc:.2f}%")

    return test_acc, {
        'pretrain_acc': pretrain_acc,
        'continual_acc': continual_acc,
        'class_acc': {cls: 100. * class_correct[cls] / class_total[cls] if class_total[cls] > 0 else 0
                     for cls in range(class_info['n_classes'])}
    }

def load_pretrained_model(dataset_name, model_type='best', checkpoint_dir=None):
    """
    Load a pretrained ViT model
    """
    if checkpoint_dir is None:
        checkpoint_dir = f"/content/drive/MyDrive/vit_pretraining/checkpoints/{dataset_name}"

    checkpoint_path = os.path.join(checkpoint_dir, f"vit_{model_type}_checkpoint.pth")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device)

    class_info = checkpoint['class_info']
    model = create_vit_model_from_scratch(num_classes=len(class_info['pretrain_classes']))
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)

    print(f"Loaded {model_type} {dataset_name} ViT model from {checkpoint_path}")
    print(f"Model was trained for {checkpoint['epoch'] + 1} epochs")

    return model, class_info

In [None]:
# Cell 10: Visualization functions
def plot_training_progress(history, dataset_name):
    """
    Plot training progress during training
    """
    plt.figure(figsize=(15, 5))

    # Plot loss
    plt.subplot(1, 3, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title(f'{dataset_name} Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 3, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title(f'{dataset_name} Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    # Plot learning rate
    plt.subplot(1, 3, 3)
    plt.plot(history['lr'])
    plt.title(f'{dataset_name} Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')

    plt.tight_layout()
    plt.show()

def plot_training_history(history, dataset_name):
    """
    Plot complete training history after training
    """
    plt.figure(figsize=(15, 10))

    # Plot loss
    plt.subplot(2, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title(f'{dataset_name} Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracy
    plt.subplot(2, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title(f'{dataset_name} Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    # Plot learning rate
    plt.subplot(2, 2, 3)
    plt.plot(history['lr'])
    plt.title(f'{dataset_name} Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')

    # Save the figure to Google Drive
    save_path = f"/content/drive/MyDrive/vit_pretraining/{dataset_name}_training_history.png"
    plt.savefig(save_path)
    print(f"Saved training history plot to {save_path}")

    plt.tight_layout()
    plt.show()

In [None]:
# Cell 11: Main training function
def train_cifar100():
    """
    Train ViT on CIFAR-100 from scratch
    """
    # Set seed for reproducibility
    set_seed(42)

    print("=== Pretraining ViT on CIFAR-100 from scratch ===")
    cifar_train_loader, cifar_val_loader, cifar_test_loader, cifar_class_info = prepare_dataset('cifar100')

    cifar_model = create_vit_model_from_scratch(num_classes=len(cifar_class_info['pretrain_classes']))
    cifar_model, cifar_history = train_model_from_scratch(
        cifar_model, cifar_train_loader, cifar_val_loader, cifar_class_info, 'cifar100'
    )

    print("\n=== Evaluating CIFAR-100 ViT on Full Test Set ===")
    evaluate_model(cifar_model, cifar_test_loader, cifar_class_info)

    # Plot final training history
    plot_training_history(cifar_history, 'cifar100')

    print("\nCIFAR-100 pretraining complete! Model saved to Google Drive.")
    return cifar_model, cifar_history, cifar_class_info

def train_tiny_imagenet():
    """
    Train ViT on Tiny-ImageNet from scratch
    """
    # Set seed for reproducibility
    set_seed(42)

    print("\n=== Pretraining ViT on Tiny-ImageNet from scratch ===")
    tiny_train_loader, tiny_val_loader, tiny_test_loader, tiny_class_info = prepare_dataset('tiny-imagenet')

    tiny_model = create_vit_model_from_scratch(num_classes=len(tiny_class_info['pretrain_classes']))
    tiny_model, tiny_history = train_model_from_scratch(
        tiny_model, tiny_train_loader, tiny_val_loader, tiny_class_info, 'tiny-imagenet'
    )

    print("\n=== Evaluating Tiny-ImageNet ViT on Full Test Set ===")
    evaluate_model(tiny_model, tiny_test_loader, tiny_class_info)

    # Plot final training history
    plot_training_history(tiny_history, 'tiny-imagenet')

    print("\nTiny-ImageNet pretraining complete! Model saved to Google Drive.")
    return tiny_model, tiny_history, tiny_class_info

In [None]:
# Cell 12: Evaluation function
def check_models_on_full_dataset():
    """
    Check the pre-trained models on the full dataset
    """
    print("=== Checking Pre-trained Models on Full Datasets ===")

    # Check if pretrained models exist
    cifar_checkpoint_path = "/content/drive/MyDrive/vit_pretraining/checkpoints/cifar100/vit_best_checkpoint.pth"
    tiny_checkpoint_path = "/content/drive/MyDrive/vit_pretraining/checkpoints/tiny-imagenet/vit_best_checkpoint.pth"

    models_found = True

    if not os.path.exists(cifar_checkpoint_path):
        print(f"CIFAR-100 model not found at {cifar_checkpoint_path}")
        models_found = False

    if not os.path.exists(tiny_checkpoint_path):
        print(f"Tiny-ImageNet model not found at {tiny_checkpoint_path}")
        models_found = False

    if not models_found:
        print("Some pretrained models not found. Please run the pretraining first.")
        return

    # Load the models
    cifar_model, cifar_class_info = load_pretrained_model('cifar100')
    tiny_model, tiny_class_info = load_pretrained_model('tiny-imagenet')

    # Load the datasets
    _, _, cifar_test_loader, _ = prepare_dataset('cifar100')
    _, _, tiny_test_loader, _ = prepare_dataset('tiny-imagenet')

    # Evaluate on full test set
    print("\n=== Evaluating CIFAR-100 ViT on Full Test Set ===")
    cifar_results = evaluate_model(cifar_model, cifar_test_loader, cifar_class_info)

    print("\n=== Evaluating Tiny-ImageNet ViT on Full Test Set ===")
    tiny_results = evaluate_model(tiny_model, tiny_test_loader, tiny_class_info)

    # Analyze performance on pretrain vs continual classes
    print("\n=== Performance Analysis ===")
    print("CIFAR-100:")
    print(f"  Pretrain Classes Accuracy: {cifar_results[2]['pretrain_acc']:.2f}%")
    print(f"  Continual Classes Accuracy: {cifar_results[2]['continual_acc']:.2f}%")

    print("\nTiny-ImageNet:")
    print(f"  Pretrain Classes Accuracy: {tiny_results[2]['pretrain_acc']:.2f}%")
    print(f"  Continual Classes Accuracy: {tiny_results[2]['continual_acc']:.2f}%")

In [None]:
# Cell 13: Run the training for CIFAR-100
# Uncomment the line below to run the CIFAR-100 training
# cifar_model, cifar_history, cifar_class_info = train_cifar100()

In [None]:
# Cell 14: Run the training for Tiny-ImageNet
# Uncomment the line below to run the Tiny-ImageNet training
# tiny_model, tiny_history, tiny_class_info = train_tiny_imagenet()