In [None]:
!pip install -q accelerate safetensors

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from collections import defaultdict

In [None]:
# ✅ Colab-Ready Notebook: Functional Similarity Testing for MoE Experts (Decoder-Only)



# Load Qwen1.5-MoE decoder-only model
model_id = "Qwen/Qwen1.5-MoE-A2.7B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16
)




Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

# Dataset

In [None]:
!pip install -U datasets



In [None]:
from datasets import load_dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from collections import defaultdict

In [None]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)

tok_ds = dataset["train"].map(tokenize, batched=True, remove_columns=["text"])
train_size = int(0.8 * len(tok_ds))
train_ds, val_ds = random_split(tok_ds, [train_size, len(tok_ds) - train_size])
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8)


# New Code

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict


In [None]:

# -------------------
# Adapter Definition
# -------------------
class Adapter(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.norm(self.linear(x))

In [None]:


from collections import defaultdict
import torch

# Captures (input, output) for each expert
expert_pairs = defaultdict(list)

def make_aligned_expert_hook(layer_id, expert_id):
    def hook_fn(module, input, output):
        z1 = input[0].detach().cpu()
        fz1 = output.detach().cpu()
        for i in range(z1.shape[0]):
            expert_pairs[(layer_id, expert_id)].append((z1[i], fz1[i]))
    return hook_fn

# def get_aligned_data(layer_id, expert_id, device):
#     z_list, fz_list = zip(*expert_pairs[(layer_id, expert_id)])
#     z1 = torch.stack(z_list).to(device)
#     fz1 = torch.stack(fz_list).to(device)
#     return z1, fz1

def get_aligned_data(layer_id, expert_id, device, max_samples=None):
    z_list, fz_list = zip(*expert_pairs[(layer_id, expert_id)])
    if max_samples is not None:
        z_list = z_list[:max_samples]
        fz_list = fz_list[:max_samples]
    z1 = torch.stack(z_list).to(device)
    fz1 = torch.stack(fz_list).to(device)
    return z1, fz1


# This captures the input before routing
layer_inputs = defaultdict(list)
def make_router_input_hook(layer_id):
    def hook_fn(module, input):
        hidden = input[0]
        if hidden.size(-1) == 2048:
            layer_inputs[layer_id].append(hidden.detach().cpu())
    return hook_fn

# Attach hooks for two experts at (layer1, expert1) and (layer2, expert2)
def attach_moe_hooks(model, pair1, pair2):
    hooks = []
    for lid in {pair1[0], pair2[0]}:
        moe_block = model.model.layers[lid].mlp
        hooks.append(moe_block.register_forward_pre_hook(make_router_input_hook(lid)))
    for (lid, eid) in [pair1, pair2]:
        expert = model.model.layers[lid].mlp.experts[eid]
        hooks.append(expert.register_forward_hook(make_aligned_expert_hook(lid, eid)))
    return hooks


# Manually run z1 through expert1 to get f1_output
def get_z1_and_f1_output(layer_id, expert_id, device):
    z1 = torch.cat(layer_inputs[layer_id], dim=0).to(device)
    expert = model.model.layers[layer_id].mlp.experts[expert_id]
    f1_output = expert(z1)
    return z1, f1_output



In [None]:
for h in hooks: h.remove()

In [None]:
# -------------------
# Run Inference + Compare Experts
# -------------------
layer1, expert1 = 1, 1
layer2, expert2 = 2, 0
hooks = attach_moe_hooks(model, (layer1, expert1), (layer2, expert2))

In [None]:
model.eval()
samples_needed = 20000
with torch.no_grad():
    for batch in train_loader:
        input_ids = torch.stack(batch['input_ids']).to('cuda')
        model(input_ids)
        if len(expert_pairs[(layer1, expert1)]) >= samples_needed:
            break
for h in hooks: h.remove()



In [None]:
z1, f1_out = get_aligned_data(layer1, expert1, model.device)
f2 = model.model.layers[layer2].mlp.experts[expert2]

In [None]:
def train_adapter_aligning_experts(y1, f1_output, f2, lr=1e-3, steps=10):
    adapter = Adapter(y1.size(-1)).to(y1.device).train()
    opt = torch.optim.Adam(adapter.parameters(), lr=lr)

    y1 = y1.to(torch.float32)
    f1_output = f1_output.to(torch.float32)
    f2 = f2.to(torch.float32)
    # adapter = adapter.to(torch.float32)

    for _ in range(steps):

        aligned_input = adapter(y1)
      # with torch.no_grad(): # Wrap the call to f2 with no_grad
        f2_output = f2(aligned_input)
        # f2_output.to(torch.float32)
        # f2_output = f2(aligned_input

        # f2_output = f2(aligned_input)
        if torch.isnan(f2_output).any():
          print("❌ NaNs in f2_output")
          break
        if torch.isnan(f1_output).any():
            print("❌ NaNs in f1_output")
            break
        if f2_output.ndim == 3:
            f2_output = f2_output.view(-1, f2_output.size(-1))
        if f1_output.ndim == 3:
            f1_output = f1_output.view(-1, f1_output.size(-1))

        loss = F.mse_loss(f2_output, f1_output)
        opt.zero_grad()
        loss.backward(retain_graph=True)
        opt.step()

    # print(loss.detach().item())

    f2 = f2.to(torch.float16)
    print(loss.item())
    return adapter, loss.item()

In [None]:
adapter, loss = train_adapter_aligning_experts(z1, f1_out, f2)
print(f"✅ Final Adapter MSE Loss: {loss:.4f}")

0.24927015602588654
0.050167590379714966
0.037436481565237045
0.027257481589913368
0.02222025953233242
0.01993030495941639
0.01888357847929001
0.01834593527019024
0.0180082768201828
0.017751488834619522
✅ Final Adapter MSE Loss: 0.0178


In [None]:


# text = "The economy is improving and markets are responding positively."
# inputs = tokenizer(text, return_tensors="pt").to(model.device)

# print(inputs)
# with torch.no_grad():
#     _ = model(**inputs)


{'input_ids': tensor([[  785,  8584,   374, 18392,   323, 11725,   525, 29338, 39546,    13]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}


In [None]:

# Gather data
# y1 = torch.cat(layer_inputs[layer1], dim=0).to("cuda")
# f1_output = torch.cat(expert_outputs[(layer1, expert1)], dim=0).to("cuda")

# y1, f1_output = get_z1_and_f1_output(layer1, expert1, device="cuda")
# f2 = model.model.layers[layer2].mlp.experts[expert2]



In [None]:
# Train adapter: f2(A(y1)) ≈ f1(z1)
adapter, loss = train_adapter_aligning_experts(y1, f1_output, f2)
print(f"✅ Final Adapter MSE Loss (f2(A(y1)) ≈ f1(z1)): {loss:.6f}")

# Cleanup
for h in hooks:
    h.remove()

0.09067047387361526
38.54270553588867
0.20621776580810547
0.14524886012077332
0.09172222763299942
0.061609137803316116
0.0469362847507
0.03986551612615585
0.03613929823040962
0.03386891633272171
0.03226346895098686
0.030997002497315407
0.029936401173472404
0.029026532545685768
0.028242135420441628
0.027567831799387932
0.02699037455022335
0.02649623155593872
0.026071518659591675
0.02570287697017193
0.02537832036614418
0.025087663903832436
0.024822819977998734
0.02457754872739315
0.0243472121655941
0.024128450080752373
0.023918839171528816
0.023716680705547333
0.02352074719965458
0.023330166935920715
0.023144308477640152
0.022962680086493492
0.022784942761063576
0.02261078916490078
0.02243999019265175
0.022272348403930664
0.022107703611254692
0.021945904940366745
0.021786827594041824
0.021630359813570976
0.021476412191987038
0.021324878558516502
0.02117571048438549
0.02102884091436863
0.020884212106466293
0.02074178121984005
0.020601525902748108
0.020463405176997185
0.020327381789684296


In [None]:
for h in hooks:
  h.remove()

In [None]:
model

Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (down_proj): Linear(in_features=1408, out_features=2048, bias=False)
        

In [None]:
import os

adapter_dir = "/content/drive/MyDrive/AdaptersMoE"
os.makedirs(adapter_dir, exist_ok=True)


In [None]:
# ✅ Colab-Ready Notebook: Functional Similarity Testing for MoE Experts (Decoder-Only)

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader
import pandas as pd

layer1 = 1
layer2 = 2
num_experts = len(model.model.layers[layer1].mlp.experts)
results = []

selected_experts = [0, 1, 2, 3, 4, 5, 6, 7]  # any 8 unique IDs from 0–59

for eid1 in selected_experts:
    for eid2 in selected_experts:

# for eid1 in range(num_experts):
#     for eid2 in range(num_experts):
        if (layer2, eid2, layer1, eid1) in [(l2, e2, l1, e1) for (l1, e1, l2, e2, *_ ) in results]:
            continue

        print(f"Analyzing: Layer {layer1} Expert {eid1} ↔ Layer {layer2} Expert {eid2}")
        expert_pairs.clear()
        layer_inputs.clear()
        hooks = attach_moe_hooks(model, (layer1, eid1), (layer2, eid2))

        with torch.no_grad():
            for batch in train_loader:
                input_ids = torch.stack(batch['input_ids']).to('cuda')
                model(input_ids)
                if len(expert_pairs[(layer1, eid1)]) >= 5000:
                    break

        for h in hooks: h.remove()

        try:
            z1, f1_out = get_aligned_data(layer1, eid1, model.device)
            f2 = model.model.layers[layer2].mlp.experts[eid2]
            adapter, loss = train_adapter_aligning_experts(z1, f1_out, f2)
            results.append((layer1, eid1, layer2, eid2, loss))

            # Save adapter
            fname = f"{adapter_dir}/adapter_L{layer1}E{eid1}_to_L{layer2}E{eid2}.pt"
            torch.save(adapter.state_dict(), fname)
            print(f"✅ Saved: {fname} | Loss: {loss:.4f}")
        except Exception as e:
            print(f"⚠️ Skipped pair ({layer1},{eid1}) → ({layer2},{eid2}) due to error: {e}")


# -------------------
# Result Table
# -------------------
print("\n\n====== Functional Similarity Table (MSE Loss) ======")
df = pd.DataFrame("-", index=[f"E{eid1}" for eid1 in range(num_experts)], columns=[f"E{eid2}" for eid2 in range(num_experts)])
for _, eid1, _, eid2, loss in results:
    df.loc[f"E{eid1}", f"E{eid2}"] = f"{loss:.4f}"

print(df)


Analyzing: Layer 1 Expert 0 ↔ Layer 2 Expert 0
9.779531478881836
✅ Saved: /content/drive/MyDrive/AdaptersMoE/adapter_L1E0_to_L2E0.pt | Loss: 9.7795
Analyzing: Layer 1 Expert 0 ↔ Layer 2 Expert 1
6.640830993652344
✅ Saved: /content/drive/MyDrive/AdaptersMoE/adapter_L1E0_to_L2E1.pt | Loss: 6.6408
Analyzing: Layer 1 Expert 0 ↔ Layer 2 Expert 2
6.242130756378174
✅ Saved: /content/drive/MyDrive/AdaptersMoE/adapter_L1E0_to_L2E2.pt | Loss: 6.2421
Analyzing: Layer 1 Expert 0 ↔ Layer 2 Expert 3
7.418803691864014
✅ Saved: /content/drive/MyDrive/AdaptersMoE/adapter_L1E0_to_L2E3.pt | Loss: 7.4188
Analyzing: Layer 1 Expert 0 ↔ Layer 2 Expert 4
7.009343147277832
✅ Saved: /content/drive/MyDrive/AdaptersMoE/adapter_L1E0_to_L2E4.pt | Loss: 7.0093
Analyzing: Layer 1 Expert 0 ↔ Layer 2 Expert 5
6.691318511962891
✅ Saved: /content/drive/MyDrive/AdaptersMoE/adapter_L1E0_to_L2E5.pt | Loss: 6.6913
Analyzing: Layer 1 Expert 0 ↔ Layer 2 Expert 6
5.898207664489746
✅ Saved: /content/drive/MyDrive/AdaptersMoE/ada

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-16-2797085562>", line 34, in <cell line: 0>
    model(input_ids)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 1161, in forward
    outputs: MoeModelOutputW

TypeError: object of type 'NoneType' has no len()

In [None]:
z2_all = []

def make_aligned_expert_hook(layer_id, expert_id):
    def hook_fn(module, input, output):
        z = input[0].detach().cpu()
        fz = output.detach().cpu()
        for i in range(z.shape[0]):
            expert_pairs[(layer_id, expert_id)].append((z[i], fz[i]))
        if layer_id == layer2:
            z2_all.append(z)
    return hook_fn



In [None]:
z2 = torch.cat(z2_all, dim=0).to(model.device)


In [None]:
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA

In [None]:
def visualize_distribution_stats(z1, y1, z2):
    """Visualize the raw value distributions of z1, y1, and z2 using histograms."""
    plt.figure(figsize=(12, 5))

    for idx, (data, label) in enumerate(zip([z1, y1, z2], ['z1 (input L1)', 'y1 (output L1)', 'z2 (input L2)'])):
        plt.subplot(1, 3, idx+1)
        flattened = data.detach().cpu().numpy().flatten()
        plt.hist(flattened, bins=100, alpha=0.7, color='C' + str(idx))
        plt.title(f"{label}\nMean: {flattened.mean():.4f}, Std: {flattened.std():.4f}")
        plt.grid(True)

    plt.suptitle("Distribution of Activation Values")
    plt.tight_layout()
    plt.show()


In [None]:
def visualize_pca_distributions(z1, y1, z2):
    pca = PCA(n_components=2)
    all_data = torch.cat([z1, y1, z2], dim=0).cpu().numpy()
    reduced = pca.fit_transform(all_data)
    n = z1.size(0)
    plt.figure(figsize=(8, 6))
    plt.scatter(reduced[:n, 0], reduced[:n, 1], label="z1 (Input Layer 1)", alpha=0.5)
    plt.scatter(reduced[n:2*n, 0], reduced[n:2*n, 1], label="y1 (Output Layer 1)", alpha=0.5)
    plt.scatter(reduced[2*n:, 0], reduced[2*n:, 1], label="z2 (Input Layer 2)", alpha=0.5)
    plt.legend()
    plt.title("PCA of z1, y1, z2")
    plt.grid(True)
    plt.show()


In [None]:
def attach_all_hooks(model, layer1, layer2, expert_ids):
    hooks = []
    for lid in [layer1, layer2]:

        moe_block = model.model.layers[lid].mlp
        hooks.append(moe_block.register_forward_pre_hook(make_router_input_hook(lid)))
        for eid in expert_ids:
            expert = moe_block.experts[eid]
            hooks.append(expert.register_forward_hook(make_aligned_expert_hook(lid, eid)))
    return hooks

In [None]:
# selected_experts = [0, 4, 8, 12, 16, 20, 24, 28]
selected_experts = [1, 15, 20, 25, 45, 59]

In [None]:
for h in hooks:
  h.remove()

In [None]:
layer1 = 1
layer2 = 20
hooks = attach_all_hooks(model, layer1, layer2, selected_experts)

with torch.no_grad():
    for batch in train_loader:
        input_ids = torch.stack(batch['input_ids']).to(model.device)
        model(input_ids)
        # You can optionally stop early
        if all(len(expert_pairs[(layer1, eid)]) >= 100 for eid in selected_experts):
            break
        else:
          print(len(expert_pairs[(layer1, eid)]) for eid in selected_experts)

for h in hooks: h.remove()


<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>
<generator object <genexpr> at 0x7f4ab0c47c40>


KeyboardInterrupt: 

In [None]:
import os
adapter_dir = "/content/drive/MyDrive/AdaptersMoE"
os.makedirs(adapter_dir, exist_ok=True)

In [None]:
results = []
for eid1 in selected_experts:
    for eid2 in selected_experts:
        if (layer2, eid2, layer1, eid1) in [(l2, e2, l1, e1) for (l1, e1, l2, e2, *_ ) in results]:
            continue
        try:
            z1, f1_out = get_aligned_data(layer1, eid1, model.device, max_samples=200)
            f2 = model.model.layers[layer2].mlp.experts[eid2]
            adapter, loss = train_adapter_aligning_experts(z1, f1_out, f2)
            results.append((layer1, eid1, layer2, eid2, loss))
            # Save adapter
            fname = f"{adapter_dir}/adapter_L{layer1}E{eid1}_to_L{layer2}E{eid2}.pt"
            torch.save(adapter.state_dict(), fname)
            print(f"✅ Saved: {fname} | Loss: {loss:.4f}")
        except Exception as e:
            print(f"⚠️ Failed: ({eid1} → {eid2}) - {e}")


In [None]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [None]:
del train_loader

In [None]:
print(len(expert_pairs[(1, 0)]))
print(len(expert_pairs[(1, 30)]))

1908
188
