In [1]:
!pip install matplotlib numpy torchvision tqdm

Collecting matplotlib
  Downloading matplotlib-3.10.6-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.7/57.7 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.59.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (109 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.7/109.7 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x8

In [2]:
# ================================================
# OGD on Split CIFAR-100 (10 tasks × 10 classes)
# ================================================
# - Memory-efficient task construction
# - OGD with orthonormal gradient memory
# - Stores optimizer/criterion in OGD
# - Accuracy matrix + CL metrics + plot
# ================================================

import os
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

# -----------------
# Repro & Device
# -----------------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -----------------
# Hyperparameters
# -----------------
root = "./data"
num_tasks = 10
num_classes = 100
classes_per_task = num_classes // num_tasks  # 10
batch_size = 32
num_epochs = 2
download = True

# Optimizer/loss
learning_rate = 0.001
weight_decay = 5e-4
momentum = 0.9

# OGD memory
max_mem_dirs = 1000    # cap on number of stored gradient directions (global)
dirs_per_task = 120   # target number of new directions to add per task
harvest_batches = 30  # batches to sample for memory after each task
grad_eps = 1e-6       # min norm to accept a direction

# -----------------
# Transforms
# -----------------
normalize = transforms.Normalize(mean=(0.5071, 0.4867, 0.4408),
                                 std=(0.2675, 0.2565, 0.2761))
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

# -----------------
# Datasets
# -----------------
train_full = datasets.CIFAR100(root=root, train=True,  download=download, transform=train_transform)
test_full  = datasets.CIFAR100(root=root, train=False, download=download, transform=test_transform)

train_targets = np.array(train_full.targets)
test_targets  = np.array(test_full.targets)

# -----------------
# Task splits (indices & class lists)
# -----------------
task_class_lists = []
train_indices_per_task, test_indices_per_task = [], []

for t in range(num_tasks):
    cls_start = t * classes_per_task
    cls_end = cls_start + classes_per_task
    task_classes = list(range(cls_start, cls_end))
    task_class_lists.append(task_classes)

    train_idx = np.where(np.isin(train_targets, task_classes))[0].tolist()
    test_idx  = np.where(np.isin(test_targets,  task_classes))[0].tolist()

    train_indices_per_task.append(train_idx)
    test_indices_per_task.append(test_idx)

    print(f"Task {t}: classes {task_classes[0]}..{task_classes[-1]} | "
          f"train {len(train_idx)}, test {len(test_idx)}")

# -----------------
# Per-task label mapping dataset
# -----------------
class MapLabelsDataset(Dataset):
    def __init__(self, base_dataset, indices, class_map):
        self.base = base_dataset
        self.indices = indices
        self.class_map = class_map

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        x, y = self.base[self.indices[i]]
        return x, self.class_map[int(y)]

# -----------------
# WideResNet (WRN-28-10)
# -----------------
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, 3, 1, 1, bias=False)
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = None if self.equalInOut else nn.Conv2d(
            in_planes, out_planes, 1, stride, 0, bias=False
        )

    def forward(self, x):
        out = self.relu1(self.bn1(x))
        shortcut = x if self.equalInOut else self.convShortcut(x)
        out = self.conv1(out)
        out = self.relu2(self.bn2(out))
        out = self.conv2(out)
        return out + shortcut

class NetworkBlock(nn.Module):
    def __init__(self, n, in_planes, out_planes, block, stride):
        super().__init__()
        layers = []
        for i in range(n):
            layers.append(
                block(in_planes if i == 0 else out_planes,
                      out_planes,
                      stride if i == 0 else 1)
            )
        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, depth=28, widen_factor=10, num_classes=10):
        super().__init__()
        assert (depth - 4) % 6 == 0
        n = (depth - 4) // 6
        k = widen_factor
        nChannels = [16, 16*k, 32*k, 64*k]

        self.conv1 = nn.Conv2d(3, nChannels[0], 3, 1, 1, bias=False)
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], BasicBlock, 1)
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], BasicBlock, 2)
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], BasicBlock, 2)
        self.bn = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

    def forward(self, x):
        x = self.conv1(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.relu(self.bn(x))
        x = F.adaptive_avg_pool2d(x, 1).view(-1, self.nChannels)
        return self.fc(x)

"""
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))   # 32x16x16
        x = self.pool(F.relu(self.bn2(self.conv2(x))))   # 64x8x8
        x = self.pool(F.relu(self.bn3(self.conv3(x))))   # 128x4x4
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)
"""

# -----------------
# Utils to flatten/assign grads
# -----------------
def flatten_grads(model):
    parts = []
    for p in model.parameters():
        if p.requires_grad and p.grad is not None:
            parts.append(p.grad.view(-1))
    return torch.cat(parts) if parts else torch.tensor([], device=next(model.parameters()).device)

def assign_grads_from_vector(model, grad_vec):
    offset = 0
    for p in model.parameters():
        if not p.requires_grad:
            continue
        n = p.numel()
        if p.grad is None:
            p.grad = torch.zeros_like(p)
        p.grad.copy_(grad_vec[offset:offset+n].view_as(p))
        offset += n

# -----------------
# OGD (Option 2: stores optimizer & criterion internally)
# -----------------
class OGD:
    def __init__(self, model, optimizer, criterion, device,
                 max_mem_dirs=1000, grad_eps=1e-6):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.P = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        self.memory = None   # [k, P], orthonormal rows (unit-norm)
        self.max_mem_dirs = max_mem_dirs
        self.grad_eps = grad_eps

    @torch.no_grad()
    def _project_onto_complement(self, g):
        # g: [P]
        if self.memory is None or self.memory.size(0) == 0:
            return g
        # Memory rows are orthonormal ⇒ projection is g_perp = g - M^T (M g)
        Mg = torch.mv(self.memory, g)            # [k]
        g_perp = g - torch.mv(self.memory.t(), Mg)
        return g_perp

    @torch.no_grad()
    def _add_dir_to_memory(self, g):
        # Gram-Schmidt orth against existing memory; if large enough, normalize & append
        if self.memory is not None and self.memory.size(0) > 0:
            Mg = torch.mv(self.memory, g)
            g = g - torch.mv(self.memory.t(), Mg)
        norm = torch.linalg.norm(g)
        if norm > self.grad_eps:
            g = g / norm
            if self.memory is None:
                self.memory = g.unsqueeze(0)
            else:
                if self.memory.size(0) < self.max_mem_dirs:
                    self.memory = torch.vstack([self.memory, g])
                else:
                    # Replace a random row to maintain diversity (FIFO/random policy)
                    idx = torch.randint(0, self.memory.size(0), (1,)).item()
                    self.memory[idx] = g

    def observe(self, x, y):
        """One OGD training step on a batch."""
        self.model.train()
        x, y = x.to(self.device), y.to(self.device)

        self.optimizer.zero_grad(set_to_none=True)
        logits = self.model(x)
        loss = self.criterion(logits, y)
        loss.backward()

        # Flatten grads → project → assign → step
        g = flatten_grads(self.model).detach()
        g_perp = self._project_onto_complement(g)
        assign_grads_from_vector(self.model, g_perp)
        self.optimizer.step()
        return loss.item()

    def end_task(self, dataloader, dirs_to_add=100, harvest_batches=30):
        """Harvest gradient directions (from current task) to expand memory."""
        self.model.train()
        for p in self.model.parameters():
            p.requires_grad_(True)
        torch.set_grad_enabled(True)

        added, seen = 0, 0
        for xb, yb in dataloader:
            if seen >= harvest_batches or added >= dirs_to_add:
                break
            seen += 1

            xb, yb = xb.to(self.device), yb.to(self.device)
            self.model.zero_grad(set_to_none=True)
            logits = self.model(xb)
            loss = self.criterion(logits, yb)
            loss.backward()

            g = flatten_grads(self.model).detach()
            self._add_dir_to_memory(g)
            added += 1

        print(f"[OGD] Harvested {added} dirs (seen {seen} batches). "
              f"Memory size: {0 if self.memory is None else self.memory.size(0)}")

# -----------------
# Eval
# -----------------
@torch.no_grad()
def evaluate(model, dataloader, device):
    model.eval()
    correct, total = 0, 0
    for xb, yb in dataloader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb).argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return correct / max(1, total)

# -----------------
# Init model/optimizer/criterion/OGD
# -----------------
"""
model = SimpleCNN(num_classes=classes_per_task).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

ogd = OGD(model, optimizer, criterion, device,
          max_mem_dirs=max_mem_dirs, grad_eps=grad_eps)
"""

# -----------------
# Init model/optimizer/criterion/OGD
# -----------------
model = WideResNet(depth=28, widen_factor=10, num_classes=classes_per_task).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,
                            momentum=momentum, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
ogd = OGD(model, optimizer, criterion, device,
          max_mem_dirs=max_mem_dirs, grad_eps=grad_eps)

# -----------------
# Train across tasks
# -----------------
accuracy_matrix = np.zeros((num_tasks, num_tasks), dtype=np.float32)

loader_kwargs = dict(num_workers=2, pin_memory=True) if device.type == "cuda" else dict(num_workers=0)

for t in range(num_tasks):
    # Datasets & loaders for task t (labels remapped to 0..9)
    class_map = {orig: i for i, orig in enumerate(task_class_lists[t])}
    train_ds_t = MapLabelsDataset(train_full, train_indices_per_task[t], class_map)
    test_ds_t  = MapLabelsDataset(test_full,  test_indices_per_task[t],  class_map)

    train_loader = DataLoader(train_ds_t, batch_size=batch_size, shuffle=True,  **loader_kwargs)
    test_loader  = DataLoader(test_ds_t,  batch_size=batch_size, shuffle=False, **loader_kwargs)

    print(f"\n=== Task {t+1}/{num_tasks} | classes {task_class_lists[t][0]}..{task_class_lists[t][-1]} ===")
    for epoch in range(1, num_epochs + 1):
        running = 0.0
        for xb, yb in tqdm(train_loader, desc=f"Task {t+1} Epoch {epoch}"):
            running += ogd.observe(xb, yb)
        print(f"Epoch {epoch}: loss={running/len(train_loader):.4f}")

    # Evaluate on all seen tasks so far
    for j in range(t + 1):
        class_map_eval = {orig: i for i, orig in enumerate(task_class_lists[j])}
        test_ds_eval = MapLabelsDataset(test_full, test_indices_per_task[j], class_map_eval)
        test_loader_eval = DataLoader(test_ds_eval, batch_size=batch_size, shuffle=False, **loader_kwargs)
        acc = evaluate(ogd.model, test_loader_eval, device)
        accuracy_matrix[t, j] = acc
        print(f"Accuracy on Task {j+1}: {acc:.3f}")

    # Grow OGD memory from current task data
    ogd.end_task(train_loader, dirs_to_add=dirs_per_task, harvest_batches=harvest_batches)


# Compute Continual Learning Metrics
# Average Accuracy (ACC)
ACC = accuracy_matrix[-1].mean()  # Final row averaged across tasks

# Average Forgetting (F)
F = np.mean([
    np.max(accuracy_matrix[:num_tasks-1, j]) - accuracy_matrix[-1, j]
    for j in range(num_tasks-1)
])

# Backward Transfer (BWT)
BWT = np.mean([
    accuracy_matrix[-1, j] - accuracy_matrix[j, j]
    for j in range(num_tasks-1)
])

# Forward Transfer (FWT)
# Measures how much previous tasks helped the next task before it was trained
FWT = np.mean([
    accuracy_matrix[i, i+1]
    for i in range(num_tasks-1)
])

# Memory Usage (in MB)
num_params = sum(p.numel() for p in ogd.model.parameters())
dirs_count = 0 if ogd.memory is None else ogd.memory.size(0)
mem_usage = dirs_count * num_params * 4 / (1024**2)  # 4 bytes per float32

# Computation Cost (approximate as #dirs, since each step projects on memory)
comp_cost = dirs_count

# Plasticity-Stability Measure (PSM) - Normalized 0 to 1

# Define components
stability = 1 - F                 # High if forgetting is low
plasticity = max(FWT, 0)          # High if positive forward transfer

alpha = 0.5                       # Balance between stability and plasticity
PSM = alpha * stability + (1 - alpha) * plasticity

# Print Metrics
print("=== Continual Learning Metrics ===")
print(f"Average Accuracy (ACC):       {ACC:.4f}")
print(f"Forgetting (F):              {F:.4f}")
print(f"Backward Transfer (BWT):     {BWT:.4f}")
print(f"Forward Transfer (FWT):      {FWT:.4f}")
print(f"Memory Usage:                {mem_usage:.2f} MB")
print(f"Computation Cost:            {comp_cost} projections/batch")
print(f"Plasticity-Stability Measure (PSM): {PSM:.4f} (0-1 normalized)")

print("=== Metrics (Split CIFAR-100) ===")
print(f"ACC={ACC:.4f}, F={F:.4f}, BWT={BWT:.4f}, FWT={FWT:.4f}, Mem={mem_usage:.2f}MB, PSM={PSM:.4f}")

plt.figure(figsize=(6,5))
plt.imshow(accuracy_matrix, cmap='viridis', interpolation='nearest')
plt.colorbar(label='Accuracy')
plt.xlabel('Evaluation Task')
plt.ylabel('Training Task')
plt.title('OGD Accuracy Matrix (Split CIFAR-100)')
plt.show()

Device: cuda
Files already downloaded and verified
Files already downloaded and verified
Task 0: classes 0..9 | train 5000, test 1000
Task 1: classes 10..19 | train 5000, test 1000
Task 2: classes 20..29 | train 5000, test 1000
Task 3: classes 30..39 | train 5000, test 1000
Task 4: classes 40..49 | train 5000, test 1000
Task 5: classes 50..59 | train 5000, test 1000
Task 6: classes 60..69 | train 5000, test 1000
Task 7: classes 70..79 | train 5000, test 1000
Task 8: classes 80..89 | train 5000, test 1000
Task 9: classes 90..99 | train 5000, test 1000

=== Task 1/10 | classes 0..9 ===


Task 1 Epoch 1: 100%|██████████| 157/157 [00:07<00:00, 19.69it/s]


Epoch 1: loss=1.9256


Task 1 Epoch 2: 100%|██████████| 157/157 [00:06<00:00, 22.60it/s]

Epoch 2: loss=1.5942





Accuracy on Task 1: 0.506
[OGD] Harvested 30 dirs (seen 30 batches). Memory size: 30

=== Task 2/10 | classes 10..19 ===


Task 2 Epoch 1: 100%|██████████| 157/157 [00:09<00:00, 16.42it/s]


Epoch 1: loss=1.8596


Task 2 Epoch 2: 100%|██████████| 157/157 [00:09<00:00, 16.50it/s]

Epoch 2: loss=1.4806





Accuracy on Task 1: 0.119
Accuracy on Task 2: 0.517
[OGD] Harvested 30 dirs (seen 30 batches). Memory size: 60

=== Task 3/10 | classes 20..29 ===


Task 3 Epoch 1: 100%|██████████| 157/157 [00:12<00:00, 12.91it/s]


Epoch 1: loss=1.6873


Task 3 Epoch 2: 100%|██████████| 157/157 [00:12<00:00, 12.99it/s]

Epoch 2: loss=1.2760





Accuracy on Task 1: 0.157
Accuracy on Task 2: 0.122
Accuracy on Task 3: 0.560
[OGD] Harvested 30 dirs (seen 30 batches). Memory size: 90

=== Task 4/10 | classes 30..39 ===


Task 4 Epoch 1: 100%|██████████| 157/157 [00:14<00:00, 10.74it/s]


Epoch 1: loss=1.6177


Task 4 Epoch 2: 100%|██████████| 157/157 [00:14<00:00, 10.80it/s]

Epoch 2: loss=1.2175





Accuracy on Task 1: 0.071
Accuracy on Task 2: 0.160
Accuracy on Task 3: 0.122
Accuracy on Task 4: 0.607
[OGD] Harvested 30 dirs (seen 30 batches). Memory size: 120

=== Task 5/10 | classes 40..49 ===


Task 5 Epoch 1: 100%|██████████| 157/157 [00:17<00:00,  9.15it/s]


Epoch 1: loss=1.4118


Task 5 Epoch 2: 100%|██████████| 157/157 [00:17<00:00,  9.16it/s]

Epoch 2: loss=1.0027





Accuracy on Task 1: 0.126
Accuracy on Task 2: 0.111
Accuracy on Task 3: 0.085
Accuracy on Task 4: 0.159
Accuracy on Task 5: 0.673


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.25 GiB. GPU 0 has a total capacty of 44.45 GiB of which 20.05 GiB is free. Process 2969108 has 24.39 GiB memory in use. Of the allocated memory 20.82 GiB is allocated by PyTorch, and 3.23 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [4]:
# === Cell: prepare CIFAR-10 tasks for OGD (split) with CNN ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Hyperparameters ===
root = './data'
num_tasks = 5            
num_classes = 10         
classes_per_task = num_classes // num_tasks   # = 2
batch_size = 32
download = True
num_epochs = 2
learning_rate = 0.001
mem_size = 200   # how many gradient directions to store per task

# CIFAR-10 mean/std for normalization
mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)

# === Transforms ===
train_transform = transforms.ToTensor()
test_transform = transforms.ToTensor()

# === Load datasets ===
train_ds = datasets.CIFAR10(root=root, train=True, download=download, transform=train_transform)
test_ds = datasets.CIFAR10(root=root, train=False, download=download, transform=test_transform)

# === Helper: extract subset tensors for a given set of class IDs ===
def extract_subset_tensors(dataset, class_list):
    targets = np.array(dataset.targets)
    mask = np.isin(targets, class_list)
    indices = np.nonzero(mask)[0].tolist()
    imgs, labs = [], []
    for i in indices:
        img, lbl = dataset[i]
        imgs.append(img)
        labs.append(lbl)
    return torch.stack(imgs), torch.tensor(labs, dtype=torch.long)

# === Build tasks ===
train_tasks, test_tasks = [], []
for t in range(num_tasks):
    cls_start = t * classes_per_task
    cls_end = cls_start + classes_per_task
    task_classes = list(range(cls_start, cls_end))

    x_train, y_train = extract_subset_tensors(train_ds, task_classes)
    x_test, y_test = extract_subset_tensors(test_ds, task_classes)

    # Normalize
    x_train = (x_train - mean) / std
    x_test = (x_test - mean) / std

    # Map labels to 0..(classes_per_task-1)
    class_map = {orig: i for i, orig in enumerate(task_classes)}
    y_train_mapped = torch.tensor([class_map[int(v)] for v in y_train])
    y_test_mapped = torch.tensor([class_map[int(v)] for v in y_test])

    train_tasks.append(TensorDataset(x_train, y_train_mapped))
    test_tasks.append(TensorDataset(x_test, y_test_mapped))

    print(f"Task {t}: classes {task_classes[0]}-{task_classes[-1]}, train={len(x_train)}, test={len(x_test)}")

print(f"Prepared {len(train_tasks)} tasks (Split CIFAR-10)")

"""
# === CNN model ===
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):   # 2 classes per task
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 32x16x16
        x = self.pool(F.relu(self.conv2(x)))  # 64x8x8
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)
"""

# -----------------
# WideResNet (WRN-28-10)
# -----------------
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, 3, 1, 1, bias=False)
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = None if self.equalInOut else nn.Conv2d(
            in_planes, out_planes, 1, stride, 0, bias=False
        )

    def forward(self, x):
        out = self.relu1(self.bn1(x))
        shortcut = x if self.equalInOut else self.convShortcut(x)
        out = self.conv1(out)
        out = self.relu2(self.bn2(out))
        out = self.conv2(out)
        return out + shortcut

class NetworkBlock(nn.Module):
    def __init__(self, n, in_planes, out_planes, block, stride):
        super().__init__()
        layers = []
        for i in range(n):
            layers.append(
                block(in_planes if i == 0 else out_planes,
                      out_planes,
                      stride if i == 0 else 1)
            )
        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, depth=28, widen_factor=10, num_classes=10):
        super().__init__()
        assert (depth - 4) % 6 == 0
        n = (depth - 4) // 6
        k = widen_factor
        nChannels = [16, 16*k, 32*k, 64*k]

        self.conv1 = nn.Conv2d(3, nChannels[0], 3, 1, 1, bias=False)
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], BasicBlock, 1)
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], BasicBlock, 2)
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], BasicBlock, 2)
        self.bn = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

    def forward(self, x):
        x = self.conv1(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.relu(self.bn(x))
        x = F.adaptive_avg_pool2d(x, 1).view(-1, self.nChannels)
        return self.fc(x)


# === OGD class ===
class OGD:
    def __init__(self, model, lr=0.001, device="cpu", mem_size=200):
        self.model = model.to(device)
        self.opt = optim.Adam(self.model.parameters(), lr=lr)
        self.device = device
        self.S = []  # stored gradient directions
        self.mem_size = mem_size

    def project(self, grad_vec):
        if not self.S:
            return grad_vec
        proj_grad = grad_vec.clone()
        for g in self.S:
            proj_grad -= (proj_grad @ g) * g
        return proj_grad

    def observe(self, x, y):
        self.model.train()
        x, y = x.to(self.device), y.to(self.device)
        self.opt.zero_grad()
        loss = F.cross_entropy(self.model(x), y)
        loss.backward()

        grad_vec = torch.cat([p.grad.view(-1) for p in self.model.parameters()])
        grad_proj = self.project(grad_vec)
        idx = 0
        for p in self.model.parameters():
            numel = p.numel()
            p.grad.copy_(grad_proj[idx:idx+numel].view_as(p))
            idx += numel

        self.opt.step()
        return loss.item()

    def end_task(self, dataloader):
        self.model.eval()
        grads = []
        for x, y in dataloader:
            x, y = x.to(self.device), y.to(self.device)
            self.opt.zero_grad()
            loss = F.cross_entropy(self.model(x), y)
            loss.backward()
            g = torch.cat([p.grad.view(-1) for p in self.model.parameters()])
            grads.append(g / g.norm())
            if len(grads) >= self.mem_size:
                break
        if grads:
            mean_g = torch.stack(grads).mean(0)
            mean_g /= mean_g.norm()
            self.S.append(mean_g)
            if len(self.S) > self.mem_size:
                self.S.pop(0)

# === Training loop with OGD ===
# ogd = OGD(SimpleCNN(num_classes=classes_per_task), lr=learning_rate, device=device, mem_size=mem_size)
ogd = OGD(WideResNet(depth=28, widen_factor=10, num_classes=classes_per_task),
          lr=learning_rate, device=device, mem_size=mem_size)
accuracy_matrix = np.zeros((num_tasks, num_tasks))

def evaluate_task(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

for task_id, train_dataset in enumerate(train_tasks):
    print(f"\n=== Training Task {task_id+1} ===")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for x, y in tqdm(train_loader):
            epoch_loss += ogd.observe(x, y)
        print(f"Epoch {epoch+1}, Loss={epoch_loss/len(train_loader):.4f}")

    # Evaluate on all seen tasks
    for eval_id in range(task_id+1):
        test_loader = DataLoader(test_tasks[eval_id], batch_size=batch_size, shuffle=False)
        acc = evaluate_task(ogd.model, test_loader)
        accuracy_matrix[task_id, eval_id] = acc
        print(f"Accuracy on Task {eval_id+1}: {acc:.3f}")

    ogd.end_task(train_loader)

# === Metrics ===
ACC = accuracy_matrix[-1].mean()
F = np.mean([np.max(accuracy_matrix[:num_tasks-1, j]) - accuracy_matrix[-1, j] for j in range(num_tasks-1)]) if num_tasks>1 else 0.0
BWT = np.mean([accuracy_matrix[-1, j] - accuracy_matrix[j, j] for j in range(num_tasks-1)]) if num_tasks>1 else 0.0
FWT = np.mean([accuracy_matrix[i, i+1] for i in range(num_tasks-1)]) if num_tasks>1 else 0.0
num_params = sum(p.numel() for p in ogd.model.parameters())
mem_usage = len(ogd.S) * num_params * 4 / (1024**2)
PSM = 0.5*(1-F) + 0.5*max(FWT,0)

print("=== Metrics (Split CIFAR-10 CNN OGD) ===")
print(f"ACC={ACC:.4f}, F={F:.4f}, BWT={BWT:.4f}, FWT={FWT:.4f}, Mem={mem_usage:.2f}MB, PSM={PSM:.4f}")

plt.figure(figsize=(6,5))
plt.imshow(accuracy_matrix, cmap='viridis', interpolation='nearest')
plt.colorbar(label='Accuracy')
plt.xlabel('Evaluation Task')
plt.ylabel('Training Task')
plt.title('OGD Accuracy Matrix (Split CIFAR-10 CNN)')
plt.show()

Files already downloaded and verified
Files already downloaded and verified
Task 0: classes 0-1, train=10000, test=2000
Task 1: classes 2-3, train=10000, test=2000
Task 2: classes 4-5, train=10000, test=2000
Task 3: classes 6-7, train=10000, test=2000
Task 4: classes 8-9, train=10000, test=2000
Prepared 5 tasks (Split CIFAR-10)

=== Training Task 1 ===


100%|██████████| 313/313 [00:15<00:00, 20.18it/s]


Epoch 1, Loss=0.3703


100%|██████████| 313/313 [00:15<00:00, 20.13it/s]


Epoch 2, Loss=0.2434
Accuracy on Task 1: 0.871


OutOfMemoryError: CUDA out of memory. Tried to allocate 140.00 MiB. GPU 0 has a total capacty of 44.45 GiB of which 14.62 MiB is free. Process 2307347 has 44.43 GiB memory in use. Of the allocated memory 41.47 GiB is allocated by PyTorch, and 2.61 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF