In [1]:
import os
# os.environ['http_proxy'] = 'http://10.176.58.101:7890'
# os.environ['https_proxy'] = 'http://10.176.58.101:7890'

import sys
# sys.path.append('/remote-home1/jxwang/project/monofuctional_attn/')

import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
from transformer_lens import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.components import Attention
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR

import einops
import einsum

import copy

device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [2]:
import importlib
import src.model.toy_model

importlib.reload(src.model.toy_model)

from src.model.toy_model import toy_attn, toy_attn_model, toy_sparse_attn

In [3]:
num_embeddings = 6
num_unembeddings = 5
d_model = 2
d_head = 2
n_heads = 2
model = toy_attn_model(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings, d_model=d_model, d_head=d_head, n_heads=n_heads).to(device)

In [4]:
mapping = {
    (0, 1): 0,
    (0, 2): 1,
    (0, 3): 2,
    (0, 4): 3,
    (0, 5): 4,
}

def generate_train_data(num_embeddings, num_unembeddings, batch_size):
    batch = torch.zeros(batch_size, 2, dtype=torch.int32)
    batch[:, 0] = 0
    batch[:batch_size//2, 0] = torch.randint(1, 6, (batch_size//2,), dtype=torch.int32)
    batch[:, 1] = torch.randint(1, 6, (batch_size,), dtype=torch.int32)
    
    y = torch.tensor([
        mapping.get((batch[i][0].item(), batch[i][1].item()), torch.randint(0, 5, (1,)).item())
        for i in range(batch_size)
    ])
    
    return batch, y

def generate_valid_data(num_embeddings, num_unembeddings, batch_size):
    batch = torch.zeros(batch_size, 2, dtype=torch.int32)
    batch[:, 0] = 0
    batch[:, 1] = torch.randint(1, 6, (batch_size,), dtype=torch.int32)
    
    y = torch.tensor([mapping[(batch[i][0].item(), batch[i][1].item())] for i in range(batch_size)])
    
    return batch, y
def generate_test_data(num_embeddings, num_unembeddings):
    batch = torch.zeros(5, 2, dtype=torch.int32)
    batch[:, 0] = 0
    batch[:, 1] = torch.tensor([1, 2, 3, 4, 5])
    
    y = torch.tensor([mapping[(batch[i][0].item(), batch[i][1].item())] for i in range(5)])
    
    return batch, y

In [6]:
from tqdm.notebook import tqdm

batch_size = 32
num_batches = 2 ** 14
lr = 1e-2
optimizer = optim.SGD(model.parameters(), lr=lr)

pbar = tqdm(range(num_batches), desc='Training', unit='batch', dynamic_ncols=True)

for i in pbar:
    batch, y = generate_train_data(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings, batch_size=batch_size)
    batch, y = batch.to(device), y.to(device)
    
    logits = model(batch)
    
    p = logits[:, -1].softmax(dim=-1)
    
    loss = F.cross_entropy(p, y)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    with torch.no_grad():
        test_batch, test_y = generate_test_data(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings)
        test_batch, test_y = test_batch.to(device), test_y.to(device)
        
        test_logits = model(test_batch)
        test_p = test_logits[:, -1].softmax(dim=-1)
        
        test_loss = F.cross_entropy(test_p, test_y)
        
        test_preds = torch.argmax(test_p, dim=1)
        correct = (test_preds == test_y).sum().item()
        total = test_y.size(0)
        accuracy = correct / total
    
        pbar.set_postfix({'test_loss': test_loss.item(), 'accuracy': f'{accuracy:.4f}'})

Training:   0%|                                                                                               …

In [72]:
with torch.no_grad():
    batch, y = generate_test_data(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings)
    batch, y = batch.to(device), y.to(device)
    resid = model.input_to_embed(batch)
    out, out_heads, pattern = model.attn(resid)
    out, out_heads, pattern = out[:, -1].unsqueeze(dim=0), out_heads[:, -1], pattern[:, -1]
    unembeds = model.unembed.weight.unsqueeze(0)

print(out.shape, out_heads.shape, unembeds.shape)
print(pattern.shape)

torch.Size([1, 5, 2]) torch.Size([5, 2, 2]) torch.Size([1, 5, 2])
torch.Size([5, 2, 2])


In [51]:
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

vector_out_heads = [out_heads[:, i] for i in range(out_heads.size(1))]
vector_unembeds = [unembeds[i] for i in range(unembeds.size(0))]
vector_out = [out[i] for i in range(out.size(0))]

names_out_heads = [
    [f'{i}' for i in range(out_heads[:, j].size(0))] for j in range(out_heads.size(1))
]
names_unembeds = [
    [f'{i}' for i in range(unembeds[j].size(0))] for j in range(unembeds.size(0))
]
names_out = [
    [f'{i}' for i in range(out[j].size(0))] for j in range(out.size(0))
]

fig = make_subplots(rows=2, cols=len(vector_out_heads), subplot_titles=[f'head {i+1}' for i in range(len(vector_out_heads))])

for idx, (vectors, names) in enumerate(zip(vector_out_heads, names_out_heads)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=1, col=idx+1)

for idx, (vectors, names) in enumerate(zip(vector_unembeds, names_unembeds)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=2, col=idx+1)
        
for idx, (vectors, names) in enumerate(zip(vector_out, names_out)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=2, col=idx+2)
        
fig.update_layout(
    title='',
    width=300 * len(vector_out_heads),
    height=300 * 2,
    showlegend=False
)

for idx in range(len(vector_out_heads)):
    fig.update_xaxes(range=[-10, 10], row=1, col=idx+1, title_text='heads output')
    fig.update_yaxes(range=[-10, 10], row=1, col=idx+1, title_text='')

for idx in range(1):
    fig.update_xaxes(range=[-5, 5], row=2, col=idx+1, title_text='unembeddings')
    fig.update_yaxes(range=[-5, 5], row=2, col=idx+1, title_text='')
    fig.update_xaxes(range=[-10, 10], row=2, col=idx+2, title_text='output')
    fig.update_yaxes(range=[-10, 10], row=2, col=idx+2, title_text='')

# 显示图形
fig.show()


In [53]:
embeds = model.embed.weight.unsqueeze(0)
unembeds = model.unembed.weight.unsqueeze(0)
q = model.embed.weight @ model.attn.W_Q + model.attn.b_Q.unsqueeze(1)
k = model.embed.weight @ model.attn.W_K + model.attn.b_K.unsqueeze(1)
ov = (model.embed.weight @ model.attn.W_V + model.attn.b_V.unsqueeze(1)) @ model.attn.W_O

In [54]:
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# 示例：创建多个包含二维向量的张量
vector_q = [q[i] for i in range(q.size(0))]
vector_k = [k[i] for i in range(k.size(0))]
vector_ov = [ov[i] for i in range(ov.size(0))]
vector_embeds = [embeds[i] for i in range(embeds.size(0))]
vector_unembeds = [unembeds[i] for i in range(unembeds.size(0))]

names_q = [
    [f'{i}' for i in range(q[j].size(0))] for j in range(q.size(0))
]
names_k = [
    [f'{i}' for i in range(k[j].size(0))] for j in range(k.size(0))
]
names_ov = [
    [f'{i}' for i in range(ov[j].size(0))] for j in range(ov.size(0))
]
names_embeds = [
    [f'{i}' for i in range(embeds[j].size(0))] for j in range(embeds.size(0))
]
names_unembeds = [
    [f'{i}' for i in range(unembeds[j].size(0))] for j in range(unembeds.size(0))
]

fig = make_subplots(rows=4, cols=2, subplot_titles=[f'head {i+1}' for i in range(len(vector_q))])

for idx, (vectors, names) in enumerate(zip(vector_q, names_q)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=1, col=idx+1)

for idx, (vectors, names) in enumerate(zip(vector_k, names_k)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=2, col=idx+1)
        
for idx, (vectors, names) in enumerate(zip(vector_ov, names_ov)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=3, col=idx+1)
          
for idx, (vectors, names) in enumerate(zip(vector_embeds, names_embeds)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=4, col=idx+1)

for idx, (vectors, names) in enumerate(zip(vector_unembeds, names_unembeds)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=4, col=idx+2)
        
fig.update_layout(
    title='',
    width=400 * 2,
    height=400 * 4,
    showlegend=False
)

for idx in range(len(vector_q)):
    fig.update_xaxes(range=[-5, 5], row=1, col=idx+1, title_text='q')
    fig.update_yaxes(range=[-5, 5], row=1, col=idx+1, title_text='')

for idx in range(len(vector_k)):
    fig.update_xaxes(range=[-5, 5], row=2, col=idx+1, title_text='k')
    fig.update_yaxes(range=[-5, 5], row=2, col=idx+1, title_text='')

for idx in range(len(vector_ov)):
    fig.update_xaxes(range=[-5, 5], row=3, col=idx+1, title_text='ov')
    fig.update_yaxes(range=[-5, 5], row=3, col=idx+1, title_text='')
    
for idx in range(1):
    fig.update_xaxes(range=[-5, 5], row=4, col=idx+1, title_text='embed')
    fig.update_yaxes(range=[-5, 5], row=4, col=idx+1, title_text='')
    fig.update_xaxes(range=[-5, 5], row=4, col=idx+2, title_text='unembed')
    fig.update_yaxes(range=[-5, 5], row=4, col=idx+2, title_text='')

# 显示图形
fig.show()


In [99]:
sparse_d_model = 2
sparse_d_head = 1
sparse_n_heads = 6
sparse_model = toy_sparse_attn(d_model = sparse_d_model, d_head = sparse_d_head, n_heads = sparse_n_heads, use_l1_loss=True).to(device)

In [100]:
from tqdm.notebook import tqdm

batch_size = 4
num_batches = 2 ** 13
lr = 1e-2
l1_coefficient = 2
optimizer = optim.SGD(sparse_model.parameters(), lr=lr)

pbar = tqdm(range(num_batches), desc='Training', unit='batch', dynamic_ncols=True)

for i in pbar:
    batch, y = generate_test_data(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings)
    batch, y = batch.to(device), y.to(device)
    
    with torch.no_grad():
        resid = model.input_to_embed(batch)
        out, _, _ = model.attn(resid)
    
    out_hat,  _, _, l1 = sparse_model(resid)
    
    mse_loss = F.mse_loss(out[:, -1], out_hat[:, -1])
    l1_loss = l1[:, -1].mean()
    loss = mse_loss + l1_coefficient * l1_loss
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    logits = model.unembed(out_hat[:, -1])
    p = logits.softmax(dim=-1)
    preds = torch.argmax(p, dim=1)
    correct = (preds == y).sum().item()
    total = y.size(0)
    accuracy = correct / total
    
    pbar.set_postfix({'mse_loss': mse_loss.item(), 'l1_loss': l1_loss.item(), 'accuracy': f'{accuracy:.4f}'})

Training:   0%|                                                                                               …

In [101]:
with torch.no_grad():
    batch, y = generate_test_data(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings)
    batch, y = batch.to(device), y.to(device)
    resid = model.input_to_embed(batch)
    out, out_heads, pattern = sparse_model.attn(resid)
    out, out_heads, pattern = out[:, -1].unsqueeze(dim=0), out_heads[:, -1], pattern[:, -1]
    unembeds = model.unembed.weight.unsqueeze(0)

print(out.shape, out_heads.shape, unembeds.shape)
print(pattern.shape)

torch.Size([1, 5, 2]) torch.Size([5, 6, 2]) torch.Size([1, 5, 2])
torch.Size([5, 2, 2])


In [102]:
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

vector_out_heads = [out_heads[:, i] for i in range(out_heads.size(1))]
vector_unembeds = [unembeds[i] for i in range(unembeds.size(0))]
vector_out = [out[i] for i in range(out.size(0))]

names_out_heads = [
    [f'{i}' for i in range(out_heads[:, j].size(0))] for j in range(out_heads.size(1))
]
names_unembeds = [
    [f'{i}' for i in range(unembeds[j].size(0))] for j in range(unembeds.size(0))
]
names_out = [
    [f'{i}' for i in range(out[j].size(0))] for j in range(out.size(0))
]

fig = make_subplots(rows=2, cols=len(vector_out_heads), subplot_titles=[f'head {i+1}' for i in range(len(vector_out_heads))])

for idx, (vectors, names) in enumerate(zip(vector_out_heads, names_out_heads)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=1, col=idx+1)

for idx, (vectors, names) in enumerate(zip(vector_unembeds, names_unembeds)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=2, col=idx+1)
        
for idx, (vectors, names) in enumerate(zip(vector_out, names_out)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=2, col=idx+2)
        
fig.update_layout(
    title='',
    width=300 * len(vector_out_heads),
    height=300 * 2,
    showlegend=False
)

for idx in range(len(vector_out_heads)):
    fig.update_xaxes(range=[-10, 10], row=1, col=idx+1, title_text='heads output')
    fig.update_yaxes(range=[-10, 10], row=1, col=idx+1, title_text='')

for idx in range(1):
    fig.update_xaxes(range=[-5, 5], row=2, col=idx+1, title_text='unembeddings')
    fig.update_yaxes(range=[-5, 5], row=2, col=idx+1, title_text='')
    fig.update_xaxes(range=[-10, 10], row=2, col=idx+2, title_text='output')
    fig.update_yaxes(range=[-10, 10], row=2, col=idx+2, title_text='')

# 显示图形
fig.show()


In [177]:
import importlib
import src.model.toy_model

importlib.reload(src.model.toy_model)

from src.model.toy_model import toy_attn, toy_attn_model, toy_sparse_attn

In [187]:
sparse_d_model = 2
sparse_d_head = 1
sparse_n_heads = 6
sparse_model_2 = toy_sparse_attn(d_model = sparse_d_model, d_head = sparse_d_head, n_heads = sparse_n_heads, use_l1_loss=False, use_topk=True, top_k = 1).to(device)

In [189]:
from tqdm.notebook import tqdm

batch_size = 2
num_batches = 2 ** 10
lr = 1e-2
optimizer = optim.SGD(sparse_model_2.parameters(), lr=lr)

pbar = tqdm(range(num_batches), desc='Training', unit='batch', dynamic_ncols=True)

for i in pbar:
    batch, y = generate_train_data(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings, batch_size=batch_size)
    batch, y = batch.to(device), y.to(device)
    
    with torch.no_grad():
        resid = model.input_to_embed(batch)
        out, _, _ = model.attn(resid)
    
    _, _, _, _, top_k_out, _ = sparse_model_2(resid)
    
    mse_loss = F.mse_loss(out[:, -1], top_k_out[:, -1])
    loss = mse_loss
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    with torch.no_grad():
        test_batch, test_y = generate_test_data(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings)
        test_batch, test_y = test_batch.to(device), test_y.to(device)
        
        resid = model.input_to_embed(test_batch)
        _, _, _, _, top_k_out, _ = sparse_model_2(resid)
        
        test_logits = model.unembed(top_k_out[:, -1])
        
        test_p = test_logits.softmax(dim=-1)
        
        test_preds = torch.argmax(test_p, dim=1)
        correct = (test_preds == test_y).sum().item()
        total = test_y.size(0)
        accuracy = correct / total
    
        pbar.set_postfix({'mse_loss': mse_loss.item(), 'accuracy': f'{accuracy:.4f}'})

Training:   0%|                                                                                               …

In [197]:
with torch.no_grad():
    batch, y = generate_test_data(num_embeddings=num_embeddings, num_unembeddings=num_unembeddings)
    batch, y = batch.to(device), y.to(device)
    resid = model.input_to_embed(batch)
    _, out_heads, top_k_indices, top_k_out_heads, top_k_out, pattern = sparse_model_2(resid)
    out, out_heads, pattern = top_k_out[:, -1].unsqueeze(dim=0), out_heads[:, -1], pattern
    unembeds = model.unembed.weight.unsqueeze(0)

print(out.shape, out_heads.shape, unembeds.shape)
print(pattern.shape)

torch.Size([1, 5, 2]) torch.Size([5, 6, 2]) torch.Size([1, 5, 2])
torch.Size([5, 6, 2, 2])


In [191]:
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

vector_out_heads = [out_heads[:, i] for i in range(out_heads.size(1))]
vector_unembeds = [unembeds[i] for i in range(unembeds.size(0))]
vector_out = [out[i] for i in range(out.size(0))]

names_out_heads = [
    [f'{i}' for i in range(out_heads[:, j].size(0))] for j in range(out_heads.size(1))
]
names_unembeds = [
    [f'{i}' for i in range(unembeds[j].size(0))] for j in range(unembeds.size(0))
]
names_out = [
    [f'{i}' for i in range(out[j].size(0))] for j in range(out.size(0))
]

fig = make_subplots(rows=2, cols=len(vector_out_heads), subplot_titles=[f'head {i+1}' for i in range(len(vector_out_heads))])

for idx, (vectors, names) in enumerate(zip(vector_out_heads, names_out_heads)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=1, col=idx+1)

for idx, (vectors, names) in enumerate(zip(vector_unembeds, names_unembeds)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=2, col=idx+1)
        
for idx, (vectors, names) in enumerate(zip(vector_out, names_out)):
    x_start = torch.zeros(vectors.size(0))
    y_start = torch.zeros(vectors.size(0))

    x_end = vectors[:, 0]
    y_end = vectors[:, 1]

    for i in range(vectors.size(0)):
        fig.add_trace(go.Scatter(
            x=[x_start[i].item(), x_end[i].item()],
            y=[y_start[i].item(), y_end[i].item()],
            mode='lines+markers+text',
            text=[None, names[i]],  # 在终点添加文本
            textposition='top center',  # 文本位置
            name=names[i],
        ), row=2, col=idx+2)
        
fig.update_layout(
    title='',
    width=300 * len(vector_out_heads),
    height=300 * 2,
    showlegend=False
)

for idx in range(len(vector_out_heads)):
    fig.update_xaxes(range=[-10, 10], row=1, col=idx+1, title_text='heads output')
    fig.update_yaxes(range=[-10, 10], row=1, col=idx+1, title_text='')

for idx in range(1):
    fig.update_xaxes(range=[-5, 5], row=2, col=idx+1, title_text='unembeddings')
    fig.update_yaxes(range=[-5, 5], row=2, col=idx+1, title_text='')
    fig.update_xaxes(range=[-10, 10], row=2, col=idx+2, title_text='output')
    fig.update_yaxes(range=[-10, 10], row=2, col=idx+2, title_text='')

# 显示图形
fig.show()


In [200]:
print(pattern[:, 2])

tensor([[[1.0000, 0.0000],
         [0.2648, 0.7352]],

        [[1.0000, 0.0000],
         [0.4964, 0.5036]],

        [[1.0000, 0.0000],
         [0.0088, 0.9912]],

        [[1.0000, 0.0000],
         [0.9143, 0.0857]],

        [[1.0000, 0.0000],
         [0.0024, 0.9976]]], device='cuda:0')
