In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))
import pytorch_lightning as pl
import torch
from model.model_interface import LLM
import torch.utils.data as tud
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
import torch.nn.functional as F
from tqdm.notebook import tqdm
from utils.viz_tool import *


torch.set_float32_matmul_precision('medium')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

LLM Config

In [2]:
gpt2_local_dir = "/nvme/yangyuchen1/huggingface/hub/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8"
gpt2_xl_local_dir = "/nvme/yangyuchen1/huggingface/hub/models--gpt2-xl/snapshots/33cdb5c0db5423c1879b1b9f16c352988e8754a8"
llama_7b_local_dir = "/nvme/yangyuchen1/huggingface/hub/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348"
llm_config = {
    "model_name": llama_7b_local_dir,
}
mt = LLM(**llm_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Hook Config

In [62]:
hook_config = {
    "retain_output": True,
    "retain_input": False,
    "edit_output": None,
    "clone": False,
    "detach": False,
    "device": "cpu"
}

n_layer = mt.model.config.num_hidden_layers
# gpt2 config
# blocks = mt.model.transformer.h
# llama config
blocks = mt.model.model.layers


mt.clear_hook()
for i in range(n_layer):
    mt.add_hook(module=blocks[i],name=f"block_{i}", **hook_config)

Dataset Config

In [63]:
bsz = 1

from dataset.knowns import Knowns

dst = Knowns("/nvme/yangyuchen1/coding/gpt_re/data", mt.tokenizer)
dl = DataLoader(dst, batch_size=bsz, collate_fn=dst.collate_fn)

Loaded dataset with 1209 elements


Trainer config

In [64]:
trainer_config = {
    "precision" : "16-mixed",
    "accelerator" : "auto",
    "devices" : [5],
}
trainer = pl.Trainer(**trainer_config)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [65]:
res = trainer.predict(mt, dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

The dataloader, predict_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 128 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



Predicting: 0it [00:00, ?it/s]

Analyse Mean Residual Norm

In [66]:
blocks_mean_output = [[] for i in range(n_layer)]
for idx, (input_ids, attention_mask, labels) in enumerate(dl):
    seq_len = attention_mask.sum(dim=1).unsqueeze(-1).repeat(1,mt.model.config.hidden_size)
    attention_mask = attention_mask.unsqueeze(-1).repeat(1,1,mt.model.config.hidden_size)
    for i in range(n_layer):
        output_i_idx = mt.hooks[f"block_{i}"].outputs[idx][0]
        output_i_idx = output_i_idx * attention_mask.float()
        output_i_idx = output_i_idx.sum(dim=1) / seq_len # [bsz, hidden_size] # compute mean
        # output_i_idx = output_i_idx[:,-1,:] # [bsz, hidden_size] # use last
        blocks_mean_output[i].append(output_i_idx)
blocks_mean_output = [torch.vstack(b).mean(0) for b in blocks_mean_output]
plotly_bar("Avg norm of layer output", [torch.norm(b).item() for b in blocks_mean_output])

Unembedding

In [106]:
idx = 125

bi = idx // bsz
i = idx - (idx // bsz) * bsz

input_ids, attention_mask, labels = list(dl)[bi][0][i], list(dl)[bi][1][i], list(dl)[bi][2][i]
outputs = [h.outputs[bi][0][i] for h in mt.hooks.values()]
logits = res[bi]['logits'][i][-2] # last token is <\s> so -2

input_tokens = mt.tokenizer.decode(input_ids)
prob, next_ids = torch.topk(F.softmax(logits.float(),dim=-1), 5)
next_token = {mt.tokenizer.decode(t) : "{:7f}".format(p.item()) for p, t in zip(prob, next_ids)}

print(f"prompt:{input_tokens}")
print(f"next token:{next_token}")

# mt.generate(input_tokens, max_new_tokens=5)

prompt: The headquarter of Zillow is in downtown
next token:{' Seattle': '0.964193', ' Belle': '0.020972', ',': '0.002076', ' of': '0.001804', ' Se': '0.001192'}


In [111]:
import plotly.graph_objs as go
import numpy as np

data = np.array([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])

# 添加辅助信息字典
info_dict = {'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}

# 创建customdata数组
customdata = []
for row in data:
    customdata_row = []
    for val in row:
        customdata_row.append(info_dict)
    customdata.append(customdata_row)

trace = go.Heatmap(z=data, customdata=customdata, hovertemplate='x=%{x}<br>y=%{y}<br>customdata=%{customdata}')

layout = go.Layout(title='Heatmap Example',
                   xaxis=dict(title='X Axis'),
                   yaxis=dict(title='Y Axis'))

fig = go.Figure(data=[trace], layout=layout)

fig.show()
