In [1]:
import pytorch_lightning as pl
import torch
from model.model_interface import LLM
from dataset.knowns import Knowns
import torch.utils.data as tud
from lightning.pytorch.loggers import TensorBoardLogger
import torch.nn.functional as F
from gpthook import TraceDict
import os 
import random

In [2]:
# LLM Config
llm_config = {
    "model_name": "gpt2",
}

# Dataset config
dl_config = {
    "batch_size": 1,
}
data_dir= "data"
size = 1000

# Trainer config
trainer_config = {
    "precision" : 16,
    "accelerator" : "auto",
    "devices" : 1,
}

In [3]:
mt = LLM(**llm_config)
dst = Knowns(data_dir, mt.tokenizer, size)
dl = tud.DataLoader(dst, **dl_config, collate_fn=dst.collate_fn)
trainer = pl.Trainer(**trainer_config)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loaded dataset with 1000 elements


In [None]:
trainer.p

In [None]:
for i,a,l in dl:
    with torch.no_grad(), TraceDict(mt.model, device="cpu") as clean_td:
        logits = mt.predict(input_ids, attention_mask=attention_mask, output_attentions=True)['logits'] # [bsz, seq, vocab]

In [None]:
def calculate_comp_flow(mt:LLM, batch:tuple, comp_key, comp_kind):
    """batch_size equal 1"""
    model, tokenizer = mt.model, mt.tokenizer
    input_ids, attention_mask, labels = batch
    input_tokens = [tokenizer.decode([t]) for t in input_ids.squeeze()]
    # get clean td
    with torch.no_grad(), TraceDict(mt.model, device="cpu") as clean_td:
        logits = mt.predict_step(input_ids, attention_mask=attention_mask, output_attentions=True)['logits'] # [bsz, seq, vocab]
    clean_prob = F.softmax(logits, dim=-1)
    gt_idx = torch.argmax(clean_prob[:,-1,:], dim=-1)
    answer = mt.tokenizer.decode(gt_idx)
    gt_prob = clean_prob[:,-1,gt_idx]
    
    x0 = clean_td["block_0"].input
    table = []
    attn_weight_diff = []
    for layer in range(1, model.config.n_layer):
        if comp_key == "attn":
            comp = clean_td[f"{comp_key}_{layer - 1}"].output[0] 
        else:
            comp = clean_td[f"{comp_key}_{layer - 1}"].output
        comp = comp.to(device)
        column = []
        for t_idx in range(len(inp['input_ids'][0])):
            prob, td = trace_comp_patch(model, inp, x0, layer, [t_idx], comp, comp_kind)
            column.append(gt_prob - prob[:,-1,gt_idx])
        column = torch.vstack(column)
        table.append(column)
        # corrupt all tokens
        t_idxs = list(range(len(inp['input_ids'][0])))
        prob, td = trace_comp_patch(model, inp, x0, layer, t_idxs, comp, comp_kind, output_attentions=True)
        attn_weight_o, attn_weight_fixed = td[f'attn_{layer}'].output[2]
        device2 = attn_weight_fixed.device
        # pdb.set_trace()
        attn_weight_diff.append((attn_weight_o-attn_weight_fixed).abs().sum(dim=-1).sum(dim=-1))
    attn_weight_diff = torch.vstack(attn_weight_diff)
    table = torch.stack(table).squeeze()
    return {"table":table.transpose(0,1).cpu(),
            "comp_key":comp_key,
            "comp_kind":comp_kind,
            "input_tokens": input_tokens,  
            "answer":answer,
            "attn_weight_diff":attn_weight_diff.cpu()}
