In [1]:
import os 
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))
import pytorch_lightning as pl
import torch
from model.model_interface import LLM
from dataset.knowns import Knowns
import torch.utils.data as tud
from lightning.pytorch.loggers import TensorBoardLogger
import torch.nn.functional as F
from utils.gpthook import TraceDict
import os 
import random
from tqdm.notebook import tqdm

In [2]:
# LLM Config
llm_config = {
    "model_name": "gpt2",
}

# Dataset config
dl_config = {
    "batch_size": 1,
    "num_workers": 1,
}
data_dir= "data"
size = 100

# Trainer config
trainer_config = {
    "precision" : "16-mixed",
    "accelerator" : "auto",
    "devices" : 1,
}
torch.set_float32_matmul_precision('medium')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [3]:
mt = LLM(**llm_config)
dst = Knowns(data_dir, mt.tokenizer, size)
dl = tud.DataLoader(dst, **dl_config, collate_fn=dst.collate_fn)
trainer = pl.Trainer(**trainer_config)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loaded dataset with 100 elements


In [3]:
def calculate_comp_flow(mt:LLM, batch:tuple, comp_key, comp_kind):
    """batch_size equal 1"""
    model, tokenizer, device = mt.model, mt.tokenizer, mt.device
    input_ids, attention_mask, labels = batch
    input_tokens = [[tokenizer.decode([t]) for t in seq] for seq in input_ids]
    
    # get clean td
    with torch.no_grad(), TraceDict(mt.model, device="cpu") as clean_td:
        logits = mt(input_ids, attention_mask=attention_mask, output_attentions=True)['logits'] # [bsz, seq, vocab]
    
    clean_prob = F.softmax(logits, dim=-1)
    gt_idx = torch.argmax(clean_prob[:,-1,:], dim=-1)
    answers = [mt.tokenizer.decode(gt_idx)]
    gt_prob = clean_prob[:,-1,gt_idx]
    
    x0 = clean_td["block_0"].input
    table = []
    attn_weight_diff = []
    for layer in range(1, model.config.n_layer):
        if comp_key == "attn":
            comp = clean_td[f"{comp_key}_{layer - 1}"].output[0] 
        else:
            comp = clean_td[f"{comp_key}_{layer - 1}"].output
        comp = comp.to(device)
        column = []
        for t_idx in range(len(inp['input_ids'][0])):
            prob, td = trace_comp_patch(model, inp, x0, layer, [t_idx], comp, comp_kind)
            column.append(gt_prob - prob[:,-1,gt_idx])
        column = torch.vstack(column)
        table.append(column)
        # corrupt all tokens
        t_idxs = list(range(len(inp['input_ids'][0])))
        prob, td = trace_comp_patch(model, inp, x0, layer, t_idxs, comp, comp_kind, output_attentions=True)
        attn_weight_o, attn_weight_fixed = td[f'attn_{layer}'].output[2]
        device2 = attn_weight_fixed.device
        # pdb.set_trace()
        attn_weight_diff.append((attn_weight_o-attn_weight_fixed).abs().sum(dim=-1).sum(dim=-1))
    attn_weight_diff = torch.vstack(attn_weight_diff)
    table = torch.stack(table).squeeze()
    return {"table":table.transpose(0,1).cpu(),
            "comp_key":comp_key,
            "comp_kind":comp_kind,
            "input_tokens": input_tokens,  
            "answer":answers,
            "attn_weight_diff":attn_weight_diff.cpu()}


In [5]:
res = trainer.predict(mt, dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

In [2]:
# launch tensorboard
%tensorboard --logdir lightning_logs/ --port 6009

UsageError: Line magic function `%tensorboard` not found.


In [8]:
res[0].keys()

odict_keys(['logits', 'past_key_values'])