# PLeaS Merging - Basic Usage

This notebook demonstrates how to use the PLeaS merging algorithm to merge two pretrained neural networks.

In [None]:
import torch
import torchvision
from pleas.core.compiler import get_permutation_spec
from pleas.methods.activation_matching import activation_matching
from pleas.methods.partial_matching import partial_merge, get_blocks
from pleas.methods.pleas_merging import train
from pleas.core.utils import Axis

## Load Models

First, let's load two pretrained models.

In [None]:
# Load two ResNet-50 models
model1 = torchvision.models.resnet50(pretrained=True)
model2 = torchvision.models.resnet50(pretrained=True)

# For this example, let's randomly modify the second model
# In a real scenario, these would be models trained on different datasets
with torch.no_grad():
    for param in model2.parameters():
        param.data = param.data + 0.01 * torch.randn_like(param.data)

## Create a Dataloader

We need a dataloader to compute activations for matching.

In [None]:
# Create a simple dataloader with random data
batch_size = 8
dataloader = torch.utils.data.DataLoader(
    [(torch.randn(3, 224, 224), 0) for _ in range(100)],
    batch_size=batch_size
)

## Generate Permutation Specification

Now we need to create a permutation specification that defines which axes can be permuted.

In [None]:
# Generate permutation specification
spec = get_permutation_spec(model1, ((1, 3, 224, 224),))
print(f"Found {len(spec)} permutable axes")

## Find Permutations with Activation Matching

In [None]:
# Perform activation matching
perm, costs = activation_matching(
    spec,
    model1,
    model2,
    dataloader,
    num_batches=10,
    output_costs=True
)

## Define Budget Ratios for Partial Merging

In [None]:
# Define budget ratios for 50% extra computation cost
budget_ratio = 1.5
budget_ratios = {Axis(k, 0): 0.5 for k in spec.keys()}

## Create Initial Merged Model with Partial Merging

In [None]:
# Create initial merged model
model3 = partial_merge(spec, model1, model2, perm, costs, budget_ratios)

## Optimize with PLeaS

In [None]:
# Optimize the merged model using PLeaS
optimized_model = train(
    dataloader,
    model1,
    model2,
    model3,
    spec,
    perm,
    costs,
    budget_ratios,
    WANDB=False,
    MAX_STEPS=50,
    wandb_run=None
)

## Save the Merged Model

In [None]:
# Save the optimized model
torch.save(optimized_model.state_dict(), "merged_model.pth")