# 

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import math

# -------------------------------
# 1. Data Preparation
# -------------------------------

# Transform: convert to tensor and normalize.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Download CIFAR-10 training and test datasets.
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True, transform=transform)

# Helper function: Given a dataset and a list of target classes, return a Subset.
def get_subset_by_classes(dataset, class_list):
    indices = [i for i, (img, label) in enumerate(dataset) if label in class_list]
    return Subset(dataset, indices)

# Define our 5 tasks: each task is a pair of classes.
tasks = [
    [0, 1],  # airplane, automobile
    [2, 3],  # bird, cat
    [4, 5],  # deer, dog
    [6, 7],  # frog, horse
    [8, 9]   # ship, truck
]

# -------------------------------
# 2. Define the MLP and Linear Model
# -------------------------------

# MLP that returns both output and hidden layer activations
class MLP(nn.Module):
    def __init__(self, input_size=3*32*32, hidden_size=512, num_classes=10):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        # Flatten the input: [batch, 3,32,32] -> [batch, 3072]
        x = x.view(x.size(0), -1)
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        out = self.fc3(h2)
        return out, h1, h2  # returning activations for layers fc1 and fc2

# A simple linear model (single linear layer)
class LinearModel(nn.Module):
    def __init__(self, input_size=3*32*32, num_classes=10):
        super(LinearModel, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)

# -------------------------------
# 3. Utility Functions for Continual Metrics
# -------------------------------

# Compute effective rank of a representation (assumes activations is 2D: [N, D])
def compute_effective_rank(activations):
    # activations: [N, D]
    # Use torch.svd (or torch.linalg.svd in recent versions)
    try:
        u, s, vh = torch.svd(activations)
    except RuntimeError:
        # Fall back to torch.linalg.svd if necessary
        s = torch.linalg.svdvals(activations)
    s_sum = torch.sum(s) + 1e-8
    p = s / s_sum
    erank = torch.exp(-torch.sum(p * torch.log(p + 1e-8)))
    return erank.item()

# Compute contribution utility for a layer.
# activations: [N, D] for the given layer.
# next_weight: weight matrix of the next layer, shape [out_dim, D].
def compute_contrib_utility(activations, next_weight):
    mean_abs = torch.mean(torch.abs(activations), dim=0)  # [D]
    sum_out = torch.sum(torch.abs(next_weight), dim=0)      # [D]
    return mean_abs * sum_out  # [D]

# Compute adaptation utility for a layer given its incoming weight matrix.
# weight_in: shape [D, in_features]
def compute_adaptation_utility(weight_in):
    sum_in = torch.sum(torch.abs(weight_in), dim=1)  # [D]
    return 1.0 / (sum_in + 1e-8)

# Overall utility: here defined as product of contribution and adaptation utilities.
def compute_overall_utility(contrib, adaptation):
    return contrib * adaptation

# Exponential moving average update.
def update_running_avg(old, current, eta):
    if old is None:
        return current
    else:
        return (1 - eta) * current + eta * old

# Bias-correct the running average given the age (number of updates).
def bias_corrected(running_avg, age, eta):
    # age is assumed to be an integer or a torch scalar.
    denom = 1 - eta ** (age + 1)
    return running_avg / (denom + 1e-8)

# -------------------------------
# 4. Metric Measurement Routine
# -------------------------------

def measure_metrics(model, dataloader, device, eta, layer_names, running_avgs, ages):
    """
    For each layer in layer_names (e.g., ['fc1','fc2']), this function:
     - Collects activations from the model (by running through dataloader)
     - Computes effective rank of the activations.
     - Computes contribution utility using the next layer's weights.
     - Computes adaptation utility from the layer's incoming weights.
     - Computes overall utility = contrib * adaptation.
     - Updates a running average for contribution and overall utilities using exponential smoothing.
     - Increments the "age" (number of epochs seen).
     - Computes bias-corrected running average utilities.
     
    running_avgs is a dict with keys: 'contrib' and 'overall', each mapping layer name -> tensor of shape [D] (or None initially).
    ages is a dict mapping layer name -> age (integer).
    
    Returns a dict of metrics and updated running_avgs and ages.
    """
    # We will aggregate activations for each layer over the dataloader.
    activations = {ln: [] for ln in layer_names}
    model.eval()
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            out, h1, h2 = model(images)
            # Save activations for each layer:
            activations['fc1'].append(h1.cpu())
            activations['fc2'].append(h2.cpu())
    # Concatenate activations along the batch dimension.
    for ln in layer_names:
        activations[ln] = torch.cat(activations[ln], dim=0)  # shape: [N, D]
    
    metrics = {}
    # Effective rank of each layer's representation.
    for ln in layer_names:
        er = compute_effective_rank(activations[ln])
        metrics[f'erank_{ln}'] = er

    # For contribution utility we need the weight matrices of the next layer.
    # For fc1, next layer is fc2; for fc2, next layer is fc3.
    with torch.no_grad():
        # Assume model has attributes fc1, fc2, fc3.
        # Contribution utility:
        contrib_fc1 = compute_contrib_utility(activations['fc1'], model.fc2.weight.data.cpu())
        contrib_fc2 = compute_contrib_utility(activations['fc2'], model.fc3.weight.data.cpu())
        # Adaptation utility:
        # For fc1, incoming weights: model.fc1.weight, shape [hidden, input]
        adapt_fc1 = compute_adaptation_utility(model.fc1.weight.data.cpu())
        # For fc2, incoming weights: model.fc2.weight, shape [hidden, hidden]
        adapt_fc2 = compute_adaptation_utility(model.fc2.weight.data.cpu())
        # Overall utility:
        overall_fc1 = compute_overall_utility(contrib_fc1, adapt_fc1)
        overall_fc2 = compute_overall_utility(contrib_fc2, adapt_fc2)
        
    # Update running averages and ages for each layer.
    for ln, current_contrib in zip(layer_names, [contrib_fc1, contrib_fc2]):
        # running average for contribution:
        if running_avgs['contrib'].get(ln) is None:
            running_avgs['contrib'][ln] = current_contrib.clone()
        else:
            running_avgs['contrib'][ln] = update_running_avg(running_avgs['contrib'][ln], current_contrib, eta)
        # running average for overall:
        current_overall = overall_fc1 if ln=='fc1' else overall_fc2
        if running_avgs['overall'].get(ln) is None:
            running_avgs['overall'][ln] = current_overall.clone()
        else:
            running_avgs['overall'][ln] = update_running_avg(running_avgs['overall'][ln], current_overall, eta)
        # Increase age
        ages[ln] = ages.get(ln, 0) + 1
        # Bias-corrected running averages:
        bc_contrib = bias_corrected(running_avgs['contrib'][ln], ages[ln], eta)
        bc_overall = bias_corrected(running_avgs['overall'][ln], ages[ln], eta)
        metrics[f'avg_contrib_{ln}'] = torch.mean(running_avgs['contrib'][ln]).item()
        metrics[f'avg_overall_{ln}'] = torch.mean(running_avgs['overall'][ln]).item()
        metrics[f'bc_avg_contrib_{ln}'] = torch.mean(bc_contrib).item()
        metrics[f'bc_avg_overall_{ln}'] = torch.mean(bc_overall).item()
    return metrics, running_avgs, ages

# -------------------------------
# 5. Training and Evaluation Functions (Same as before)
# -------------------------------

def train_epoch(model, optimizer, criterion, dataloader, device):
    model.train()
    running_loss = 0.0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, _, _ = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs, _, _ = model(images)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# -------------------------------
# 6. Continual Learning Training with Metrics Logging
# -------------------------------

def continual_training(model, tasks, train_dataset, eval_loader, epochs_per_task, batch_size, device, eta=0.99):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    # We will track metrics for layers 'fc1' and 'fc2'
    layer_names = ['fc1', 'fc2']
    # Dictionaries to store running averages and ages per layer.
    running_avgs = {'contrib': {}, 'overall': {}}
    ages = {}
    
    # Before training, measure initial metrics.
    print("=== Before Training: Metrics ===")
    metrics, running_avgs, ages = measure_metrics(model, eval_loader, device, eta, layer_names, running_avgs, ages)
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")
        
    # Now loop over tasks sequentially.
    for task in tasks:
        print(f"\n=== Training on Task (classes {task}) ===")
        task_subset = get_subset_by_classes(train_dataset, task)
        task_loader = DataLoader(task_subset, batch_size=batch_size, shuffle=True, num_workers=2)
        for epoch in range(epochs_per_task):
            loss = train_epoch(model, optimizer, criterion, task_loader, device)
            print(f"Task {task} - Epoch {epoch+1}/{epochs_per_task}: Loss = {loss:.4f}")
            # At the end of each epoch, measure the metrics on the eval_loader.
            metrics, running_avgs, ages = measure_metrics(model, eval_loader, device, eta, layer_names, running_avgs, ages)
            print("Metrics after epoch:")
            for k, v in metrics.items():
                print(f"  {k}: {v:.4f}")
    return model

# -------------------------------
# 7. Joint Training (Train on Entire CIFAR10) -- without continual metrics
# -------------------------------

def joint_training(model, train_dataset, epochs, batch_size, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    for epoch in range(epochs):
        loss = train_epoch(model, optimizer, criterion, dataloader, device)
        print(f"Joint Training - Epoch {epoch+1}/{epochs}: Loss = {loss:.4f}")
    return model

# -------------------------------
# 8. Main Experiment
# -------------------------------

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    batch_size = 128
    epochs_per_task = 5  # epochs per paired-task
    num_tasks = len(tasks)
    total_epochs_joint = epochs_per_task * num_tasks  # same total epochs for joint training
    
    # We also prepare an evaluation DataLoader (for metrics) using the test set.
    eval_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print("=== Continual Learning Experiment (Sequential Training on Task Pairs) ===")
    continual_model = MLP().to(device)
    continual_model = continual_training(continual_model, tasks, train_dataset, eval_loader,
                                          epochs_per_task, batch_size, device, eta=0.99)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    acc_continual = evaluate(continual_model, test_loader, device)
    print(f"\nTest Accuracy of Continual Learning MLP: {acc_continual*100:.2f}%")
    
    print("\n=== Joint Training Experiment (Train on entire CIFAR10) ===")
    joint_model = MLP().to(device)
    joint_model = joint_training(joint_model, train_dataset, total_epochs_joint, batch_size, device)
    acc_joint = evaluate(joint_model, test_loader, device)
    print(f"\nTest Accuracy of Joint Training MLP: {acc_joint*100:.2f}%")
    
    print("\n=== Linear Model Experiment (Train on entire CIFAR10) ===")
    linear_model = LinearModel().to(device)
    linear_optimizer = optim.SGD(linear_model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    joint_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    for epoch in range(total_epochs_joint):
        loss = train_epoch(linear_model, linear_optimizer, criterion, joint_loader, device)
        print(f"Linear Model - Epoch {epoch+1}/{total_epochs_joint}: Loss = {loss:.4f}")
    acc_linear = evaluate(linear_model, test_loader, device)
    print(f"\nTest Accuracy of Linear Model: {acc_linear*100:.2f}%")


Files already downloaded and verified
Files already downloaded and verified
=== Continual Learning Experiment (Sequential Training on Task Pairs) ===
=== Before Training: Metrics ===
erank_fc1: 354.1995
erank_fc2: 321.8415
avg_contrib_fc1: 1.2890
avg_overall_fc1: 0.0465
bc_avg_contrib_fc1: 64.7742
bc_avg_overall_fc1: 2.3369
avg_contrib_fc2: 0.0103
avg_overall_fc2: 0.0009
bc_avg_contrib_fc2: 0.5161
bc_avg_overall_fc2: 0.0456

=== Training on Task (classes [0, 1]) ===
Task [0, 1] - Epoch 1/5: Loss = 0.7206
Metrics after epoch:
  erank_fc1: 219.0818
  erank_fc2: 101.4673
  avg_contrib_fc1: 1.3036
  avg_overall_fc1: 0.0470
  bc_avg_contrib_fc1: 43.8918
  bc_avg_overall_fc1: 1.5832
  avg_contrib_fc2: 0.0111
  avg_overall_fc2: 0.0010
  bc_avg_contrib_fc2: 0.3732
  bc_avg_overall_fc2: 0.0329
Task [0, 1] - Epoch 2/5: Loss = 0.3747
Metrics after epoch:
  erank_fc1: 239.7315
  erank_fc2: 130.0002
  avg_contrib_fc1: 1.3151
  avg_overall_fc1: 0.0474
  bc_avg_contrib_fc1: 33.3751
  bc_avg_overall_f

ValueError: too many values to unpack (expected 3)

Files already downloaded and verified
Files already downloaded and verified
epoch 0 CC stats using threshold = 0.999
layer layer_0 feature dim = 1000 # of connected components: 1000
layer layer_1 feature dim = 1000 # of connected components: 1000
layer layer_2 feature dim = 1000 # of connected components: 1000
layer layer_3 feature dim = 1000 # of connected components: 1000
layer layer_4 feature dim = 1000 # of connected components: 1000
layer layer_5 feature dim = 1000 # of connected components: 1000
layer layer_6 feature dim = 1000 # of connected components: 1000
Epoch 1: Train Loss: 0.0000, Val Loss: 18.0068 | Effective Rank per layer: layer_0: 657.69, layer_1: 671.36, layer_2: 675.24, layer_3: 678.24, layer_4: 681.57, layer_5: 684.71, layer_6: 686.90
epoch 1 CC stats using threshold = 0.999
layer layer_0 feature dim = 1000 # of connected components: 1000
layer layer_1 feature dim = 1000 # of connected components: 939
layer layer_2 feature dim = 1000 # of connected components: 263
l

KeyboardInterrupt: 

In [12]:
model

ResNetWithHooks(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

Files already downloaded and verified
Files already downloaded and verified
Epoch 0/24
----------
train Loss: 1.8462 Acc: 0.3340
val Loss: 1.5380 Acc: 0.4425
report CC stats for phase val
key =  initial_relu  CC =  633  rank(A) =  640
key =  layer1  CC =  640  rank(A) =  640
key =  layer2  CC =  640  rank(A) =  640
key =  layer3  CC =  639  rank(A) =  640
key =  layer4  CC =  539  rank(A) =  640
key =  avgpool  CC =  58  rank(A) =  512
key =  logits  CC =  8  rank(A) =  10

Epoch 1/24
----------
train Loss: 1.3833 Acc: 0.4971
val Loss: 1.2382 Acc: 0.5549
report CC stats for phase val
key =  initial_relu  CC =  640  rank(A) =  640
key =  layer1  CC =  640  rank(A) =  640
key =  layer2  CC =  640  rank(A) =  640
key =  layer3  CC =  640  rank(A) =  640
key =  layer4  CC =  629  rank(A) =  640
key =  avgpool  CC =  186  rank(A) =  512
key =  logits  CC =  22  rank(A) =  10

Epoch 2/24
----------
train Loss: 1.0947 Acc: 0.6105
val Loss: 1.1704 Acc: 0.6003
report CC stats for phase val
key 

Files already downloaded and verified
Files already downloaded and verified
Epoch: 1/100
Train Loss: 1.8716 | Train Acc: 28.60%
Val Loss: 1.7739 | Val Acc: 33.96%
Epoch: 2/100
Train Loss: 1.6630 | Train Acc: 37.41%
Val Loss: 1.5279 | Val Acc: 43.72%
Epoch: 3/100
Train Loss: 1.5240 | Train Acc: 43.63%
Val Loss: 1.4582 | Val Acc: 47.39%
Epoch: 4/100
Train Loss: 1.4439 | Train Acc: 47.15%
Val Loss: 1.4285 | Val Acc: 48.79%
Epoch: 5/100
Train Loss: 1.3721 | Train Acc: 50.10%
Val Loss: 1.3193 | Val Acc: 52.65%
Epoch: 6/100
Train Loss: 1.3198 | Train Acc: 52.09%
Val Loss: 1.2573 | Val Acc: 54.83%
Epoch: 7/100
Train Loss: 1.2773 | Train Acc: 53.61%
Val Loss: 1.1865 | Val Acc: 57.50%
Epoch: 8/100
Train Loss: 1.2440 | Train Acc: 54.96%
Val Loss: 1.1456 | Val Acc: 58.21%
Epoch: 9/100
Train Loss: 1.2049 | Train Acc: 56.28%
Val Loss: 1.1430 | Val Acc: 59.03%
Epoch: 10/100
Train Loss: 1.1715 | Train Acc: 57.47%
Val Loss: 1.1197 | Val Acc: 59.52%
Epoch: 11/100
Train Loss: 1.1337 | Train Acc: 59.10%


KeyboardInterrupt: 

In [2]:
model

NameError: name 'model' is not defined