In [8]:
# Cell 1 — Install dependencies
!pip install torchvision torchaudio wandb
import torchvision
import torchaudio



In [2]:
import wandb
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m142502015[0m ([33m142502015-indian-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
import os
import random
import time
from pathlib import Path
from typing import Tuple, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision.models import resnet18

import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)


Using device: cuda


In [2]:
# Cell 4 — Dataloaders factory
def get_cifar_loaders(name: str, batch_size=256, num_workers=2, augment=True):
    assert name in ("CIFAR10", "CIFAR100")
    if name == "CIFAR10":
        dataset_class = torchvision.datasets.CIFAR10
    else:
        dataset_class = torchvision.datasets.CIFAR100

    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)

    train_transforms = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ] if augment else [transforms.ToTensor(), transforms.Normalize(mean,std)]

    test_transforms = [transforms.ToTensor(), transforms.Normalize(mean,std)]

    trainset = dataset_class(root='./data', train=True, download=True, transform=transforms.Compose(train_transforms))
    testset  = dataset_class(root='./data', train=False, download=True, transform=transforms.Compose(test_transforms))

    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader, trainset, testset


In [3]:
# Cell 5 — Model builder (ResNet18 adapted to num_classes)
def build_model(num_classes: int, pretrained=False):
    model = resnet18(pretrained=pretrained)
    # adapt first conv for CIFAR (32x32): change kernel_size=3, stride=1, padding=1
    model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = torch.nn.Identity()
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    return model


In [4]:
# Cell 6 — Training/validation utilities + logging helpers
from sklearn.metrics import confusion_matrix

def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    losses = []
    criterion = nn.CrossEntropyLoss()
    preds = []
    targets = []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            loss = criterion(out, y)
            losses.append(loss.item())
            _, p = out.max(1)
            correct += (p == y).sum().item()
            total += y.size(0)
            preds.append(p.cpu().numpy())
            targets.append(y.cpu().numpy())
    avg_loss = float(np.mean(losses))
    acc = correct / total
    preds = np.concatenate(preds)
    targets = np.concatenate(targets)
    return avg_loss, acc, preds, targets

def log_confusion_matrix(targets, preds, class_labels, step=None, prefix=""):
    # simple confusion matrix logging to W&B (as a table)
    cm = confusion_matrix(targets, preds)
    # Normalize for readability
    cm_norm = cm.astype(float)
    with np.errstate(divide='ignore', invalid='ignore'):
        row_sums = cm_norm.sum(axis=1, keepdims=True)
        cm_norm = np.divide(cm_norm, row_sums, where=row_sums!=0)
    # Log as an image via wandb.plot.confusion_matrix if available
    try:
        wandb.log({f"{prefix}confusion_matrix": wandb.plot.confusion_matrix(probs=None,
                                                                           y_true=targets.tolist(),
                                                                           preds=preds.tolist(),
                                                                           class_names=class_labels)},
                  step=step)
    except Exception:
        # fallback to logging the raw matrix as artifact/table
        wandb.log({f"{prefix}confusion_matrix_array": cm.tolist()}, step=step)


In [5]:
# Cell 7 — Training loop that supports sequential datasets, and forgetting measurement
def run_sequential_experiment(seq, base_seed=42, epochs_per_task=100, batch_size=256, lr=0.01, weight_decay=5e-4):
    """
    seq: list like ["CIFAR100", "CIFAR10"] in the order to train
    """
    set_seed(base_seed)
    # Common initialization: same seed -> same initial weights for both experiments
    initial_model = build_model(num_classes=100)  # create model with largest class count
    init_state = initial_model.state_dict()

    # Initialize a W&B run for this experiment
    run_name = "_then_".join(seq)
    wandb.init(project="cifar-sequential-wandb", name=run_name, config={
        "sequence": seq,
        "epochs_per_task": epochs_per_task,
        "batch_size": batch_size,
        "lr": lr,
        "weight_decay": weight_decay,
        "seed": base_seed,
        "model": "resnet18_cifar"
    })

    model = build_model(num_classes=100).to(device)
    model.load_state_dict(init_state)  # ensure same start across experiments

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
    criterion = nn.CrossEntropyLoss()

    # Keep holdouts for both datasets so we can measure forgetting
    holdouts = {}  # name -> (dataloader, num_classes, class_labels)
    dataloaders = {}  # current dataloaders
    datasets_full = {}

    for name in ("CIFAR10","CIFAR100"):
        train_loader, test_loader, trainset, testset = get_cifar_loaders(name, batch_size=batch_size, augment=True)
        dataloaders[name] = (train_loader, test_loader)
        datasets_full[name] = (trainset, testset)
        # class labels
        if name=="CIFAR10":
            labels = [str(i) for i in range(10)]
        else:
            labels = [str(i) for i in range(100)]
        holdouts[name] = (test_loader, len(labels), labels)

    # Track checkpoint before any training (to measure initial performance)
    performance = {}

    # Evaluate initial model on both testsets (after adapting final layer to dataset num_classes)
    def eval_on_dataset(model, name):
        # Create copy of model and adapt final layer (since we used num_classes=100)
        n_classes = 10 if name=="CIFAR10" else 100
        m = build_model(num_classes=n_classes).to(device)
        # copy all weights except final fc (if shapes differ)
        sd = model.state_dict()
        m_sd = m.state_dict()
        # copy compatible keys
        for k in m_sd:
            if k in sd and sd[k].shape == m_sd[k].shape:
                m_sd[k] = sd[k]
        m.load_state_dict(m_sd)
        return evaluate(m, holdouts[name][0], device)

    # Initial eval
    for name in seq:
        loss, acc, _, _ = eval_on_dataset(model, name)
        wandb.log({f"initial/{name}_loss": loss, f"initial/{name}_acc": acc}, step=0)
        performance[f"initial_{name}"] = (loss, acc)

    global_step = 0
    # Now train sequentially
    for task_idx, task in enumerate(seq):
        train_loader, test_loader = dataloaders[task]
        n_classes = 10 if task=="CIFAR10" else 100

        # If model.fc size != n_classes, replace final layer (fine-tune last layer)
        if model.fc.out_features != n_classes:
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, n_classes).to(device)
            # note: reinit new layer's params (keeps other weights)
        wandb.log({f"task_started": task, "task_index": task_idx}, step=global_step)
        for epoch in range(1, epochs_per_task+1):
            model.train()
            epoch_losses = []
            correct = 0
            total = 0
            for xb, yb in train_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                optimizer.zero_grad()
                out = model(xb)
                loss = criterion(out, yb)
                loss.backward()
                optimizer.step()
                epoch_losses.append(loss.item())
                _, p = out.max(1)
                correct += (p==yb).sum().item()
                total += yb.size(0)
                global_step += 1
            scheduler.step()

            train_loss = float(np.mean(epoch_losses))
            train_acc = correct/total
            val_loss, val_acc, val_preds, val_targets = evaluate(model, test_loader, device)

            # Log to W&B
            wandb.log({
                f"{task}/train_loss": train_loss,
                f"{task}/train_acc": train_acc,
                f"{task}/val_loss": val_loss,
                f"{task}/val_acc": val_acc,
                "epoch": epoch,
            }, step=global_step)

            # Periodic confusion matrix (less frequent to save bandwidth)
            if epoch % 25 == 0 or epoch == epochs_per_task:
                class_labels = [str(i) for i in range(n_classes)]
                log_confusion_matrix(val_targets, val_preds, class_labels, step=global_step, prefix=f"{task}/")

        # After finishing this task, evaluate performance on all tasks (measure forgetting)
        for other in seq:
            loss_o, acc_o, _, _ = eval_on_dataset(model, other)
            wandb.log({f"after_{task}/{other}_loss": loss_o, f"after_{task}/{other}_acc": acc_o}, step=global_step)
            performance[f"after_{task}_{other}"] = (loss_o, acc_o)

        # Save model artifact snapshot
        artifact = wandb.Artifact(f"{run_name}_after_{task}", type="model")
        model_file = f"model_{run_name}_after_{task}.pth"
        torch.save(model.state_dict(), model_file)
        artifact.add_file(model_file)
        wandb.log_artifact(artifact)

    # Final evaluations and summary
    for name in seq:
        loss, acc, _, _ = eval_on_dataset(model, name)
        wandb.log({f"final/{name}_loss": loss, f"final/{name}_acc": acc}, step=global_step)
        performance[f"final_{name}"] = (loss, acc)

    wandb.finish()
    return performance


In [11]:
# Cell 8 — Run both experiments sequentially in the notebook (this will take time — 200 epochs per experiment total)
# If you want to run them in separate Colab sessions / sequentially, comment/uncomment as needed.

# EXPERIMENT A: CIFAR100 -> CIFAR10
perf_A = run_sequential_experiment(["CIFAR100", "CIFAR10"], base_seed=42, epochs_per_task=100, batch_size=256, lr=0.05)

# EXPERIMENT B: CIFAR10 -> CIFAR100
perf_B = run_sequential_experiment(["CIFAR10", "CIFAR100"], base_seed=42, epochs_per_task=100, batch_size=256, lr=0.05)

# Save performance summaries to disk for quick inspection
import json
with open("perf_A.json","w") as f:
    json.dump(perf_A, f)
with open("perf_B.json","w") as f:
    json.dump(perf_B, f)

print("Experiments finished. Check W&B project: cifar-sequential-wandb")


KeyboardInterrupt: 