# GD

In [6]:
import torch
from torch import nn
from torch.optim import SGD
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_labels = mnist_train.targets.numpy()
test_labels  = mnist_test.targets.numpy()

idx_A_train = np.where(train_labels < 5)[0].tolist()
idx_B_train = np.where(train_labels >= 5)[0].tolist()
idx_A_test  = np.where(test_labels  < 5)[0].tolist()

loader_A_train = DataLoader(
    Subset(mnist_train, idx_A_train),
    batch_size=64, shuffle=True, drop_last=True
)
loader_B_train = DataLoader(
    Subset(mnist_train, idx_B_train),
    batch_size=64, shuffle=True, drop_last=True
)
loader_A_test  = DataLoader(
    Subset(mnist_test, idx_A_test),
    batch_size=256, shuffle=False, drop_last=False
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 100), nn.ReLU(),
    nn.Linear(100, 100),    nn.ReLU(),
    nn.Linear(100, 10),
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=1e-3, momentum=0.9)

def eval_accuracy(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total   += y.size(0)
    return correct / total

num_epochs_A = 3
for epoch in range(1, num_epochs_A + 1):
    model.train()
    running_loss = 0.0
    for x, y in loader_A_train:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    avg_loss = running_loss / len(loader_A_train.dataset)
    print(f"[Task A] Epoch {epoch}/{num_epochs_A} — Loss: {avg_loss:.4f}")

acc_before = eval_accuracy(model, loader_A_test)
print(f">>> Task A test accuracy BEFORE Task B: {acc_before*100:.2f}%")

num_epochs_B = 3
for epoch in range(1, num_epochs_B + 1):
    model.train()
    running_loss = 0.0
    for x, y in loader_B_train:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    avg_loss = running_loss / len(loader_B_train.dataset)
    print(f"[Task B] Epoch {epoch}/{num_epochs_B} — Loss: {avg_loss:.4f}")

acc_after = eval_accuracy(model, loader_A_test)
print(f">>> Task A test accuracy AFTER  Task B:  {acc_after*100:.2f}%")


[Task A] Epoch 1/3 — Loss: 0.5934
[Task A] Epoch 2/3 — Loss: 0.1326
[Task A] Epoch 3/3 — Loss: 0.1095
>>> Task A test accuracy BEFORE Task B: 97.55%
[Task B] Epoch 1/3 — Loss: 0.7293
[Task B] Epoch 2/3 — Loss: 0.2079
[Task B] Epoch 3/3 — Loss: 0.1673
>>> Task A test accuracy AFTER  Task B:  0.00%


# OGD-GTL

In [7]:
import torch
from torch import nn
from torch.optim.optimizer import Optimizer
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

# Orthogonal Gradient Descent (OGD) Utilities
def gram_schmidt(vectors, eps=1e-10):
    """Orthonormalize a list of gradient vectors via Gram-Schmidt."""
    ortho = []
    for v in vectors:
        w = v.clone()
        for u in ortho:
            w -= (u @ v) / (u @ u + eps) * u
        norm = w.norm()
        if norm > eps:
            ortho.append(w / norm)
    return ortho

class OGD(Optimizer):
    """Orthogonal Gradient Descent optimizer."""
    def __init__(self, params, lr=1e-3):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)
        self.directions = []

    @torch.no_grad()
    def step(self, closure=None):
        loss = closure() if closure is not None else None
        grads = [p.grad.view(-1) for p in self.param_groups[0]['params'] if p.grad is not None]
        if not grads:
            return loss
        g = torch.cat(grads)
        if self.directions:
            V = torch.stack(self.directions)
            alphas = V.mv(g)
            g = g - V.t().mv(alphas)
        offset = 0
        for p in self.param_groups[0]['params']:
            if p.grad is None:
                continue
            numel = p.numel()
            p.grad.copy_(g[offset:offset+numel].view_as(p))
            offset += numel
        lr = self.param_groups[0]['lr']
        for p in self.param_groups[0]['params']:
            if p.grad is not None:
                p.add_(p.grad, alpha=-lr)
        return loss

    def store_directions(self, model, dataloader, device='cpu', max_samples=2000):
        """Compute and store orthonormal gradient directions from dataloader."""
        model.eval()
        collected = []
        seen = 0
        for x, y in dataloader:
            if seen >= max_samples:
                break
            x, y = x.to(device), y.to(device)
            model.zero_grad()
            with torch.enable_grad():
                logits = model(x)
                true_logits = logits[torch.arange(len(y)), y].sum()
                true_logits.backward()
            vec = [p.grad.view(-1).detach().cpu() for p in model.parameters()]
            collected.append(torch.cat(vec))
            seen += x.size(0)
        new_dirs = gram_schmidt(collected)
        self.directions = gram_schmidt(self.directions + new_dirs)

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
mnist_test  = datasets.MNIST('./data', train=False, download=True, transform=transform)

y_train = np.array(mnist_train.targets)
y_test  = np.array(mnist_test.targets)
idx_A_train = np.where(y_train < 5)[0]
idx_B_train = np.where(y_train >= 5)[0]
idx_A_test  = np.where(y_test  < 5)[0]

loader_A_train = DataLoader(Subset(mnist_train, idx_A_train), batch_size=64, shuffle=True)
loader_B_train = DataLoader(Subset(mnist_train, idx_B_train), batch_size=64, shuffle=True)
loader_A_test  = DataLoader(Subset(mnist_test,  idx_A_test),  batch_size=256, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 100), nn.ReLU(),
    nn.Linear(100, 100),   nn.ReLU(),
    nn.Linear(100, 10)
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = OGD(model.parameters(), lr=1e-3)

def eval_accuracy(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total   += y.size(0)
    return correct / total

# === TRAIN Task A ===
print("Training Task A (digits 0-4)...")
for epoch in range(1, 4):
    model.train()
    cum_loss = 0
    for x, y in loader_A_train:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        cum_loss += loss.item() * x.size(0)
    print(f"[Task A] Epoch {epoch}/3 — Loss: {cum_loss/len(loader_A_train.dataset):.4f}")

# Eval before Task B
acc_before = eval_accuracy(model, loader_A_test)
print(f"Task A accuracy BEFORE Task B: {acc_before*100:.2f}%")

# Store Task A directions
optimizer.store_directions(model, loader_A_train, device=device, max_samples=2000)
print(f"Stored {len(optimizer.directions)} orthonormal directions from Task A.")

# Adjust LR for Task B
optimizer.param_groups[0]['lr'] = 1e-4

# === TRAIN Task B without rehearsal ===
print("Training Task B (digits 5-9) with OGD projection...")
for epoch in range(1, 4):
    model.train()
    cum_loss = 0
    for x, y in loader_B_train:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        cum_loss += loss.item() * x.size(0)
    print(f"[Task B] Epoch {epoch}/3 — Loss: {cum_loss/len(loader_B_train.dataset):.4f}")

# Eval after Task B
acc_after = eval_accuracy(model, loader_A_test)
print(f"Task A accuracy AFTER Task B: {acc_after*100:.2f}%")


Training Task A (digits 0-4)...
[Task A] Epoch 1/3 — Loss: 1.9245
[Task A] Epoch 2/3 — Loss: 0.9401
[Task A] Epoch 3/3 — Loss: 0.4406
Task A accuracy BEFORE Task B: 93.75%
Stored 32 orthonormal directions from Task A.
Training Task B (digits 5-9) with OGD projection...
[Task B] Epoch 1/3 — Loss: 5.3471
[Task B] Epoch 2/3 — Loss: 3.8714
[Task B] Epoch 3/3 — Loss: 2.7151
Task A accuracy AFTER Task B: 93.05%
