In [1]:
from model import network_mnist,naive_train,test_taskwise,test,benchmark,train_stream,test_stream,compute_fisher_information,apply_importance_mask,create_masked_weight_dict,load_non_zero_weights
from torch.utils.data import DataLoader, Subset
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from plot import plot_parameter_importance
import os
import copy 

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
model=network_mnist(256,128)
model_2=network_mnist(256,128)
#print(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model_2.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs=5

In [None]:
def method_1_train(model,task_number, epochs,criterion,optimizer,device,weight_dicts):
    experience = train_stream[task_number]
    train_loader = DataLoader(experience.dataset, batch_size=64, shuffle=True)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        
        for images, labels, *_ in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
       

        #print(f"Task {task_number}, Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

In [17]:
os.makedirs('figures', exist_ok=True)
all_tasks_data = {}
weight_dicts=[]
for task in range(2):
    print(f"\n{'='*70}")
    print(f"Training on Task {task}")
    print(f"{'='*70}")

    method_1_train(model, task, epochs, criterion=criterion, optimizer=optimizer, device=device, weight_dicts=weight_dicts)
    acc = test_taskwise(model, task, device)
    print(f"Post-training accuracy on Task {task}: {acc:.2f}%")

    fisher_dict = compute_fisher_information(model, task_number=task, num_samples=500, 
                                             criterion=criterion, device=device)
    percent_list = list(range(50,51))
    accuracy_vs_percent = []
    

    original_weights = {name: param.clone() for name, param in model.state_dict().items()}
    for p in percent_list:
        model, mask_dict = apply_importance_mask(model, fisher_dict, importance_percent=p)
        weight_dicts.append(create_masked_weight_dict(model, mask_dict))
        acc_p = test_taskwise(model, task, device)
        accuracy_vs_percent.append(acc_p)
        model.load_state_dict(original_weights, strict=False)
       
    all_tasks_data[task] = (percent_list, accuracy_vs_percent)
    test(model, device)
    #plot_parameter_importance(percent_list, accuracy_vs_percent, task, save_path=f'figures/task_{task}_importance.png')



Training on Task 0
Accuracy on task 0: 99.95%
Post-training accuracy on Task 0: 99.95%
Accuracy on task 0: 99.95%
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%

Training on Task 1
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%
Accuracy on task 1: 99.36%
Post-training accuracy on Task 1: 99.36%
Accuracy on task 

In [None]:
for i in range(5):
    print(f"\n{'='*70}")
    print(f"Evaluating Model on Task {i} after applying masks from all tasks")
    print(f"{'='*70}")
   
    print(f"\n-- Using mask from Task {i} --")
    model_2=copy.deepcopy(model)
    masked_model = load_non_zero_weights(model_2, weight_dicts[i])
    acc = test_taskwise(masked_model, i, device)
    print(f"Accuracy on Task {i} with mask from Task {i}: {acc:.2f}%")
    masked_model.load_state_dict(weight_dicts[i], strict=False)
    acc = test_taskwise(masked_model, i, device)
    print(f"Accuracy on Task {i} with mask from Task {i} (direct load): {acc:.2f}%")

In [1]:
from model import network_mnist, test_taskwise, test, train_stream, test_stream
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import os
from collections import defaultdict
from typing import Dict, Tuple
import warnings

# ==================== Avalanche-style EWC Implementation ====================

def copy_params_dict(model):
    """Copy parameters from model (Avalanche utility function)"""
    return [(name, param.data.clone()) for name, param in model.named_parameters()]

def zerolike_params_dict(model):
    """Create zero-like parameter dictionary (Avalanche utility function)"""
    return [(name, torch.zeros_like(param)) for name, param in model.named_parameters()]

class EWCPlugin:
    """
    Elastic Weight Consolidation (EWC) plugin adapted from Avalanche.
    Exactly follows Avalanche's EWCPlugin implementation.
    """
    
    def __init__(self, ewc_lambda, mode='separate', decay_factor=None, keep_importance_data=False):
        """
        Args:
            ewc_lambda: hyperparameter to weigh the penalty inside the total loss
            mode: 'separate' to keep penalty for each previous experience,
                  'online' to keep single penalty with decay
            decay_factor: used only if mode is 'online'
            keep_importance_data: if True, keep in memory parameter values and importances
        """
        assert (decay_factor is None) or (mode == 'online'), \
            "You need to set `online` mode to use `decay_factor`."
        assert (decay_factor is not None) or (mode != 'online'), \
            "You need to set `decay_factor` to use the `online` mode."
        assert mode == 'separate' or mode == 'online', \
            'Mode must be separate or online.'

        self.ewc_lambda = ewc_lambda
        self.mode = mode
        self.decay_factor = decay_factor

        if self.mode == 'separate':
            self.keep_importance_data = True
        else:
            self.keep_importance_data = keep_importance_data

        self.saved_params = defaultdict(list)
        self.importances = defaultdict(list)

    def compute_ewc_penalty(self, model, exp_counter, device):
        """
        Compute EWC penalty (equivalent to before_backward in Avalanche).
        """
        if exp_counter == 0:
            return torch.tensor(0).float().to(device)

        penalty = torch.tensor(0).float().to(device)

        if self.mode == 'separate':
            for experience in range(exp_counter):
                for (_, cur_param), (_, saved_param), (_, imp) in zip(
                        model.named_parameters(),
                        self.saved_params[experience],
                        self.importances[experience]):
                    penalty += (imp * (cur_param - saved_param).pow(2)).sum()
        elif self.mode == 'online':
            prev_exp = exp_counter - 1
            for (_, cur_param), (_, saved_param), (_, imp) in zip(
                    model.named_parameters(),
                    self.saved_params[prev_exp],
                    self.importances[prev_exp]):
                penalty += (imp * (cur_param - saved_param).pow(2)).sum()
        else:
            raise ValueError('Wrong EWC mode.')

        return penalty

    def compute_importances(self, model, criterion, optimizer, dataset, device, batch_size):
        """
        Compute EWC importance matrix for each parameter.
        Exactly follows Avalanche's compute_importances method.
        """
        model.eval()

        # Set RNN-like modules on GPU to training mode to avoid CUDA error
        if device.type == 'cuda':
            for module in model.modules():
                if isinstance(module, torch.nn.RNNBase):
                    warnings.warn(
                        'RNN-like modules do not support '
                        'backward calls while in `eval` mode on CUDA '
                        'devices. Setting all `RNNBase` modules to '
                        '`train` mode. May produce inconsistent '
                        'output if such modules have `dropout` > 0.'
                    )
                    module.train()

        # Initialize importances as list of tuples
        importances = zerolike_params_dict(model)
        
        # Move importances to device
        importances = [(name, imp.to(device)) for name, imp in importances]
        
        dataloader = DataLoader(dataset, batch_size=batch_size)
        
        for i, batch in enumerate(dataloader):
            # Handle Avalanche batch format: (x, y, task_labels, ...)
            x, y = batch[0], batch[1]
            # For Split MNIST, we don't need task_labels for forward pass
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            out = model(x)  # Direct forward pass (no avalanche_forward needed)
            loss = criterion(out, y)
            loss.backward()

            for (k1, p), (k2, imp) in zip(model.named_parameters(), importances):
                assert (k1 == k2)
                if p.grad is not None:
                    imp += p.grad.data.clone().pow(2)

        # Average over mini batch length
        for _, imp in importances:
            imp /= float(len(dataloader))

        # Print statistics
        total_importance = sum(imp.sum().item() for _, imp in importances)
        num_params = sum(imp.numel() for _, imp in importances)
        print(f"  Total importance: {total_importance:.6f}")
        print(f"  Average importance per param: {total_importance/num_params:.8f}")

        return importances

    def update_importances(self, importances, t):
        """
        Update importance for each parameter based on currently computed importances.
        Exactly follows Avalanche's update_importances method.
        """
        if self.mode == 'separate' or t == 0:
            self.importances[t] = importances
        elif self.mode == 'online':
            self.importances[t] = []
            for (k1, old_imp), (k2, curr_imp) in \
                    zip(self.importances[t - 1], importances):
                assert k1 == k2, 'Error in importance computation.'
                self.importances[t].append(
                    (k1, (self.decay_factor * old_imp + curr_imp)))

            # Clear previous parameter importances
            if t > 0 and (not self.keep_importance_data):
                del self.importances[t - 1]
        else:
            raise ValueError("Wrong EWC mode.")

    def after_training_exp(self, model, criterion, optimizer, experience, device, batch_size, exp_counter):
        """
        Compute importances after each experience (equivalent to Avalanche's after_training_exp).
        """
        importances = self.compute_importances(model, criterion, optimizer, 
                                             experience.dataset, device, batch_size)
        self.update_importances(importances, exp_counter)
        self.saved_params[exp_counter] = copy_params_dict(model)
        
        # Clear previous parameter values
        if exp_counter > 0 and (not self.keep_importance_data):
            del self.saved_params[exp_counter - 1]


# ==================== Training Script ====================

# Initialize model
model = network_mnist(256, 128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()

# Hyperparameters
ewc_lambda = 5000
learning_rate = 0.001
epochs_per_task = 5
batch_size = 64

# Initialize EWC Plugin
ewc = EWCPlugin(ewc_lambda=ewc_lambda, mode='separate')

os.makedirs('figures', exist_ok=True)

print("="*70)
print("Training Split MNIST with EWC (Avalanche EWCPlugin Implementation)")
print("="*70)
print(f"EWC Lambda: {ewc_lambda}")
print(f"Learning Rate: {learning_rate}")
print(f"Epochs per task: {epochs_per_task}")
print(f"Batch size: {batch_size}")
print("="*70)

# Main training loop
for task in range(5):
    print(f"\n{'='*70}")
    print(f"Task {task} Training")
    print(f"{'='*70}")
    
    # Get experience for current task
    experience = train_stream[task]
    train_loader = DataLoader(experience.dataset, batch_size=batch_size, shuffle=True)
    
    # Create optimizer for this task
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    model.train()
    for epoch in range(epochs_per_task):
        total_loss = 0
        total_loss_ce = 0
        total_loss_ewc = 0
        correct = 0
        total_samples = 0
        
        for images, labels, *_ in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(images)
            
            # Compute task loss
            loss_ce = criterion(outputs, labels)
            
            # Compute EWC penalty (equivalent to before_backward)
            ewc_penalty = ewc.compute_ewc_penalty(model, task, device)
            
            # Total loss (following Avalanche's formula exactly)
            loss = loss_ce + ewc_lambda * ewc_penalty
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Track metrics
            total_loss += loss.item()
            total_loss_ce += loss_ce.item()
            total_loss_ewc += ewc_penalty.item()
            
            # Accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
        
        # Epoch statistics
        avg_loss = total_loss / len(train_loader)
        avg_loss_ce = total_loss_ce / len(train_loader)
        avg_loss_ewc = total_loss_ewc / len(train_loader)
        accuracy = 100 * correct / total_samples
        
        print(f"  Epoch {epoch+1}/{epochs_per_task} - Loss: {avg_loss:.4f} "
              f"(CE: {avg_loss_ce:.4f}, EWC: {avg_loss_ewc:.4f}), Acc: {accuracy:.2f}%")
    
    # Test on current task
    print(f"\nPost-training accuracy:")
    acc_current = test_taskwise(model, task, device)
    
    # Compute importances after training (equivalent to after_training_exp)
    print(f"\nComputing importances for Task {task}...")
    ewc.after_training_exp(model, criterion, optimizer, experience, device, batch_size, task)
    
    # Test on all tasks seen so far
    print(f"\n{'='*70}")
    print("Testing on all tasks seen so far:")
    print(f"{'='*70}")
    task_accuracies = []
    for t in range(task + 1):
        acc = test_taskwise(model, t, device)
        task_accuracies.append(acc)
    
    avg_acc = sum(task_accuracies) / len(task_accuracies)
    print(f"Average accuracy: {avg_acc:.2f}%")

# Final evaluation
print("\n" + "="*70)
print("FINAL RESULTS - All 5 Tasks")
print("="*70)
final_acc, final_acc_list = test(model, device)

# Display task-wise breakdown
print("\nTask-wise Final Results:")
for i, acc in enumerate(final_acc_list):
    print(f"Task {i}: {acc:.2f}%")

print("="*70)

# Save model and EWC data
torch.save({
    'model_state_dict': model.state_dict(),
    'ewc_lambda': ewc_lambda,
    'saved_params': dict(ewc.saved_params),
    'importances': dict(ewc.importances),
    'final_accuracies': final_acc_list
}, 'ewc_avalanche_plugin_model.pth')
print("\nModel saved to 'ewc_avalanche_plugin_model.pth'")


  from .autonotebook import tqdm as notebook_tqdm


Training Split MNIST with EWC (Avalanche EWCPlugin Implementation)
EWC Lambda: 5000
Learning Rate: 0.001
Epochs per task: 5
Batch size: 64

Task 0 Training
  Epoch 1/5 - Loss: 0.0441 (CE: 0.0441, EWC: 0.0000), Acc: 99.36%
  Epoch 2/5 - Loss: 0.0057 (CE: 0.0057, EWC: 0.0000), Acc: 99.88%
  Epoch 3/5 - Loss: 0.0018 (CE: 0.0018, EWC: 0.0000), Acc: 99.94%
  Epoch 4/5 - Loss: 0.0032 (CE: 0.0032, EWC: 0.0000), Acc: 99.91%
  Epoch 5/5 - Loss: 0.0005 (CE: 0.0005, EWC: 0.0000), Acc: 99.99%

Post-training accuracy:
Accuracy on task 0: 99.95%

Computing importances for Task 0...
  Total importance: 0.002349
  Average importance per param: 0.00000001

Testing on all tasks seen so far:
Accuracy on task 0: 99.95%
Average accuracy: 99.95%

Task 1 Training
  Epoch 1/5 - Loss: 0.4571 (CE: 0.4553, EWC: 0.0000), Acc: 88.71%
  Epoch 2/5 - Loss: 0.0515 (CE: 0.0497, EWC: 0.0000), Acc: 98.30%
  Epoch 3/5 - Loss: 0.0307 (CE: 0.0291, EWC: 0.0000), Acc: 99.04%
  Epoch 4/5 - Loss: 0.0194 (CE: 0.0179, EWC: 0.0000

In [2]:
from model import network_mnist, test_taskwise, test, train_stream, test_stream
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import os
from collections import defaultdict

# ==================== DEBUG VERSION WITH FIXES ====================

def copy_params_dict(model):
    """Copy parameters from model"""
    return [(name, param.data.clone()) for name, param in model.named_parameters()]

def zerolike_params_dict(model):
    """Create zero-like parameter dictionary"""
    return [(name, torch.zeros_like(param)) for name, param in model.named_parameters()]

class EWCPluginFixed:
    """
    EWC Plugin with debugging and fixes for very small Fisher values.
    """
    
    def __init__(self, ewc_lambda, mode='separate'):
        self.ewc_lambda = ewc_lambda
        self.mode = mode
        self.saved_params = defaultdict(list)
        self.importances = defaultdict(list)
        self.keep_importance_data = True

    def compute_ewc_penalty(self, model, exp_counter, device):
        """Compute EWC penalty with debugging."""
        if exp_counter == 0:
            return torch.tensor(0).float().to(device)

        penalty = torch.tensor(0).float().to(device)

        if self.mode == 'separate':
            for experience in range(exp_counter):
                exp_penalty = torch.tensor(0).float().to(device)
                for (name_cur, cur_param), (name_saved, saved_param), (name_imp, imp) in zip(
                        model.named_parameters(),
                        self.saved_params[experience],
                        self.importances[experience]):
                    
                    # Compute parameter-wise penalty
                    param_penalty = (imp * (cur_param - saved_param).pow(2)).sum()
                    exp_penalty += param_penalty
                
                penalty += exp_penalty
                
        return penalty

    def compute_importances(self, model, criterion, optimizer, dataset, device, batch_size):
        """
        Compute importances - FIXED VERSION.
        The key issue: we need to use model in TRAIN mode to get meaningful gradients.
        """
        # CRITICAL: Keep model in train mode to get proper gradients
        model.train()
        
        # Initialize importances
        importances = zerolike_params_dict(model)
        importances = [(name, imp.to(device)) for name, imp in importances]
        
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        print(f"  Computing Fisher over {len(dataloader)} batches...")
        
        for i, batch in enumerate(dataloader):
            x, y = batch[0], batch[1]
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()

            # Accumulate squared gradients
            for (k1, p), (k2, imp) in zip(model.named_parameters(), importances):
                assert k1 == k2
                if p.grad is not None:
                    imp += p.grad.data.clone().pow(2)

        # Average over mini-batches
        for name, imp in importances:
            imp /= float(len(dataloader))

        # Detailed statistics
        total_importance = sum(imp.sum().item() for _, imp in importances)
        num_params = sum(imp.numel() for _, imp in importances)
        max_importance = max(imp.max().item() for _, imp in importances)
        min_importance = min(imp.min().item() for _, imp in importances)
        
        print(f"  Fisher Statistics:")
        print(f"    Total: {total_importance:.6f}")
        print(f"    Mean: {total_importance/num_params:.10f}")
        print(f"    Max: {max_importance:.10f}")
        print(f"    Min: {min_importance:.10f}")

        return importances

    def after_training_exp(self, model, criterion, optimizer, experience, device, batch_size, exp_counter):
        """Compute and store importances after training."""
        print(f"\n  Computing importances for Task {exp_counter}...")
        
        importances = self.compute_importances(
            model, criterion, optimizer, 
            experience.dataset, device, batch_size
        )
        
        # Store importances
        self.importances[exp_counter] = importances
        
        # Store parameters
        self.saved_params[exp_counter] = copy_params_dict(model)
        
        print(f"  ✓ Stored {len(importances)} parameter groups")


# ==================== Training Script ====================

model = network_mnist(256, 128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()

# Try with MUCH higher lambda since Fisher values are tiny
ewc_lambda = 1000000  # 1 million - compensate for tiny Fisher values
learning_rate = 0.001
epochs_per_task = 5
batch_size = 64

ewc = EWCPluginFixed(ewc_lambda=ewc_lambda, mode='separate')

os.makedirs('figures', exist_ok=True)

print("="*70)
print("EWC Training with Debugging")
print("="*70)
print(f"EWC Lambda: {ewc_lambda:,}")
print(f"Learning Rate: {learning_rate}")
print(f"Epochs per task: {epochs_per_task}")
print("="*70)

for task in range(5):
    print(f"\n{'='*70}")
    print(f"TASK {task} TRAINING")
    print(f"{'='*70}")
    
    experience = train_stream[task]
    train_loader = DataLoader(experience.dataset, batch_size=batch_size, shuffle=True)
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.train()
    for epoch in range(epochs_per_task):
        total_loss = 0
        total_ce = 0
        total_ewc = 0
        correct = 0
        total_samples = 0
        
        for batch_idx, (images, labels, *_) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss_ce = criterion(outputs, labels)
            
            # Compute EWC penalty
            ewc_penalty = ewc.compute_ewc_penalty(model, task, device)
            
            # Total loss
            loss = loss_ce + ewc_lambda * ewc_penalty
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_ce += loss_ce.item()
            total_ewc += ewc_penalty.item()
            
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
        
        avg_loss = total_loss / len(train_loader)
        avg_ce = total_ce / len(train_loader)
        avg_ewc = total_ewc / len(train_loader)
        accuracy = 100 * correct / total_samples
        
        print(f"  Epoch {epoch+1} - Loss: {avg_loss:.4f} (CE: {avg_ce:.4f}, EWC: {avg_ewc:.6f}), Acc: {accuracy:.2f}%")
    
    print(f"\n  Post-training Test:")
    acc_current = test_taskwise(model, task, device)
    
    # Compute importances
    ewc.after_training_exp(model, criterion, optimizer, experience, device, batch_size, task)
    
    # Test all tasks
    print(f"\n{'='*70}")
    print("ALL TASKS ACCURACY:")
    print(f"{'='*70}")
    accuracies = []
    for t in range(task + 1):
        acc = test_taskwise(model, t, device)
        accuracies.append(acc)
    
    avg = sum(accuracies) / len(accuracies)
    print(f"Average: {avg:.2f}%")

print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
final_acc, final_list = test(model, device)

for i, acc in enumerate(final_list):
    print(f"Task {i}: {acc:.2f}%")

print("="*70)

torch.save({
    'model_state_dict': model.state_dict(),
    'ewc_lambda': ewc_lambda,
    'saved_params': dict(ewc.saved_params),
    'importances': dict(ewc.importances),
}, 'ewc_debug_model.pth')


EWC Training with Debugging
EWC Lambda: 1,000,000
Learning Rate: 0.001
Epochs per task: 5

TASK 0 TRAINING
  Epoch 1 - Loss: 0.0439 (CE: 0.0439, EWC: 0.000000), Acc: 99.31%
  Epoch 2 - Loss: 0.0058 (CE: 0.0058, EWC: 0.000000), Acc: 99.85%
  Epoch 3 - Loss: 0.0029 (CE: 0.0029, EWC: 0.000000), Acc: 99.92%
  Epoch 4 - Loss: 0.0012 (CE: 0.0012, EWC: 0.000000), Acc: 99.96%
  Epoch 5 - Loss: 0.0019 (CE: 0.0019, EWC: 0.000000), Acc: 99.93%

  Post-training Test:
Accuracy on task 0: 99.91%

  Computing importances for Task 0...
  Computing Fisher over 198 batches...
  Fisher Statistics:
    Total: 0.003849
    Mean: 0.0000000164
    Max: 0.0001002361
    Min: 0.0000000000
  ✓ Stored 6 parameter groups

ALL TASKS ACCURACY:
Accuracy on task 0: 99.91%
Average: 99.91%

TASK 1 TRAINING
  Epoch 1 - Loss: 0.5143 (CE: 0.4237, EWC: 0.000000), Acc: 90.36%
  Epoch 2 - Loss: 0.0648 (CE: 0.0454, EWC: 0.000000), Acc: 98.38%
  Epoch 3 - Loss: 0.0366 (CE: 0.0251, EWC: 0.000000), Acc: 99.13%
  Epoch 4 - Loss: 

In [9]:
from model import network_mnist, test_taskwise, test, train_stream, test_stream
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import os

class EWC_Final:
    """Final working version of EWC - simplified and fixed."""
    
    def __init__(self, ewc_lambda):
        self.ewc_lambda = ewc_lambda
        self.saved_params = {}
        self.importances = {}
    
    def penalty(self, model):
        """Compute EWC penalty."""
        if len(self.saved_params) == 0:
            return 0
        
        loss = 0
        for task_id in self.saved_params.keys():
            for n, p in model.named_parameters():
                _loss = self.importances[task_id][n] * (p - self.saved_params[task_id][n]) ** 2
                loss += _loss.sum()
        return loss
    
    def update(self, model, dataset, device, batch_size=64):
        """Compute and store Fisher information."""
        model.train()
        
        # Initialize Fisher
        fisher = {}
        for n, p in model.named_parameters():
            fisher[n] = torch.zeros_like(p)
        
        # Compute Fisher
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        criterion = nn.CrossEntropyLoss()
        
        for x, y, *_ in dataloader:
            x, y = x.to(device), y.to(device)
            model.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            
            for n, p in model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.pow(2)
        
        # Average
        for n in fisher:
            fisher[n] /= len(dataloader)
        
        # Store
        task_id = len(self.saved_params)
        self.importances[task_id] = fisher
        self.saved_params[task_id] = {n: p.clone().detach() for n, p in model.named_parameters()}
        
        total_fisher = sum(v.sum().item() for v in fisher.values())
        print(f"  Fisher sum: {total_fisher:.6f}")


# Initialize
model = network_mnist(256, 128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()

ewc_lambda = 400000
lr = 0.001
epochs = 5
batch_size = 64

ewc = EWC_Final(ewc_lambda)

print(f"EWC Lambda: {ewc_lambda:,}\n")

for task in range(5):
    print(f"{'='*70}")
    print(f"TASK {task}")
    print(f"{'='*70}")
    
    experience = train_stream[task]
    train_loader = DataLoader(experience.dataset, batch_size=batch_size, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=lr)  # Try SGD instead of Adam
    
    # Train
    for epoch in range(epochs):
        model.train()
        total_loss, total_ce, total_ewc = 0, 0, 0
        
        for x, y, *_ in train_loader:
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            out = model(x)
            ce_loss = criterion(out, y)
            ewc_loss = ewc.penalty(model)
            loss = ce_loss + ewc_lambda * ewc_loss
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_ce += ce_loss.item()
            total_ewc += ewc_loss.item() if isinstance(ewc_loss, torch.Tensor) else ewc_loss
        
        avg_ce = total_ce / len(train_loader)
        avg_ewc = total_ewc / len(train_loader)
        print(f"  Epoch {epoch+1}: CE={avg_ce:.4f}, EWC={avg_ewc:.6f}")
    
    # Test current
    print()
    test_taskwise(model, task, device)
    
    # Update EWC
    print("  Computing Fisher...")
    ewc.update(model, experience.dataset, device, batch_size)
    
    # Test all
    print(f"\n  All tasks:")
    for t in range(task+1):
        test_taskwise(model, t, device)
    print()

print("\nFINAL:")
test(model, device)


EWC Lambda: 400,000

TASK 0
  Epoch 1: CE=1.7731, EWC=0.000000
  Epoch 2: CE=0.4885, EWC=0.000000
  Epoch 3: CE=0.1259, EWC=0.000000
  Epoch 4: CE=0.0656, EWC=0.000000
  Epoch 5: CE=0.0444, EWC=0.000000

Accuracy on task 0: 99.76%
  Computing Fisher...
  Fisher sum: 0.100991

  All tasks:
Accuracy on task 0: 99.76%

TASK 1
  Epoch 1: CE=2.3428, EWC=0.000000
  Epoch 2: CE=0.7495, EWC=0.000001
  Epoch 3: CE=0.3871, EWC=0.000000
  Epoch 4: CE=0.2735, EWC=0.000000
  Epoch 5: CE=0.2214, EWC=0.000000

Accuracy on task 1: 96.13%
  Computing Fisher...
  Fisher sum: 1.025975

  All tasks:
Accuracy on task 0: 38.44%
Accuracy on task 1: 96.13%

TASK 2
  Epoch 1: CE=nan, EWC=nan
  Epoch 2: CE=nan, EWC=nan
  Epoch 3: CE=nan, EWC=nan
  Epoch 4: CE=nan, EWC=nan
  Epoch 5: CE=nan, EWC=nan

Accuracy on task 2: 0.00%
  Computing Fisher...
  Fisher sum: nan

  All tasks:
Accuracy on task 0: 46.34%
Accuracy on task 1: 0.00%
Accuracy on task 2: 0.00%

TASK 3
  Epoch 1: CE=nan, EWC=nan
  Epoch 2: CE=nan, EW

(9.267139479905437, [46.335697399527184, 0.0, 0.0, 0.0, 0.0])

In [4]:
from model import network_mnist, test_taskwise, test, train_stream, test_stream
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import os

class EWC_Correct:
    """Corrected EWC implementation with proper lambda scaling."""
    
    def __init__(self, ewc_lambda):
        self.ewc_lambda = ewc_lambda
        self.saved_params = {}
        self.importances = {}
    
    def penalty(self, model):
        """Compute EWC penalty: Σ_t Σ_i F_t,i * (θ_i - θ*_t,i)^2"""
        if len(self.saved_params) == 0:
            return torch.tensor(0, device=next(model.parameters()).device, dtype=torch.float32)
        
        loss = torch.tensor(0, device=next(model.parameters()).device, dtype=torch.float32)
        
        for task_id in self.saved_params.keys():
            for n, p in model.named_parameters():
                fisher = self.importances[task_id][n]
                saved_p = self.saved_params[task_id][n]
                loss += (fisher * (p - saved_p).pow(2)).sum()
        
        return loss
    
    def update(self, model, criterion, dataset, device, batch_size=64):
        """Compute and store Fisher information after training on a task."""
        model.train()
        
        # Initialize Fisher
        fisher = {}
        for n, p in model.named_parameters():
            fisher[n] = torch.zeros_like(p, device=device)
        
        # Compute Fisher information
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        print(f"    Computing Fisher over {len(dataloader)} batches...")
        
        for x, y, *_ in dataloader:
            x, y = x.to(device), y.to(device)
            
            model.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            
            # Accumulate squared gradients (empirical Fisher)
            for n, p in model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.pow(2)
        
        # Average over batches
        for n in fisher:
            fisher[n] /= len(dataloader)
        
        # Store for this task
        task_id = len(self.saved_params)
        self.importances[task_id] = fisher
        self.saved_params[task_id] = {n: p.clone().detach() for n, p in model.named_parameters()}
        
        # Print statistics
        total_fisher = sum(v.sum().item() for v in fisher.values())
        mean_fisher = total_fisher / sum(v.numel() for v in fisher.values())
        print(f"    Fisher - Total: {total_fisher:.6f}, Mean: {mean_fisher:.10f}")


# ==================== Training Script ====================

model = network_mnist(256, 128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()

# CORRECTED HYPERPARAMETERS
ewc_lambda = 100  # Start with 100 (not 400,000!)
learning_rate = 0.001
epochs_per_task = 5
batch_size = 64

ewc = EWC_Correct(ewc_lambda)

print("="*70)
print("EWC Training - Corrected Lambda")
print("="*70)
print(f"EWC Lambda: {ewc_lambda} (Correct range: 1-5000)")
print(f"Learning Rate: {learning_rate}")
print(f"Epochs: {epochs_per_task}")
print("="*70)

for task in range(5):
    print(f"\n{'='*70}")
    print(f"TASK {task}")
    print(f"{'='*70}")
    
    experience = train_stream[task]
    train_loader = DataLoader(experience.dataset, batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Train
    model.train()
    for epoch in range(epochs_per_task):
        total_loss = 0
        total_ce = 0
        total_ewc_penalty = 0
        correct = 0
        total_samples = 0
        
        for x, y, *_ in train_loader:
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            out = model(x)
            ce_loss = criterion(out, y)
            
            # EWC penalty
            ewc_penalty = ewc.penalty(model)
            
            # Total loss
            total_task_loss = ce_loss + ewc_lambda * ewc_penalty
            
            total_task_loss.backward()
            optimizer.step()
            
            # Track metrics
            total_loss += total_task_loss.item()
            total_ce += ce_loss.item()
            total_ewc_penalty += ewc_penalty.item() if isinstance(ewc_penalty, torch.Tensor) else ewc_penalty
            
            # Accuracy
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total_samples += y.size(0)
        
        avg_loss = total_loss / len(train_loader)
        avg_ce = total_ce / len(train_loader)
        avg_ewc = total_ewc_penalty / len(train_loader)
        acc = 100 * correct / total_samples
        
        print(f"  E{epoch+1}/5 - Total: {avg_loss:.4f} (CE: {avg_ce:.4f}, EWC: {avg_ewc:.6f}), Acc: {acc:.2f}%")
    
    # Test current task
    print(f"\n  Task {task} Accuracy:")
    test_taskwise(model, task, device)
    
    # Compute and store Fisher
    print(f"\n  Computing Fisher information for Task {task}...")
    ewc.update(model, criterion, experience.dataset, device, batch_size)
    
    # Test all tasks seen so far
    print(f"\n  All Tasks (Task 0 to {task}):")
    all_accs = []
    for t in range(task + 1):
        acc = test_taskwise(model, t, device)
        all_accs.append(acc)
    
    avg_acc = sum(all_accs) / len(all_accs)
    print(f"  >>> Average: {avg_acc:.2f}%")

# Final evaluation
print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
final_avg, final_list = test(model, device)
print(f"\nTask-wise breakdown:")
for i, acc in enumerate(final_list):
    print(f"  Task {i}: {acc:.2f}%")
print("="*70)

# Save
torch.save({
    'model': model.state_dict(),
    'ewc_lambda': ewc_lambda,
    'importances': ewc.importances,
    'saved_params': ewc.saved_params,
}, 'ewc_correct_model.pth')
print("✓ Model saved")


EWC Training - Corrected Lambda
EWC Lambda: 100 (Correct range: 1-5000)
Learning Rate: 0.001
Epochs: 5

TASK 0
  E1/5 - Total: 0.0461 (CE: 0.0461, EWC: 0.000000), Acc: 99.27%
  E2/5 - Total: 0.0044 (CE: 0.0044, EWC: 0.000000), Acc: 99.88%
  E3/5 - Total: 0.0031 (CE: 0.0031, EWC: 0.000000), Acc: 99.91%
  E4/5 - Total: 0.0030 (CE: 0.0030, EWC: 0.000000), Acc: 99.92%
  E5/5 - Total: 0.0015 (CE: 0.0015, EWC: 0.000000), Acc: 99.94%

  Task 0 Accuracy:
Accuracy on task 0: 99.95%

  Computing Fisher information for Task 0...
    Computing Fisher over 198 batches...
    Fisher - Total: 0.000855, Mean: 0.0000000036

  All Tasks (Task 0 to 0):
Accuracy on task 0: 99.95%
  >>> Average: 99.95%

TASK 1
  E1/5 - Total: 0.4519 (CE: 0.4519, EWC: 0.000000), Acc: 87.67%
  E2/5 - Total: 0.0516 (CE: 0.0516, EWC: 0.000000), Acc: 98.25%
  E3/5 - Total: 0.0261 (CE: 0.0261, EWC: 0.000000), Acc: 99.20%
  E4/5 - Total: 0.0175 (CE: 0.0175, EWC: 0.000000), Acc: 99.44%
  E5/5 - Total: 0.0109 (CE: 0.0109, EWC: 0.00

In [6]:
from model import network_mnist, test_taskwise, test, train_stream, test_stream
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import os

class EWC_ProperFisher:
    """
    EWC with PROPER Fisher computation.
    Key insight: Compute Fisher DURING training (when gradients are large), 
    not after convergence (when gradients are tiny).
    """
    
    def __init__(self, ewc_lambda):
        self.ewc_lambda = ewc_lambda
        self.saved_params = {}
        self.importances = {}
    
    def penalty(self, model):
        """Compute EWC penalty."""
        if len(self.saved_params) == 0:
            return torch.tensor(0, device=next(model.parameters()).device, dtype=torch.float32)
        
        loss = torch.tensor(0, device=next(model.parameters()).device, dtype=torch.float32)
        
        for task_id in self.saved_params.keys():
            for n, p in model.named_parameters():
                fisher = self.importances[task_id][n]
                saved_p = self.saved_params[task_id][n]
                loss += (fisher * (p - saved_p).pow(2)).sum()
        
        return loss
    
    def compute_fisher_training(self, model, criterion, train_loader, device):
        """
        Compute Fisher DURING early training.
        This captures importance when gradients are large (not zero).
        """
        model.train()
        
        fisher = {}
        for n, p in model.named_parameters():
            fisher[n] = torch.zeros_like(p, device=device)
        
        num_samples = 0
        
        # Collect from ONE epoch of training
        for x, y, *_ in train_loader:
            x, y = x.to(device), y.to(device)
            
            model.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            
            # Accumulate squared gradients
            for n, p in model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.pow(2)
            
            num_samples += x.size(0)
        
        # Normalize
        for n in fisher:
            fisher[n] /= float(num_samples)
        
        total_fisher = sum(v.sum().item() for v in fisher.values())
        mean_fisher = total_fisher / sum(v.numel() for v in fisher.values())
        print(f"    Fisher - Total: {total_fisher:.6f}, Mean: {mean_fisher:.10f}")
        
        return fisher
    
    def update_after_task(self, model, fisher_dict, task_id):
        """Store Fisher and parameters."""
        self.importances[task_id] = fisher_dict
        self.saved_params[task_id] = {n: p.clone().detach() for n, p in model.named_parameters()}


# ==================== Training ====================

model = network_mnist(256, 128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()

ewc_lambda = 400000  # Tuned for MNIST network
learning_rate = 0.001
epochs_per_task = 5
batch_size = 64

ewc = EWC_ProperFisher(ewc_lambda)

print("="*70)
print("EWC Training - Fisher Computed During Training")
print("="*70)
print(f"Lambda: {ewc_lambda}")
print(f"Key: Fisher computed on EARLY TRAINING (large gradients)")
print("="*70)

for task in range(5):
    print(f"\n{'='*70}")
    print(f"TASK {task}")
    print(f"{'='*70}")
    
    experience = train_stream[task]
    train_loader = DataLoader(experience.dataset, batch_size=batch_size, shuffle=True)
    
    # STEP 1: Compute Fisher from early training data (BEFORE convergence!)
    print(f"\n  Step 1: Computing Fisher from training data...")
    fisher_early = ewc.compute_fisher_training(model, criterion, train_loader, device)
    
    # STEP 2: Train with EWC penalty
    print(f"\n  Step 2: Training with EWC penalty...")
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.train()
    for epoch in range(epochs_per_task):
        total_loss = 0
        total_ce = 0
        total_ewc = 0
        correct = 0
        total_samples = 0
        
        for x, y, *_ in train_loader:
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            out = model(x)
            ce_loss = criterion(out, y)
            ewc_penalty = ewc.penalty(model)
            loss = ce_loss + ewc_lambda * ewc_penalty
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_ce += ce_loss.item()
            total_ewc += ewc_penalty.item()
            
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total_samples += y.size(0)
        
        avg_loss = total_loss / len(train_loader)
        avg_ce = total_ce / len(train_loader)
        avg_ewc = total_ewc / len(train_loader)
        acc = 100 * correct / total_samples
        
        print(f"    E{epoch+1} - Loss: {avg_loss:.4f} (CE: {avg_ce:.4f}, EWC: {avg_ewc:.6f}), Acc: {acc:.2f}%")
    
    # STEP 3: Test and store
    print(f"\n  Step 3: Evaluation")
    print(f"    Task {task}:")
    test_taskwise(model, task, device)
    
    ewc.update_after_task(model, fisher_early, task)
    
    print(f"    All Tasks (0-{task}):")
    all_accs = []
    for t in range(task + 1):
        acc = test_taskwise(model, t, device)
        all_accs.append(acc)
    
    print(f"    Average: {sum(all_accs)/len(all_accs):.2f}%")

print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
final_avg, final_list = test(model, device)
print()
for i, acc in enumerate(final_list):
    print(f"  Task {i}: {acc:.2f}%")
print("="*70)

torch.save({'model': model.state_dict(), 'ewc': ewc}, 'ewc_final.pth')


EWC Training - Fisher Computed During Training
Lambda: 400000
Key: Fisher computed on EARLY TRAINING (large gradients)

TASK 0

  Step 1: Computing Fisher from training data...
    Fisher - Total: 0.084578, Mean: 0.0000003597

  Step 2: Training with EWC penalty...
    E1 - Loss: 0.0435 (CE: 0.0435, EWC: 0.000000), Acc: 99.16%
    E2 - Loss: 0.0030 (CE: 0.0030, EWC: 0.000000), Acc: 99.93%
    E3 - Loss: 0.0015 (CE: 0.0015, EWC: 0.000000), Acc: 99.96%
    E4 - Loss: 0.0017 (CE: 0.0017, EWC: 0.000000), Acc: 99.96%
    E5 - Loss: 0.0015 (CE: 0.0015, EWC: 0.000000), Acc: 99.95%

  Step 3: Evaluation
    Task 0:
Accuracy on task 0: 99.91%
    All Tasks (0-0):
Accuracy on task 0: 99.91%
    Average: 99.91%

TASK 1

  Step 1: Computing Fisher from training data...
    Fisher - Total: 10.848949, Mean: 0.0000461371

  Step 2: Training with EWC penalty...
    E1 - Loss: 0.7523 (CE: 0.5082, EWC: 0.000001), Acc: 85.87%
    E2 - Loss: 0.1175 (CE: 0.0613, EWC: 0.000000), Acc: 98.01%
    E3 - Loss: 0

In [10]:
from model import network_mnist, test_taskwise, test, train_stream, test_stream
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import os

class EWC_Working:
    """
    EWC that actually works for continual learning.
    Key: Store parameters AFTER training, compute Fisher on held-out validation set.
    """
    
    def __init__(self, ewc_lambda):
        self.ewc_lambda = ewc_lambda
        self.saved_params = {}
        self.importances = {}
    
    def penalty(self, model):
        """Compute EWC penalty: Σ_t Σ_i F_t,i * (θ_i - θ*_t,i)^2"""
        if len(self.saved_params) == 0:
            return torch.tensor(0, device=next(model.parameters()).device, dtype=torch.float32)
        
        loss = torch.tensor(0, device=next(model.parameters()).device, dtype=torch.float32)
        
        for task_id in self.saved_params.keys():
            for n, p in model.named_parameters():
                fisher = self.importances[task_id][n]
                saved_p = self.saved_params[task_id][n]
                loss += (fisher * (p - saved_p).pow(2)).sum()
        
        return loss
    
    def compute_fisher_on_data(self, model, criterion, train_loader, device):
        """
        Compute Fisher information on the TRAINING DATA.
        Use all batches, NOT just first epoch.
        """
        model.eval()  # Use eval mode to avoid randomness from dropout/batchnorm
        
        fisher = {}
        for n, p in model.named_parameters():
            fisher[n] = torch.zeros_like(p, device=device)
        
        num_samples = 0
        
        # Use ALL data to compute Fisher (more stable)
        for x, y, *_ in train_loader:
            x, y = x.to(device), y.to(device)
            
            model.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            
            for n, p in model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.pow(2)
            
            num_samples += x.size(0)
        
        # Normalize by number of samples
        for n in fisher:
            fisher[n] /= float(num_samples)
        
        total_fisher = sum(v.sum().item() for v in fisher.values())
        print(f"    Fisher stats - Total: {total_fisher:.6f}, Norm: {total_fisher/sum(v.numel() for v in fisher.values()):.10f}")
        
        return fisher
    
    def after_training(self, model, fisher_dict, task_id):
        """Store Fisher and current parameters."""
        self.importances[task_id] = fisher_dict
        self.saved_params[task_id] = {n: p.clone().detach() for n, p in model.named_parameters()}


# ==================== Main Training Loop ====================

model = network_mnist(256, 128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()

# Properly tuned hyperparameters
ewc_lambda = 50  # Much lower - Fisher values are actually reasonable
learning_rate = 0.001
epochs_per_task = 5
batch_size = 64

ewc = EWC_Working(ewc_lambda)

print("="*70)
print("EWC Training - Final Working Version")
print("="*70)
print(f"Lambda: {ewc_lambda}")
print(f"Key changes:")
print(f"  1. Store params AFTER training (not before)")
print(f"  2. Compute Fisher on all training data (not just first epoch)")
print(f"  3. Use model.eval() for Fisher (stable gradients)")
print("="*70)

for task in range(5):
    print(f"\n{'='*70}")
    print(f"TASK {task}")
    print(f"{'='*70}")
    
    experience = train_stream[task]
    train_loader = DataLoader(experience.dataset, batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Train on this task
    print(f"\n  Training...")
    model.train()
    for epoch in range(epochs_per_task):
        total_loss = 0
        total_ce = 0
        total_ewc = 0
        correct = 0
        total_samples = 0
        
        for x, y, *_ in train_loader:
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            out = model(x)
            ce_loss = criterion(out, y)
            ewc_penalty = ewc.penalty(model)
            loss = ce_loss + ewc_lambda * ewc_penalty
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_ce += ce_loss.item()
            total_ewc += ewc_penalty.item()
            
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total_samples += y.size(0)
        
        avg_loss = total_loss / len(train_loader)
        avg_ce = total_ce / len(train_loader)
        avg_ewc = total_ewc / len(train_loader)
        acc = 100 * correct / total_samples
        
        print(f"    E{epoch+1}/5 - Loss: {avg_loss:.4f} (CE: {avg_ce:.4f}, EWC: {avg_ewc:.6f}), Acc: {acc:.2f}%")
    
    # Test current task
    print(f"\n  Task {task} Accuracy:")
    test_taskwise(model, task, device)
    
    # NOW compute Fisher AFTER training, but on training data
    print(f"\n  Computing Fisher information...")
    fisher = ewc.compute_fisher_on_data(model, criterion, train_loader, device)
    
    # Store parameters and Fisher
    ewc.after_training(model, fisher, task)
    
    # Test all tasks
    print(f"\n  All Tasks (0-{task}):")
    all_accs = []
    for t in range(task + 1):
        acc = test_taskwise(model, t, device)
        all_accs.append(acc)
    
    avg = sum(all_accs) / len(all_accs)
    print(f"  Average: {avg:.2f}%")

print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
final_avg, final_list = test(model, device)
print()
for i, acc in enumerate(final_list):
    print(f"  Task {i}: {acc:.2f}%")
print("="*70)

torch.save({'model': model.state_dict()}, 'ewc_working.pth')
print("✓ Model saved")


EWC Training - Final Working Version
Lambda: 50
Key changes:
  1. Store params AFTER training (not before)
  2. Compute Fisher on all training data (not just first epoch)
  3. Use model.eval() for Fisher (stable gradients)

TASK 0

  Training...
    E1/5 - Loss: 0.0519 (CE: 0.0519, EWC: 0.000000), Acc: 99.13%
    E2/5 - Loss: 0.0031 (CE: 0.0031, EWC: 0.000000), Acc: 99.92%
    E3/5 - Loss: 0.0032 (CE: 0.0032, EWC: 0.000000), Acc: 99.89%
    E4/5 - Loss: 0.0015 (CE: 0.0015, EWC: 0.000000), Acc: 99.94%
    E5/5 - Loss: 0.0012 (CE: 0.0012, EWC: 0.000000), Acc: 99.96%

  Task 0 Accuracy:
Accuracy on task 0: 99.91%

  Computing Fisher information...
    Fisher stats - Total: 0.000192, Norm: 0.0000000008

  All Tasks (0-0):
Accuracy on task 0: 99.91%
  Average: 99.91%

TASK 1

  Training...
    E1/5 - Loss: 0.4759 (CE: 0.4759, EWC: 0.000000), Acc: 88.90%
    E2/5 - Loss: 0.0400 (CE: 0.0400, EWC: 0.000000), Acc: 98.61%
    E3/5 - Loss: 0.0237 (CE: 0.0237, EWC: 0.000000), Acc: 99.19%
    E4/5 