## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from sklearn.datasets import load_digits
from sklearn.preprocessing import StandardScaler
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# Import MirrorMind
try:
    from airbornehrs import AdaptiveFramework, AdaptiveFrameworkConfig
except ImportError:
    print("Installing airbornehrs...")
    import subprocess
    subprocess.check_call(["pip", "install", "-e", ".."])
    from airbornehrs import AdaptiveFramework, AdaptiveFrameworkConfig

print("✓ All imports successful")

## 2. Data Preparation

In [None]:
# Load digits dataset
digits = load_digits()
X = digits.data
y = digits.target

# Normalize
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Convert to tensors
X_tensor = torch.FloatTensor(X_scaled)
y_tensor = torch.LongTensor(y)

# Define tasks by digit ranges
tasks = {
    "Task 1 (0-3)": (0, 4, [0, 1, 2, 3]),
    "Task 2 (4-6)": (4, 7, [4, 5, 6]),
    "Task 3 (7-9)": (7, 10, [7, 8, 9])
}

# Create task datasets
task_datasets = {}
for task_name, (start, end, classes) in tasks.items():
    mask = (y >= start) & (y < end)
    X_task = X_tensor[mask]
    y_task = y_tensor[mask]
    
    # Remap labels to 0, 1, 2, ...
    class_map = {c: i for i, c in enumerate(classes)}
    y_task_remapped = torch.tensor([class_map[c.item()] for c in y_task])
    
    task_datasets[task_name] = {
        'X': X_task,
        'y': y_task_remapped,
        'num_classes': len(classes),
        'original_classes': classes
    }
    
    print(f"✓ {task_name}: {X_task.shape[0]} samples, {len(classes)} classes")

print(f"\n✓ Data preparation complete")

## 3. Model Definition

In [None]:
class MultiTaskNet(nn.Module):
    """Network for multi-task learning with shared base and task-specific heads."""
    
    def __init__(self, input_dim=64, hidden_dim=128, num_tasks=3):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_tasks = num_tasks
        
        # Shared feature extraction layers
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 64)
        )
        
        # Task-specific heads
        self.task_heads = nn.ModuleDict()
        self.task_heads['Task 1 (0-3)'] = nn.Linear(64, 4)
        self.task_heads['Task 2 (4-6)'] = nn.Linear(64, 3)
        self.task_heads['Task 3 (7-9)'] = nn.Linear(64, 3)
    
    def forward(self, x, task_name=None):
        features = self.shared(x)
        
        if task_name is None:
            # Return all task outputs for analysis
            return {task: head(features) for task, head in self.task_heads.items()}
        
        return self.task_heads[task_name](features)

print("✓ MultiTaskNet model defined")

## 4. Baseline: Vanilla PyTorch (Shows Catastrophic Forgetting)

In [None]:
def train_vanilla_multitask():
    """Train with vanilla PyTorch, showing catastrophic forgetting."""
    
    print("\n" + "="*60)
    print("VANILLA PYTORCH: Multi-Task Learning")
    print("="*60)
    
    model = MultiTaskNet(input_dim=64, hidden_dim=128)
    model.train()
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    results = defaultdict(lambda: {'train': [], 'task_scores': {}})
    
    task_order = ['Task 1 (0-3)', 'Task 2 (4-6)', 'Task 3 (7-9)']
    
    # Train on each task sequentially
    for task_idx, task_name in enumerate(task_order):
        print(f"\nTraining on {task_name}...")
        
        X_task = task_datasets[task_name]['X']
        y_task = task_datasets[task_name]['y']
        
        # Train for 15 epochs
        for epoch in range(15):
            indices = torch.randperm(len(X_task))
            for i in range(0, len(X_task), 32):  # batch size 32
                batch_idx = indices[i:i+32]
                X_batch = X_task[batch_idx]
                y_batch = y_task[batch_idx]
                
                optimizer.zero_grad()
                output = model(X_batch, task_name)
                loss = criterion(output, y_batch)
                loss.backward()
                optimizer.step()
        
        # Evaluate on ALL tasks after training current task
        print(f"\nEvaluating after {task_name}...")
        model.eval()
        
        for eval_task in task_order:
            X_eval = task_datasets[eval_task]['X']
            y_eval = task_datasets[eval_task]['y']
            
            with torch.no_grad():
                output = model(X_eval, eval_task)
                predictions = output.argmax(dim=1)
                accuracy = (predictions == y_eval).float().mean().item()
            
            results[eval_task]['task_scores'][task_idx] = accuracy
            print(f"  {eval_task}: {accuracy:.1%}")
        
        model.train()
    
    return model, results

vanilla_model, vanilla_results = train_vanilla_multitask()

## 5. MirrorMind: Multi-Task Learning with EWC

In [None]:
def train_mirrorming_multitask():
    """Train with MirrorMind, preventing catastrophic forgetting."""
    
    print("\n" + "="*60)
    print("MIRRORMING: Multi-Task Learning with EWC")
    print("="*60)
    
    model = MultiTaskNet(input_dim=64, hidden_dim=128)
    
    config = AdaptiveFrameworkConfig(
        learning_rate=0.001,
        meta_learning_rate=0.0001,
        memory_type='ewc',
        consolidation_criterion='time',
        device='cpu',
        enable_consciousness=True
    )
    
    framework = AdaptiveFramework(model, config, device='cpu')
    framework.train()
    
    results = defaultdict(lambda: {'train': [], 'task_scores': {}})
    task_order = ['Task 1 (0-3)', 'Task 2 (4-6)', 'Task 3 (7-9)']
    
    # Train on each task sequentially
    for task_idx, task_name in enumerate(task_order):
        print(f"\nTraining on {task_name}...")
        
        X_task = task_datasets[task_name]['X']
        y_task = task_datasets[task_name]['y']
        
        # Train for 15 epochs
        criterion = nn.CrossEntropyLoss()
        for epoch in range(15):
            indices = torch.randperm(len(X_task))
            for i in range(0, len(X_task), 32):
                batch_idx = indices[i:i+32]
                X_batch = X_task[batch_idx]
                y_batch = y_task[batch_idx]
                
                output = framework.model(X_batch, task_name)
                loss = criterion(output, y_batch)
                
                framework.optimizer.zero_grad()
                loss.backward()
                framework.optimizer.step()
                
                # Add to feedback buffer for potential replay
                if hasattr(framework, 'feedback_buffer'):
                    framework.feedback_buffer.add(
                        X_batch, output.detach(), y_batch, 
                        reward=1.0, loss=loss.item()
                    )
        
        # Consolidate memory after each task
        print(f"Consolidating memory after {task_name}...")
        if hasattr(framework, 'ewc') and framework.ewc is not None:
            # Compute Fisher information on this task
            X_sample = X_task[:100]  # Sample for Fisher
            y_sample = y_task[:100]
            
            with torch.enable_grad():
                output_sample = framework.model(X_sample, task_name)
                loss_sample = criterion(output_sample, y_sample)
                loss_sample.backward()
            
            # Store Fisher information
            if hasattr(framework.ewc, 'compute_fisher_on_dataset'):
                framework.ewc.compute_fisher_on_dataset(
                    framework.model, X_sample, y_sample, task_name
                )
        
,
\nEvaluating after {task_name}...")
        framework.eval()
        
        for eval_task in task_order:
            X_eval = task_datasets[eval_task]['X']
            y_eval = task_datasets[eval_task]['y']
            
            with torch.no_grad():
                output = framework.model(X_eval, eval_task)
                predictions = output.argmax(dim=1)
                accuracy = (predictions == y_eval).float().mean().item()
            
            results[eval_task]['task_scores'][task_idx] = accuracy
            print(f"  {eval_task}: {accuracy:.1%}")
        
        framework.train()
    
    return framework, results

mirror_framework, mirror_results = train_mirrorming_multitask()

## 6. Results Comparison

In [None]:
# Compare results
print("\n" + "="*70)
print("COMPARISON: Vanilla vs MirrorMind")
print("="*70)

task_names = ['Task 1 (0-3)', 'Task 2 (4-6)', 'Task 3 (7-9)']

# Create comparison dataframe
comparison = {}
for task in task_names:
    vanilla_acc_start = vanilla_results[task]['task_scores'].get(0, 0.0)
    vanilla_acc_final = vanilla_results[task]['task_scores'].get(2, 0.0)
    
    mirror_acc_start = mirror_results[task]['task_scores'].get(0, 0.0)
    mirror_acc_final = mirror_results[task]['task_scores'].get(2, 0.0)
    
    vanilla_forgetting = vanilla_acc_start - vanilla_acc_final
    mirror_forgetting = mirror_acc_start - mirror_acc_final
    
    print(f"\n{task}")
    print("-" * 70)
    print(f"  Vanilla:")
    print(f"    Initial accuracy: {vanilla_acc_start:.1%}")
    print(f"    Final accuracy:   {vanilla_acc_final:.1%}")
    print(f"    Forgetting:       {vanilla_forgetting:.1%}")
    print(f"  MirrorMind:")
    print(f"    Initial accuracy: {mirror_acc_start:.1%}")
    print(f"    Final accuracy:   {mirror_acc_final:.1%}")
    print(f"    Forgetting:       {mirror_forgetting:.1%}")
    
    if vanilla_forgetting > 0:
        improvement = (vanilla_forgetting - mirror_forgetting) / vanilla_forgetting * 100
        print(f"  Improvement:      {improvement:.0f}% reduction in forgetting")
    
    comparison[task] = {
        'vanilla_forgetting': vanilla_forgetting,
        'mirror_forgetting': mirror_forgetting
    }

print("\n" + "="*70)

## 7. Visualization

In [None]:
# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, task in enumerate(task_names):
    ax = axes[idx]
    
    vanilla_scores = [vanilla_results[task]['task_scores'].get(i, 0) for i in range(3)]
    mirror_scores = [mirror_results[task]['task_scores'].get(i, 0) for i in range(3)]
    
    x = np.arange(3)
    width = 0.35
    
    ax.bar(x - width/2, vanilla_scores, width, label='Vanilla', alpha=0.8, color='red')
    ax.bar(x + width/2, mirror_scores, width, label='MirrorMind', alpha=0.8, color='green')
    
    ax.set_ylabel('Accuracy')
    ax.set_xlabel('After Training on Task')
    ax.set_title(task)
    ax.set_xticks(x)
    ax.set_xticklabels(['Task 1', 'Task 2', 'Task 3'])
    ax.set_ylim([0, 1.0])
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for i, (v, m) in enumerate(zip(vanilla_scores, mirror_scores)):
        ax.text(i - width/2, v + 0.02, f'{v:.0%}', ha='center', va='bottom', fontsize=9)
        ax.text(i + width/2, m + 0.02, f'{m:.0%}', ha='center', va='bottom', fontsize=9)

plt.suptitle('Multi-Task Learning: Catastrophic Forgetting Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('mirrorming_multitask_results.png', dpi=150, bbox_inches='tight')
print("✓ Visualization saved to mirrorming_multitask_results.png")
plt.show()

## 8. Summary Statistics

In [None]:
# Calculate summary metrics
print("\n" + "="*70)
print("SUMMARY METRICS")
print("="*70)

vanilla_total_forgetting = sum(c['vanilla_forgetting'] for c in comparison.values())
mirror_total_forgetting = sum(c['mirror_forgetting'] for c in comparison.values())

print(f"\nTotal Catastrophic Forgetting (across all tasks):")
print(f"  Vanilla PyTorch: {vanilla_total_forgetting:.1%}")
print(f"  MirrorMind:      {mirror_total_forgetting:.1%}")

if vanilla_total_forgetting > 0:
    overall_improvement = (vanilla_total_forgetting - mirror_total_forgetting) / vanilla_total_forgetting * 100
    print(f"  Improvement:    {overall_improvement:.0f}% reduction")

print("\n✓ Multi-task learning demonstration complete!")