In [None]:
import torch
import torch.nn as nn
import numpy as np
from avalanche.benchmarks import RotatedMNIST
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive
from avalanche.evaluation.metrics import accuracy_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin

# Import Localize and Stitch classes
from localize_and_stitch import Localizer, Stitcher

# Helper function to compute task vectors
def compute_task_vector(pretrained_model, finetuned_model, device):
    task_vector = []
    for p_pre, p_fine in zip(pretrained_model.parameters(), finetuned_model.parameters()):
        task_vector.append((p_fine.data - p_pre.data).detach().to(device))
    return task_vector

# Function to train on a single task and compute task vector
def train_task(experience, model, pretrained_model, optimizer, criterion, device):
    model.to(device)
    pretrained_model.to(device)

    # Fine-tune model on the task
    trainer = Naive(
        model,
        optimizer,
        criterion,
        train_mb_size=128,
        device=device
    )
    trainer.train(experience, epochs=5)

    # Compute task vector
    task_vector = compute_task_vector(pretrained_model, model, device)

    # Reset pretrained_model to its original device
    pretrained_model.to("cpu")
    return task_vector

# Main function for continual learning
def continual_learning_with_localize_and_stitch():
    # Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Benchmark: RotatedMNIST for tasks with rotations
    rotation_angles = [0, 15, 30, 45, 60, 75, 90, 115, 145, 175]
    rotated_benchmark = RotatedMNIST(n_experiences=len(rotation_angles), seed=1234)

    # Pretrain the base model
    model_base = SimpleMLP(num_classes=10).to(device)
    optimizer = torch.optim.SGD(model_base.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    # Initialize the pretrained model
    model_pretrained = SimpleMLP(num_classes=10).to(device)
    model_pretrained.load_state_dict(model_base.state_dict())

    # Active task vector set
    task_vectors_active = []

    # Localize and Stitch Arguments
    graft_args = {
        'sparsity': 0.01,  # 1% sparsity
        'sigmoid_bias': 1.0,
        'learning_rate': 0.01,
        'l1_strength': 0.0001,
        'num_train_epochs': 5
    }

    # Loop over each experience in RotatedMNIST
    for task_id, experience in enumerate(rotated_benchmark.train_stream):
        print(f"\n### Training on Task {task_id+1} (Rotation: {rotation_angles[task_id]}°) ###")

        # Initialize the current model and load pretrained weights
        model_current = SimpleMLP(num_classes=10).to(device)
        model_current.load_state_dict(model_pretrained.state_dict())

        # Train the model on the current task and compute task vector
        task_vector = train_task(experience, model_current, model_pretrained, optimizer, criterion, device)

        # Initialize Localizer for the current task
        localizer = Localizer(
            trainable_params=model_current.state_dict(),
            model=model_current,
            pretrained_model=model_pretrained,
            finetuned_model=model_current,
            dataset_name="RotatedMNIST",
            args=None,
            graft_args=graft_args,
            model_type="vit"  # Example: Vision Transformer
        )
        
        # Create binary masks and base patch
        localizer.create_binary_masks()
        base_patch = localizer.create_basepatch()

        # Train graft to localize task-specific parameters
        dataloader = experience.dataset
        mask, proportion, val = localizer.train_graft(dataloader=dataloader, dataset_name=f"RotatedMNIST-{rotation_angles[task_id]}")

        # Add the refined task vector to the active set
        task_vectors_active.append(task_vector)

        # Update the pretrained model
        model_pretrained.load_state_dict(model_current.state_dict())

    # Stitch the models together after all tasks
    print("\n### Stitching Models Together ###")
    stitcher = Stitcher(
        trainable_params=model_pretrained.state_dict(),
        model=model_pretrained,
        pretrained_model=model_base,
        finetuned_models=[model_pretrained for _ in task_vectors_active],  # One model per task
        masks=[mask]  # Masks generated for each task
    )
    generalized_model = stitcher.interpolate_models()

    # Evaluate the stitched model on all tasks
    print("\n### Evaluating Stitched Model ###")
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(epoch=True, stream=True),
        loggers=[InteractiveLogger()]
    )
    evaluator = Naive(generalized_model, optimizer, criterion, device=device, evaluator=eval_plugin)
    evaluator.eval(rotated_benchmark.test_stream)

if __name__ == "__main__":
    continual_learning_with_localize_and_stitch()