In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from safetensors.torch import load_file

In [3]:
l = 15
lr = 8e-05
l1coef = 1.5
modelnames = ["base_llama", "Instruct_llama", "o_llama"]
cc_paths = [
    f"/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/zf_projects/Language-Model-SAEs/JR_CC_tests/{subject_model}_l{l}_32x_lr{lr}_jumprelu_l1coef{l1coef}/sae_weights.safetensors"
    for subject_model in modelnames
]

In [4]:
path = cc_paths[0]
path

'/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/zf_projects/Language-Model-SAEs/JR_CC_tests/base_llama_l15_32x_lr8e-05_jumprelu_l1coef1.5/sae_weights.safetensors'

In [5]:
ckpt = load_file(path)

In [6]:
ckpt["decoder.weight"].shape

torch.Size([4096, 131072])

In [7]:
decoders = []

for path in cc_paths:
    ckpt = load_file(path)

    # scaling_factor = (4096 ** .5) / ckpt[f'dataset_average_activation_norm.blocks.{layer}.hook_resid_post']
    decoders.append(ckpt["decoder.weight"].to("cuda"))

In [8]:
dec_norms = []

with torch.no_grad():
    for dec in decoders:
        dec_norms.append(dec.norm(dim=0))

dec_norms = torch.stack(dec_norms)

In [9]:
dec_norms.shape

torch.Size([3, 131072])

In [10]:
dec_norms.sum(dim=0)

tensor([2.3315, 2.9109, 2.7694,  ..., 2.4798, 2.8669, 3.9965], device='cuda:0')

In [11]:
dec_norms = dec_norms / dec_norms.sum(dim=0)  # values?

In [12]:
dec_norms

tensor([[0.3312, 0.3456, 0.3289,  ..., 0.1393, 0.3360, 0.3968],
        [0.3440, 0.3370, 0.3348,  ..., 0.1557, 0.3291, 0.3378],
        [0.3248, 0.3174, 0.3364,  ..., 0.7049, 0.3348, 0.2654]],
       device='cuda:0')

In [13]:
dec_norms = dec_norms.cpu().numpy()

In [14]:
dec_norms = dec_norms.T
dec_norms.sum(axis=1)

array([1., 1., 1., ..., 1., 1., 1.], shape=(131072,), dtype=float32)

In [15]:
data = dec_norms[:1000]
data

array([[0.33122045, 0.343972  , 0.32480755],
       [0.34562233, 0.33700696, 0.31737068],
       [0.3288614 , 0.33478662, 0.33635196],
       ...,
       [0.34554592, 0.34522504, 0.30922902],
       [0.29728433, 0.299494  , 0.40322167],
       [0.3339067 , 0.33178842, 0.3343049 ]],
      shape=(1000, 3), dtype=float32)

In [20]:
ids = list(range(len(data)))

In [16]:
import numpy as np
import plotly.graph_objects as go

In [17]:
num_points = data.shape[0]

# kde = gaussian_kde(data.T)
# densities = kde(data.T)
# dens_norm = (densities - densities.min()) / (densities.max() - densities.min())
# alphas = 0.2 +  dens_norm * (1 - 0.2)
# rgba_colors = []
# for pt, alpha in zip(data, alphas):
#     # Multiply each coordinate by 256 and clip to 255 if needed
#     r = int(np.clip(pt[0] * 256, 0, 255))
#     g = int(np.clip(pt[1] * 256, 0, 255))
#     b = int(np.clip(pt[2] * 256, 0, 255))
#     rgba = f'rgba({r}, {g}, {b}, {alpha:.2f})'
#     rgba_colors.append(rgba)

alphas = np.ones((num_points,)) * 0.5
rgba_colors = []
for pt, alpha in zip(data, alphas):
    # Multiply each coordinate by 256 and clip to 255 if needed
    r = int(np.clip(pt[0] * 256, 0, 255))
    g = int(np.clip(pt[1] * 256, 0, 255))
    b = int(np.clip(pt[2] * 256, 0, 255))
    rgba = f"rgba({r}, {g}, {b}, {alpha})"
    rgba_colors.append(rgba)

In [22]:
# color_values = ['rgba({}, {}, {}, {})'.format(r, g, b, a)
#                 for (r, g, b), a in zip(rgb_values, opacity_values)]

axis_lines = []

# x-axis
axis_lines.append(
    go.Scatter3d(x=[0, 1], y=[0, 0], z=[0, 0], mode="lines", line=dict(color="red", width=4), showlegend=False)
)

# y-axis
axis_lines.append(
    go.Scatter3d(x=[0, 0], y=[0, 1], z=[0, 0], mode="lines", line=dict(color="green", width=4), showlegend=False)
)

# z-axis
axis_lines.append(
    go.Scatter3d(x=[0, 0], y=[0, 0], z=[0, 1], mode="lines", line=dict(color="blue", width=4), showlegend=False)
)

fig = go.Figure(
    data=[
        go.Scatter3d(
            x=data[:, 0],
            y=data[:, 1],
            z=data[:, 2],
            mode="markers",
            marker=dict(
                size=5,
                color=rgba_colors,  # 使用 RGBA 颜色编码，间接控制透明度
            ),
            text=ids,
        )
    ]
    + axis_lines
)

# 更新布局以适应自定义的坐标系
# fig.update_layout(
#     scene=dict(
#         xaxis=dict(range=[0, 1], title='B', showgrid=False, zeroline=True),
#         yaxis=dict(range=[0, 1], title='I', showgrid=False, zeroline=True),
#         zaxis=dict(range=[0, 1], title='O', showgrid=False, zeroline=True),
#         aspectmode='cube',  # 让xyz轴比例固定
#         aspectratio=dict(x=1, y=1, z=1)
#     ),
#     title="BOI-diff"
# )
fig.update_layout(
    scene=dict(
        xaxis=dict(range=[-0.1, 1.1], title="B", showgrid=False, zeroline=True),
        yaxis=dict(range=[-0.1, 1.1], title="I", showgrid=False, zeroline=True),
        zaxis=dict(range=[-0.1, 1.1], title="O", showgrid=False, zeroline=True),
    ),
    title="BOI-diff",
    width=1200,
    height=1200,
)

fig.show()