In [None]:
import torch
from einops import rearrange
import matplotlib.pyplot as plt

record = torch.load("/data/kxia2/mamba/artifacts/mamba2-130m-thepile_new00/params_for_debug_mamba2-130m_thepile_debug00.pt", map_location="cpu")

for layer in range(48):
    tA = (record['delta_t'][layer]*record['A'][layer]).squeeze(0)
    A_cumsum = torch.cumsum(tA, dim=0)
    A_cumsum = torch.log10(torch.exp(A_cumsum))

    plt.figure(figsize=(12, 8))
    for h in range(A_cumsum.shape[1]):  # Iterate over all h
        plt.plot(range(A_cumsum.shape[0]), A_cumsum[:, h].numpy(), label=f"channel idx={h}")
    
    plt.title(f"Layer {layer}")
    plt.xlabel("token idx")
    plt.ylabel("Cumulative decay (log)")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small', ncol=2)
    plt.tight_layout()
    plt.show()

In [1]:
import torch
from einops import rearrange
import matplotlib.pyplot as plt
import matplotlib.cm as cm

record = torch.load(
    "/data/kxia2/mamba/artifacts/mamba2-130m-thepile_new00/params_for_debug_mamba2-130m_thepile_debug00.pt", 
    map_location="cpu"
)

for layer in range(24):
    tA = (record['delta_t'][layer] * record['A'][layer]).squeeze(0)
    A_cumsum = torch.cumsum(tA, dim=0)
    A_cumsum = torch.log10(torch.exp(A_cumsum))

    threshold = -4
    split_idx = 2000 if A_cumsum.shape[0] > 2000 else A_cumsum.shape[0] - 1
    below_threshold = A_cumsum[split_idx, :] < threshold
    above_threshold = ~below_threshold

    local_colors = cm.Blues(torch.linspace(0.4, 1, A_cumsum.shape[1]))
    global_colors = cm.Reds(torch.linspace(0.4, 1, A_cumsum.shape[1]))

    plt.figure(figsize=(16, 8))

    for h in range(A_cumsum.shape[1]):
        if below_threshold[h]:
            plt.plot(
                range(A_cumsum.shape[0]), A_cumsum[:, h].numpy(),
                color=local_colors[h], alpha=0.8,
                label=f"Channel {h} (Local)", linewidth=2
            )
    
    for h in range(A_cumsum.shape[1]):
        if above_threshold[h]:
            plt.plot(
                range(A_cumsum.shape[0]), A_cumsum[:, h].numpy(),
                color=global_colors[h], alpha=0.8,
                label=f"Channel {h} (Global)", linewidth=2
            )

    # plt.title(f"Layer {layer}", fontsize=18)
    # plt.xlabel("Length", fontsize=14)
    # plt.ylabel("Cumulative Decay (log)", fontsize=14)
    plt.axvline(x=2000, color="gray", linestyle="--", alpha=0.8)
    plt.grid(alpha=0.5)
    ax = plt.gca()
    ax.spines["top"].set_linewidth(2)
    ax.spines["right"].set_linewidth(2)
    ax.spines["left"].set_linewidth(2)
    ax.spines["bottom"].set_linewidth(2)
    plt.xlim(-450, A_cumsum.shape[0])
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    plt.legend(by_label.values(), by_label.keys(), loc="upper left", fontsize=12)

    plt.tight_layout()

    plt.savefig(f"/data/kxia2/mamba/visualization/mamba2_decay_layer{layer}.png")
    plt.close()

In [9]:
import torch
from einops import rearrange
import matplotlib.pyplot as plt
import matplotlib.cm as cm

record = torch.load(
    "/data/kxia2/mamba/artifacts/mamba2-130m-thepile_new00/params_for_debug_mamba2-130m_thepile_debug00.pt", 
    map_location="cpu"
)
delta_t = record['delta_t'][0].repeat(1,10,1)
delta_t.shape

torch.Size([1, 20000, 24])

In [24]:
import torch
from einops import rearrange
import matplotlib.pyplot as plt
import matplotlib.cm as cm

def get_channelwise_dt_threshold(delta_t, dt_thre=None, response_length=0):
    L = delta_t.shape[2]
    L_for_dec = L-response_length
    mask = delta_t > dt_thre.unsqueeze(0).unsqueeze(-1)
    if response_length > 0:
        mask[:, :, L_for_dec:] = True
    return mask

record = torch.load(
    "/data/kxia2/mamba/artifacts/mamba2-130m-thepile_new00/params_for_debug_mamba2-130m_thepile_debug00.pt", 
    map_location="cpu"
)

layer = 13
# for n in [2, 5, 10, 20, 30, 40, 50, 60]:
for n in [1, 2, 5, 10, 20, 30, 40, 50, 60]:
    length = n * 2000
    delta_t = record['delta_t'][layer].repeat(1,n,1)
    indices = torch.randperm(delta_t.size(1))
    delta_t = delta_t[:, indices, :]
    tA = (delta_t * record['A'][layer]).squeeze(0)
    A_cumsum = torch.cumsum(tA, dim=0)
    A_cumsum_initial = torch.log10(torch.exp(A_cumsum))

    threshold = -4
    split_idx = 2000 if A_cumsum.shape[0] > 2000 else A_cumsum.shape[0] - 1
    below_threshold = A_cumsum_initial[split_idx, :] < threshold
    above_threshold = ~below_threshold

    delta_t = rearrange(delta_t, "b l h -> b h l")
    dt_thre_path = f"/data/kxia2/mamba/artifacts/mamba2-130m-thepile_newavg/delta_t-thre/delta_t-thre_layer_{layer}.pt"
    dt_thre_all = torch.load(dt_thre_path, map_location=delta_t.device)
    available_values = [int(k[:-1])*1e3 for k in dt_thre_all]
    key_num = int(min(available_values, key=lambda x: abs(length - x))/1e3) if available_values else None
    channel_dt_thre_all = dt_thre_all[f"{key_num}k"].to(delta_t.device)
    topk_mask = get_channelwise_dt_threshold(delta_t=delta_t[:, above_threshold], dt_thre=channel_dt_thre_all[above_threshold])
    delta_t[:, above_threshold] = torch.where(topk_mask, delta_t[:, above_threshold], delta_t[:, above_threshold] * 0)
    delta_t = rearrange(delta_t, "b h l -> b l h")
    tA = (delta_t * record['A'][layer]).squeeze(0)
    A_cumsum = torch.cumsum(tA, dim=0)
    A_cumsum_recalculated = torch.log10(torch.exp(A_cumsum))

    local_colors = cm.Blues(torch.linspace(0.4, 1, A_cumsum.shape[1]))
    global_colors_initial = cm.Reds(torch.linspace(0.4, 1, A_cumsum.shape[1]))
    # global_colors_recalculated = cm.Greens(torch.linspace(0.4, 1, A_cumsum.shape[1]))
    global_colors_recalculated = global_colors_initial

    plt.figure(figsize=(16, 8))

    # for h in range(A_cumsum.shape[1]):
    #     if below_threshold[h]:
    #         plt.plot(
    #             range(A_cumsum.shape[0]), A_cumsum_initial[:, h].numpy(),
    #             color=local_colors[h], alpha=0.8,
    #             label=f"Channel {h} (Local)", linewidth=2
    #         )

    for h in range(A_cumsum.shape[1]):
        if above_threshold[h]:
            if n == 1:
                plt.plot(
                    range(A_cumsum.shape[0]), A_cumsum_initial[:, h].numpy(),
                    color=global_colors_initial[h], alpha=0.8,
                    label=f"Channel {h} (Align Target)", linewidth=2
                )
            else:
                plt.plot(
                    range(A_cumsum.shape[0]), A_cumsum_initial[:, h].numpy(),
                    color=global_colors_initial[h], alpha=0.8,
                    label=f"Channel {h} (Global Origin)", linewidth=2
                )
                plt.plot(
                    range(A_cumsum.shape[0]), A_cumsum_recalculated[:, h].numpy(),
                    color=global_colors_recalculated[h], alpha=0.8,
                    label=f"Channel {h} (Global Ours)", linewidth=2, linestyle='--'
                )

    # plt.axvline(x=2000, color="gray", linestyle="--", alpha=0.8)
    plt.ylim(-35, 3)
    plt.grid(alpha=0.5)
    ax = plt.gca()
    ax.spines["top"].set_linewidth(2)
    ax.spines["right"].set_linewidth(2)
    ax.spines["left"].set_linewidth(2)
    ax.spines["bottom"].set_linewidth(2)
    plt.xlim(-510*n, A_cumsum.shape[0])
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    plt.legend(by_label.values(), by_label.keys(), loc="upper left", fontsize=10)

    plt.tight_layout()
    plt.savefig(f"/data/kxia2/mamba/visualization/mamba2_decay_longmamba_len{length}_layer{layer}.png")
    plt.close()

In [None]:
import torch
from einops import rearrange
import matplotlib.pyplot as plt
import matplotlib.cm as cm

def get_channelwise_dt_threshold(delta_t, dt_thre=None, response_length=0):
    L = delta_t.shape[2]
    L_for_dec = L-response_length
    mask = delta_t > dt_thre.unsqueeze(0).unsqueeze(-1)
    if response_length > 0:
        mask[:, :, L_for_dec:] = True
    return mask

record = torch.load(
    "/data/kxia2/mamba/artifacts/mamba2-130m-thepile_new00/params_for_debug_mamba2-130m_thepile_debug00.pt", 
    map_location="cpu"
)

layer = 13
target = None
# for n in [2, 5, 10, 20, 30, 40, 50, 60]:
for n in [1, 2, 5, 10, 20, 30, 40, 50, 60]:
    length = n * 2000
    delta_t = record['delta_t'][layer].repeat(1,n,1)
    indices = torch.randperm(delta_t.size(1))
    delta_t = delta_t[:, indices, :]
    tA = (delta_t * record['A'][layer]).squeeze(0)
    A_cumsum = torch.cumsum(tA, dim=0)
    A_cumsum_initial = torch.log10(torch.exp(A_cumsum))

    threshold = -4
    split_idx = 2000 if A_cumsum.shape[0] > 2000 else A_cumsum.shape[0] - 1
    below_threshold = A_cumsum_initial[split_idx, :] < threshold
    above_threshold = ~below_threshold

    delta_t = rearrange(delta_t, "b l h -> b h l")
    dt_thre_path = f"/data/kxia2/mamba/artifacts/mamba2-130m-thepile_newavg/delta_t-thre/delta_t-thre_layer_{layer}.pt"
    dt_thre_all = torch.load(dt_thre_path, map_location=delta_t.device)
    available_values = [int(k[:-1])*1e3 for k in dt_thre_all]
    key_num = int(min(available_values, key=lambda x: abs(length - x))/1e3) if available_values else None
    channel_dt_thre_all = dt_thre_all[f"{key_num}k"].to(delta_t.device)
    topk_mask = get_channelwise_dt_threshold(delta_t=delta_t[:, above_threshold], dt_thre=channel_dt_thre_all[above_threshold])
    delta_t[:, above_threshold] = torch.where(topk_mask, delta_t[:, above_threshold], delta_t[:, above_threshold] * 0)
    delta_t = rearrange(delta_t, "b h l -> b l h")
    tA = (delta_t * record['A'][layer]).squeeze(0)
    A_cumsum = torch.cumsum(tA, dim=0)
    A_cumsum_recalculated = torch.log10(torch.exp(A_cumsum))

    local_colors = cm.Blues(torch.linspace(0.4, 1, A_cumsum.shape[1]))
    global_colors_initial = cm.Reds(torch.linspace(0.4, 1, A_cumsum.shape[1]))
    # target_cm = cm.Greens(torch.linspace(0.4, 1, A_cumsum.shape[1]))
    global_colors_recalculated = global_colors_initial

    plt.figure(figsize=(16, 8))

    if n == 1:
        target = A_cumsum_initial
        continue
    else:
        for h in range(A_cumsum.shape[1]):
            if above_threshold[h]:
                plt.plot(
                    [0, A_cumsum.shape[0]],
                    [0, target[1999][h].item()],
                    color=global_colors_initial[h],
                    linestyle="-.",
                    alpha=0.8,
                    label=f"Channel {h} (Align Target)"
                )
                plt.plot(
                    range(A_cumsum.shape[0]), A_cumsum_initial[:, h].numpy(),
                    color=global_colors_initial[h], alpha=0.8,
                    label=f"Channel {h} (Global Origin)", linewidth=1.5
                )
                plt.plot(
                    range(A_cumsum.shape[0]), A_cumsum_recalculated[:, h].numpy(),
                    color=global_colors_recalculated[h], alpha=0.8,
                    label=f"Channel {h} (Global Ours)", linewidth=1.5, linestyle='--'
                )

    plt.xlim(-0.22*A_cumsum.shape[0], A_cumsum.shape[0]*1.05)
    plt.ylim(-30, None)

    plt.grid(alpha=0.5)
    ax = plt.gca()
    ax.spines["top"].set_linewidth(2)
    ax.spines["right"].set_linewidth(2)
    ax.spines["left"].set_linewidth(2)
    ax.spines["bottom"].set_linewidth(2)

    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    plt.legend(by_label.values(), by_label.keys(), loc="upper left", fontsize=10)

    # 保存图像
    plt.tight_layout()
    plt.savefig(f"/data/kxia2/mamba/visualization/mamba2_decay_longmamba_len{length}_layer{layer}.png")
    plt.close()
