In [None]:
from __future__ import annotations

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots as sp
import torch


In [168]:
def _discrete_colorscale(name: str, levels: int):
    # sample at evenly spaced points
    colors = px.colors.sample_colorscale(name, np.linspace(0, 1, levels))

    # reformat to [ [t0, col0], [t1, col0], [t1, col1], ... ]
    scale = []
    for i, c in enumerate(colors):
        t0 = i / levels
        t1 = (i + 1) / levels
        scale.append([t0, c])
        scale.append([t1, c])
    return scale

def show_matrix(m: torch.Tensor, title: str, levels: int = 7):
    arr = m.numpy()
    vmin, vmax = float(arr.min()), float(arr.max())
    fig = go.Figure(
        go.Heatmap(
            z=arr,
            zmin=vmin,
            zmax=vmax,
            colorscale=_discrete_colorscale("Viridis", levels),
            # zsmooth="none",
            colorbar=dict(
                tickmode="array",
                tickvals=np.linspace(vmin, vmax, levels),
            )
        )
    )
    fig.update_layout(title=title, height=400, width=500)
    return fig

In [169]:
def make_qkv(H: int = 2, hd: int = 2) -> torch.Tensor:
    """Make a qkv [D_in, 3*H*hd] with some distinctive patterns.

    Q = block 0, K = block 1, V = block 2.
    """
    D = H * hd
    q = torch.cat([torch.full((D, hd), v, dtype=torch.float) for v in range(H)], dim=1)  # [D, H*hd]
    k = torch.cat([torch.full((D, hd), 2*v, dtype=torch.float) for v in range(H)], dim=1)  # [D, H*hd]
    v = torch.cat([torch.full((D, hd), 3*v, dtype=torch.float) for v in range(H)], dim=1)  # [D, H*hd]
    return torch.cat([q, k, v], dim=1)  # [D, 3*H*hd]

def make_qkv_checkered(H: int, hd: int) -> torch.Tensor:
    """
    head 0 gets a checkerboard
    head 1 gets a vertical stripe
    head 2 gets a gradient
    """
    D = H * hd

    # Build a D x hd checkerboard pattern for head 0
    base = torch.arange(D).unsqueeze(1) % 2  # [D,1] alternating 0/1
    checker = base.repeat(1, hd).float()     # [D,hd]

    # per-head blocks
    q_blocks = []
    k_blocks = []
    v_blocks = []

    for h in range(H):
        # offset each block
        q_blocks.append(checker + 1*h)
        k_blocks.append(checker + 2*h)
        v_blocks.append(checker + 3*h)

    q = torch.cat(q_blocks, dim=1)  # [D, H*hd]
    k = torch.cat(k_blocks, dim=1)
    v = torch.cat(v_blocks, dim=1)

    return torch.cat([q, k, v], dim=1)  # [D, 3*H*hd]

def permute_heads(qkv: torch.Tensor, perm: torch.Tensor, H: int, hd: int) -> torch.Tensor:
    """Permutes the head axis.

    qkv: [D, 3*H*hd]
     - reshape to [D, 3, H, hd]
      - permute H
        - reshape back
    """
    D = H * hd
    qkv4 = qkv.view(D, 3, H, hd)
    qkv4 = qkv4[:, :, perm, :]  # permute head axis
    return qkv4.reshape(D, 3 * H * hd)


In [170]:
H, hd = 2, 2
qkv = make_qkv(H, hd)
# qkv = make_qkv_checkered(H, hd)
perm_true = torch.tensor([1, 0])

show_matrix(qkv, "Original QKV").show()

qkv_perm = permute_heads(qkv, perm_true, H, hd)
show_matrix(qkv_perm, "Permuted QKV").show()
print(perm_true)

tensor([1, 0])


In [171]:
qkv

tensor([[0., 0., 1., 1., 0., 0., 2., 2., 0., 0., 3., 3.],
        [0., 0., 1., 1., 0., 0., 2., 2., 0., 0., 3., 3.],
        [0., 0., 1., 1., 0., 0., 2., 2., 0., 0., 3., 3.],
        [0., 0., 1., 1., 0., 0., 2., 2., 0., 0., 3., 3.]])

In [172]:
def extract_head_features(qkv: torch.Tensor, H: int, hd: int) -> torch.Tensor:
    """get flattened vectors of shape [H, D*3*hd]."""
    D = H * hd
    qkv4 = qkv.view(D, 3, H, hd)  # [D, 3, H, hd]
    return qkv4.permute(2, 0, 1, 3).reshape(H, -1)


def greedy_match(cost: torch.Tensor) -> torch.Tensor:
    """
    cost[i, j] = distance between head i (ref) and head j (client).
    """
    H = cost.shape[0]
    perm = torch.empty(H, dtype=torch.long)
    unused = set(range(H))
    for i in range(H):
        row = cost[i]
        best_j = min(unused, key=lambda j: row[j].item())
        perm[i] = best_j
        unused.remove(best_j)
    return perm

In [173]:
H, hd = 3, 2
qkv_ref = make_qkv(H, hd)  # [D, 3*H*hd]
# qkv_ref = make_qkv_checkered(H, hd) # [D, 3*H*hd]
# print(qkv_ref.shape)

# reverse
perm_true = torch.tensor([2, 1, 0])
qkv_client = permute_heads(qkv_ref, perm_true, H, hd)

ref_feat = extract_head_features(qkv_ref, H, hd)
cli_feat = extract_head_features(qkv_client, H, hd)
cost = torch.cdist(ref_feat, cli_feat, p=2)

px.imshow(
    cost.numpy(), text_auto=True, color_continuous_scale="Plasma", title="Cost Matrix"
).show()

perm_est = greedy_match(cost)
print("True permutation:", perm_true.tolist())
print("Recovered permutation:", perm_est.tolist())

True permutation: [2, 1, 0]
Recovered permutation: [2, 1, 0]


In [140]:
def decompose_qkv(qkv: torch.Tensor, H: int, hd: int):
    D = H * hd
    qkv4 = qkv.view(D, 3, H, hd)       # [D, 3*H*hd] to [D, 3, H, hd]
    q = qkv4[:, 0]                     # [D, H, hd]
    k = qkv4[:, 1]
    v = qkv4[:, 2]
    # transpose to [H, D, hd]
    return q.permute(1, 0, 2), k.permute(1, 0, 2), v.permute(1, 0, 2)


def plot_heads(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, title_prefix: str):
    H = q.shape[0]
    # fig = sp.make_subplots(rows=H, cols=3, subplot_titles=[
    #     for i in range(H)
    # ])
    titles = []
    for i in range(H):
        titles.extend([f"{title_prefix} Head {i} Q", f"{title_prefix} Head {i} K", f"{title_prefix} Head {i} V"])
    fig = sp.make_subplots(rows=H, cols=3, subplot_titles=titles)

    for h in range(H):
        fig.add_trace(px.imshow(q[h].numpy()).data[0], row=h+1, col=1)
        fig.add_trace(px.imshow(k[h].numpy()).data[0], row=h+1, col=2)
        fig.add_trace(px.imshow(v[h].numpy()).data[0], row=h+1, col=3)

    fig.update_layout(height=300*H, width=900, title_text=f"{title_prefix} Q/K/V Blocks")
    fig.show()

In [None]:
H, hd = 2, 2
D = H * hd

h0 = tensor([[0,1],[1,0],[0,1],[1,0]], dtype=torch.float)
h1 = tensor([[5,3],[5,3],[5,3],[5,3]], dtype=torch.float)
qkv_ref = torch.cat([torch.cat([h0, h1], dim=1)]*3, dim=1)

perm_true = torch.tensor([1, 0])

def permute_heads(qkv: torch.Tensor, perm: torch.Tensor, H: int, hd: int):
    D = H * hd
    qkv4 = qkv.view(D, 3, H, hd)
    qkv4 = qkv4[:, :, perm, :]
    return qkv4.reshape(D, 3 * H * hd)

qkv_perm = permute_heads(qkv_ref, perm_true, H, hd)
q_ref, k_ref, v_ref = decompose_qkv(qkv_ref, H, hd)
plot_heads(q_ref, k_ref, v_ref, "Original")
q_p, k_p, v_p = decompose_qkv(qkv_perm, H, hd)
plot_heads(q_p, k_p, v_p, "Permuted")
# plot_diff(qkv_ref, qkv_perm, H, hd, title="|Permuted - Original|")
diff = (qkv_perm - qkv_ref).abs()
show_matrix(diff, title="|Permuted - Original|")

ref_feat = extract_head_features(qkv_ref, H, hd)
cli_feat = extract_head_features(qkv_perm, H, hd)
cost = torch.cdist(ref_feat, cli_feat, p=2)

px.imshow(
    cost.numpy(), text_auto=True, color_continuous_scale="Plasma", title="Cost Matrix"
).show()

perm_est = greedy_match(cost)
print("True permutation:", perm_true.tolist())
print("Recovered permutation:", perm_est.tolist())

True permutation: [1, 0]
Recovered permutation: [1, 0]


In [143]:
import torch

# -------------------------------
# QKV matrix
# Shape = [4 rows, 12 columns]
#
# Q0 Q1 K0 K1 V0 V1
#
# Q0 = 10
# Q1 = 20
# K0 = 30
# K1 = 40
# V0 = 50
# V1 = 60
# -------------------------------

qkv = torch.tensor([
    [10,10, 20,20, 30,30, 40,40, 50,50, 60,60],
    [10,10, 20,20, 30,30, 40,40, 50,50, 60,60],
    [10,10, 20,20, 30,30, 40,40, 50,50, 60,60],
    [10,10, 20,20, 30,30, 40,40, 50,50, 60,60],
], dtype=torch.float)

# -------------------------------
# expected result of perm = [1, 0]
#
#  Q0 <-> Q1
#  K0 <-> K1
#  V0 <-> V1
#
# [Q1 Q0 K1 K0 V1 V0]
# -------------------------------

expected = torch.tensor([
    [20,20, 10,10, 40,40, 30,30, 60,60, 50,50],
    [20,20, 10,10, 40,40, 30,30, 60,60, 50,50],
    [20,20, 10,10, 40,40, 30,30, 60,60, 50,50],
    [20,20, 10,10, 40,40, 30,30, 60,60, 50,50],
], dtype=torch.float)

In [144]:
def permute_heads1(qkv, perm, H, hd):
    D = H * hd
    qkv4 = qkv.view(D, 3, H, hd)   # [4, 3, 2, 2]
    qkv4 = qkv4[:, :, perm, :]    # permute head axis
    return qkv4.reshape(D, 3*H*hd)

def permute_heads2(qkv, perm: torch.Tensor, H, hd):
    D = H * hd

    qkv = qkv.view(D, 3, H, hd)
    qkv = qkv[:, :, perm, :]
    qkv = qkv.view(D, 3 * D)
    # layer.mha.qkv.data.copy_(qkv)
    return qkv
# def permute_heads2(qkv, perm, H, hd):
#     return qkv4.reshape(D, 3*H*hd)

In [None]:
perm = torch.tensor([1, 0])
print(qkv.shape, H, hd)
print("\nOriginal:\n", qkv)
out = permute_heads1(qkv, perm, H, hd)
print("\nOutput:\n", out)
print("\nExpected:\n", expected)
print("\nMatch?", torch.equal(out, expected))

out = permute_heads2(qkv, perm, H, hd)
print("\nOutput:\n", out)
print("\nExpected:\n", expected)
print("\nMatch?", torch.equal(out, expected))

torch.Size([4, 12]) 2 2

Original:
 tensor([[10., 10., 20., 20., 30., 30., 40., 40., 50., 50., 60., 60.],
        [10., 10., 20., 20., 30., 30., 40., 40., 50., 50., 60., 60.],
        [10., 10., 20., 20., 30., 30., 40., 40., 50., 50., 60., 60.],
        [10., 10., 20., 20., 30., 30., 40., 40., 50., 50., 60., 60.]])

Output:
 tensor([[20., 20., 10., 10., 40., 40., 30., 30., 60., 60., 50., 50.],
        [20., 20., 10., 10., 40., 40., 30., 30., 60., 60., 50., 50.],
        [20., 20., 10., 10., 40., 40., 30., 30., 60., 60., 50., 50.],
        [20., 20., 10., 10., 40., 40., 30., 30., 60., 60., 50., 50.]])

Expected:
 tensor([[20., 20., 10., 10., 40., 40., 30., 30., 60., 60., 50., 50.],
        [20., 20., 10., 10., 40., 40., 30., 30., 60., 60., 50., 50.],
        [20., 20., 10., 10., 40., 40., 30., 30., 60., 60., 50., 50.],
        [20., 20., 10., 10., 40., 40., 30., 30., 60., 60., 50., 50.]])

Match? True
