# Merge Method Microscope

Investigate how different merging strategies modify weights, attention patterns, and activations for a specific target locale. Edit the configuration cells below to choose the baseline and dynamic merges you want to analyse.

## 0. Imports

Helper utilities live in `merginguriel/analysis_helpers.py`. They load checkpoints, compute parameter deltas, and summarise attention/hidden-state statistics.

In [1]:
from collections import OrderedDict
from pathlib import Path

import pandas as pd
import plotly.express as px
import torch
from IPython.display import display

from merginguriel.analysis_helpers import (
    aggregate_parameter_stats,
    collect_model_signals,
    compute_weight_deltas,
    ensure_text_samples,
    load_model_artifacts,
    merge_models_in_memory,
    summarize_attentions,
    summarize_hidden_states,
    summarize_logits,
)
from merginguriel.run_merging_pipeline_refactored import MergeConfig

pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 50)

  import pkg_resources


## 1. Configure target & merge recipes

Update `TARGET_LOCALE` and tweak `MODEL_SETUPS` to point at the baseline model and the on-the-fly merge configurations you want to compare. Remove any entries that do not make sense for your experiment (e.g., if a dataset is unavailable).

In [2]:
from merginguriel.run_merging_pipeline_refactored import MergeConfig as MergeConfg
PROJECT_ROOT = Path.cwd().resolve().parent
BASE_DIR = PROJECT_ROOT / "haryos_model"

TARGET_LOCALE = "af-ZA"  # <-- adjust target locale
DEFAULT_NUM_LANGUAGES = 5

MODEL_SETUPS = OrderedDict({
    "baseline": {
        "kind": "pretrained",
        "path": BASE_DIR / f"xlm-roberta-base_massive_k_{TARGET_LOCALE}",
        "notes": "Target-specific fine-tuned baseline",
    },
    "average_merge": {
        "kind": "merge",
        "config": MergeConfg(
            mode="average",
            target_lang=TARGET_LOCALE,
            base_model="xlm-roberta-base",
            num_languages=DEFAULT_NUM_LANGUAGES,
        ),
        "notes": "Equal-weight merge of top-K sources",
    },
    "similarity_merge": {
        "kind": "merge",
        "config": MergeConfig(
            mode="similarity",
            target_lang=TARGET_LOCALE,
            base_model="xlm-roberta-base",
            num_languages=DEFAULT_NUM_LANGUAGES,
            similarity_type="URIEL",
        ),
        "notes": "URIEL-weighted similarity merge",
    },
    "task_arithmetic": {
        "kind": "merge",
        "config": MergeConfig(
            mode="task_arithmetic",
            target_lang=TARGET_LOCALE,
            base_model="xlm-roberta-base",
            num_languages=DEFAULT_NUM_LANGUAGES,
        ),
        "notes": "Task arithmetic using similarity-selected sources",
    },
    # "fisher_dataset": {
    #     "kind": "merge",
    #     "config": MergeConfig(
    #         mode="fisher_dataset",
    #         target_lang=TARGET_LOCALE,
    #         base_model="xlm-roberta-base",
    #         num_languages=DEFAULT_NUM_LANGUAGES,
    #         dataset_name="AmazonScience/massive",
    #         dataset_split="train",
    #         text_column="utt",
    #         num_fisher_examples=500,
    #         fisher_data_mode="target",
    #         preweight="uriel",
    #     ),
    #     "notes": "Fisher merge (requires dataset access)",
    # },
})

## 2. Inspect available setups

The table summarises which resources will be used. Paths that do not exist are skipped when loading.

In [3]:
summary_rows = []
for name, cfg in MODEL_SETUPS.items():
    row = {"model": name, "kind": cfg["kind"], "notes": cfg.get("notes", "")} 
    if cfg["kind"] == "pretrained":
        path = cfg["path"]
        row["path"] = str(path)
        row["exists"] = path.exists()
    else:
        conf = cfg["config"]
        row["mode"] = conf.mode
        row["target_lang"] = conf.target_lang
        row["num_languages"] = conf.num_languages
        row["preweight"] = getattr(conf, "preweight", None)
    summary_rows.append(row)

display(pd.DataFrame(summary_rows))

Unnamed: 0,model,kind,notes,path,exists,mode,target_lang,num_languages,preweight
0,baseline,pretrained,Target-specific fine-tuned baseline,/home/coder/Python_project/MergingUriel/haryos...,True,,,,
1,average_merge,merge,Equal-weight merge of top-K sources,,,average,af-ZA,5.0,equal
2,similarity_merge,merge,URIEL-weighted similarity merge,,,similarity,af-ZA,5.0,equal
3,task_arithmetic,merge,Task arithmetic using similarity-selected sources,,,task_arithmetic,af-ZA,5.0,equal


## 3. Load baseline & run dynamic merges

Each merge configuration is executed on demand so no checkpoints need to live on disk. Results are cached in-memory for the rest of the notebook session.

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
artifacts = {}
merge_metadata = {}

for name, cfg in MODEL_SETUPS.items():
    if cfg["kind"] == "pretrained":
        path = cfg["path"]
        if not path.exists():
            print(f"⚠️ Skipping {name}: checkpoint not found at {path}")
            continue
        artifacts[name] = load_model_artifacts(path, device=DEVICE)
    else:
        artifact, meta = merge_models_in_memory(cfg["config"], device=DEVICE)
        artifacts[name] = artifact
        merge_metadata[name] = meta

print(f"Prepared {len(artifacts)} models on {DEVICE}.")
REFERENCE_KEY = "baseline" if "baseline" in artifacts else next(iter(artifacts))
print(f"Reference model: {REFERENCE_KEY}")

2025-10-29 06:04:36.761795: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761717876.770292  763754 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761717876.773202  763754 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761717876.783352  763754 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761717876.783362  763754 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761717876.783363  763754 computation_placer.cc:177] computation placer alr


--- Setting Up Average (Equal) Weights for af-ZA ---

--- Computing Similarity Weights for af-ZA ---
Using URIEL similarity matrix with top-k + Sinkhorn normalization
Loading similarity matrix from /home/coder/Python_project/MergingUriel/language_similarity_matrix_unified.csv
Loaded similarity matrix with shape: (50, 50)
Available languages: ['af-ZA', 'am-ET', 'ar-SA', 'az-AZ', 'bn-BD', 'ca-ES', 'cy-GB', 'da-DK', 'de-DE', 'el-GR', 'en-US', 'es-ES', 'fa-IR', 'fi-FI', 'fr-FR', 'hi-IN', 'hu-HU', 'hy-AM', 'id-ID', 'is-IS', 'it-IT', 'ja-JP', 'jv-ID', 'ka-GE', 'km-KH', 'kn-IN', 'ko-KR', 'lv-LV', 'ml-IN', 'mn-MN', 'ms-MY', 'my-MM', 'nb-NO', 'nl-NL', 'pl-PL', 'pt-PT', 'ro-RO', 'ru-RU', 'sl-SL', 'sq-AL', 'sw-KE', 'ta-IN', 'te-IN', 'th-TH', 'tl-PH', 'tr-TR', 'ur-PK', 'vi-VN', 'zh-TW', 'zh-TW']
Processing similarity matrix for af-ZA
Matrix shape: (50, 50)
Applying top-k filtering (k=20)...
50
Applying Sinkhorn normalization (20 iterations)...
Generated processed matrix: (50, 50)
Top 5 similar la

In [5]:
if merge_metadata:
    meta_rows = []
    for name, meta in merge_metadata.items():
        sources = []
        for path, info in meta["models_and_weights"].items():
            label = info.locale or Path(path).name
            sources.append(f"{label}:{info.weight:.3f}")
        meta_rows.append({
            "model": name,
            "base_model": meta["base_model"].model_name,
            "sources": ", ".join(sources),
        })
    display(pd.DataFrame(meta_rows))

Unnamed: 0,model,base_model,sources
0,average_merge,/home/coder/Python_project/MergingUriel/haryos...,"hy-AM:0.200, hu-HU:0.200, ka-GE:0.200, nl-NL:0..."
1,similarity_merge,/home/coder/Python_project/MergingUriel/haryos...,"hy-AM:0.283, hu-HU:0.177, ka-GE:0.114, nl-NL:0..."
2,task_arithmetic,/home/coder/Python_project/MergingUriel/haryos...,"hy-AM:0.283, hu-HU:0.177, ka-GE:0.114, nl-NL:0..."


In [6]:
meta_rows

[{'model': 'average_merge',
  'base_model': '/home/coder/Python_project/MergingUriel/haryos_model/xlm-roberta-base_massive_k_id-ID',
  'sources': 'hy-AM:0.200, hu-HU:0.200, ka-GE:0.200, nl-NL:0.200'},
 {'model': 'similarity_merge',
  'base_model': '/home/coder/Python_project/MergingUriel/haryos_model/xlm-roberta-base_massive_k_id-ID',
  'sources': 'hy-AM:0.283, hu-HU:0.177, ka-GE:0.114, nl-NL:0.091'},
 {'model': 'task_arithmetic',
  'base_model': '/home/coder/Python_project/MergingUriel/haryos_model/xlm-roberta-base_massive_k_id-ID',
  'sources': 'hy-AM:0.283, hu-HU:0.177, ka-GE:0.114, nl-NL:0.091'}]

## 4. Parameter deltas

Compute parameter-level differences relative to the reference model and inspect the heaviest changing layers.

In [7]:
weight_delta_frames = []
layer_aggregates = []

reference_model = artifacts[REFERENCE_KEY].model

for name, artifact in artifacts.items():
    if name == REFERENCE_KEY:
        continue
    deltas = compute_weight_deltas(reference_model, artifact.model)
    deltas["model"] = name
    weight_delta_frames.append(deltas)

    layer_summary = aggregate_parameter_stats(deltas)
    layer_summary["model"] = name
    layer_aggregates.append(layer_summary)

weight_deltas_df = pd.concat(weight_delta_frames, ignore_index=True) if weight_delta_frames else pd.DataFrame()
layer_deltas_df = pd.concat(layer_aggregates, ignore_index=True) if layer_aggregates else pd.DataFrame()

display(layer_deltas_df)

Unnamed: 0,layer,delta_l2_sum,delta_l2_mean,delta_mean_abs,cosine_mean,reference_norm_mean,candidate_norm_mean,model
0,classifier.dense,18.083953,9.041976,0.01025,-0.021167,8.103898,4.01318,average_merge
1,classifier.out_proj,5.604971,2.802486,0.010836,0.481004,2.529677,1.239717,average_merge
2,layer.0.attention.output,0.846175,0.211544,0.000597,0.99988,12.162775,12.160544,average_merge
3,layer.0.attention.self,2.579264,0.429877,0.000744,0.842738,37.60254,37.598843,average_merge
4,layer.0.intermediate.dense,1.857648,0.928824,0.000849,0.999905,54.537057,54.532332,average_merge
5,layer.0.output.LayerNorm,0.038962,0.019481,0.000557,0.999996,8.321826,8.321916,average_merge
6,layer.0.output.dense,1.590576,0.795288,0.00057,0.999858,36.394897,36.390967,average_merge
7,layer.1.attention.output,0.875673,0.218918,0.000602,0.99987,12.027309,12.02605,average_merge
8,layer.1.attention.self,2.562401,0.427067,0.000722,0.996896,29.436775,29.435452,average_merge
9,layer.1.intermediate.dense,1.88328,0.94164,0.000837,0.999894,53.003073,52.997586,average_merge


In [8]:
if not layer_deltas_df.empty:
    fig = px.bar(
        layer_deltas_df,
        x="layer",
        y="delta_l2_sum",
        color="model",
        barmode="group",
        title="Layer-level parameter movement (L2 sum)",
    )
    fig.show()

## 5. Probe texts

Select a small batch of utterances for forward passes. Point `file_hint` at a locale-specific corpus (one sentence per line) to replace the defaults.

In [9]:
SAMPLE_TEXTS = ensure_text_samples(
    file_hint=PROJECT_ROOT / "assets" / "sample_prompts.txt",
    limit=12,
)
SAMPLE_TEXTS

['How can I upgrade my flight booking?',
 'Show me the weather forecast for tomorrow evening.',
 'I need to reset the password for my online banking.',
 'Find vegetarian restaurants near my location.',
 'Translate this sentence into French.',
 'Remind me to call my mom at 6 PM.']

## 6. Collect signals

Run each model on the probe texts and summarise attentions, hidden states, and logits.

In [10]:
signals = {
    name: collect_model_signals(artifact, SAMPLE_TEXTS, device=DEVICE)
    for name, artifact in artifacts.items()
}

attention_records = []
for name, signal in signals.items():
    attentions = signal.get("attentions")
    if not attentions:
        continue
    attn_df = summarize_attentions(attentions)
    attn_df["model"] = name
    attention_records.append(attn_df)

attention_df = pd.concat(attention_records, ignore_index=True) if attention_records else pd.DataFrame()
display(attention_df.head())



Unnamed: 0,layer,head,mean_prob,entropy,cls_focus,diagonal_focus,model
0,0,0,0.071429,1.851026,0.101671,0.052829,baseline
1,0,1,0.071429,2.00297,0.193926,0.057709,baseline
2,0,2,0.071429,1.730103,0.094421,0.01956,baseline
3,0,3,0.071429,1.52721,0.07678,0.312523,baseline
4,0,4,0.071429,2.062032,0.159642,0.108397,baseline


In [11]:
if not attention_df.empty:
    fig = px.box(
        attention_df,
        x="layer",
        y="entropy",
        color="model",
        points="all",
        title="Attention entropy by layer",
    )
    fig.show()

    fig = px.line(
        attention_df.groupby(["model", "layer"]).mean(numeric_only=True).reset_index(),
        x="layer",
        y="cls_focus",
        color="model",
        markers=True,
        title="Average CLS attention focus",
    )
    fig.show()

In [12]:
hidden_records = []
for name, signal in signals.items():
    hidden_states = signal.get("hidden_states")
    if not hidden_states:
        continue
    hidden_df = summarize_hidden_states(hidden_states)
    hidden_df["model"] = name
    hidden_records.append(hidden_df)

hidden_df = pd.concat(hidden_records, ignore_index=True) if hidden_records else pd.DataFrame()
display(hidden_df.head())

Unnamed: 0,layer,mean_token_norm,max_token_norm,std_token_norm,sequence_mean_norm,model
0,0,7.38319,11.213955,2.026369,7.38319,baseline
1,1,13.42983,25.50482,3.641388,13.42983,baseline
2,2,19.884916,26.351151,2.411303,19.884916,baseline
3,3,21.892551,26.28574,1.455102,21.892553,baseline
4,4,21.050301,26.78878,1.688169,21.050301,baseline


In [13]:
if not hidden_df.empty:
    fig = px.line(
        hidden_df,
        x="layer",
        y="mean_token_norm",
        color="model",
        markers=True,
        title="Hidden state mean token norm",
    )
    fig.show()

In [14]:
logit_rows = []
for name, signal in signals.items():
    stats = summarize_logits(signal["logits"])
    stats["model"] = name
    logit_rows.append(stats)

logit_df = pd.DataFrame(logit_rows)
display(logit_df)

Unnamed: 0,logit_mean,logit_std,confidence_mean,confidence_std,entropy_mean,model
0,-0.056509,1.312831,0.62881,0.31459,1.444485,baseline
1,-0.022508,0.172479,0.036324,0.010476,4.074453,average_merge
2,-0.04068,0.217057,0.037611,0.009989,4.06565,similarity_merge
3,-0.011449,0.180217,0.037792,0.012902,4.071555,task_arithmetic
