# Setup

### Install Dependencies

In [1]:
!pip install transformers bertviz shap captum

Collecting bertviz
  Downloading bertviz-1.4.0-py3-none-any.whl.metadata (19 kB)
Collecting shap
  Downloading shap-0.46.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (24 kB)
Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl.metadata (26 kB)
Collecting boto3 (from bertviz)
  Downloading boto3-1.35.22-py3-none-any.whl.metadata (6.6 kB)
Collecting slicer==0.0.8 (from shap)
  Downloading slicer-0.0.8-py3-none-any.whl.metadata (4.0 kB)
Collecting botocore<1.36.0,>=1.35.22 (from boto3->bertviz)
  Downloading botocore-1.35.22-py3-none-any.whl.metadata (5.7 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from boto3->bertviz)
  Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)
Collecting s3transfer<0.11.0,>=0.10.0 (from boto3->bertviz)
  Downloading s3transfer-0.10.2-py3-none-any.whl.metadata (1.7 kB)
Downloading bertviz-1.4.0-py3-none-any.whl (157 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32

In [3]:
%%time
import os
import sys
import copy
import random
from google.colab import userdata
userdata.get('HF_TOKEN')

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers import utils
import shap
from bertviz import model_view, head_view
from captum.attr import(
    FeatureAblation,
    LLMAttribution,
    TextTokenInput,
    TextTemplateInput,
    ProductBaselines,
)

CPU times: user 11.8 s, sys: 4.81 s, total: 16.6 s
Wall time: 42.7 s


### Load LLM (LLAMA 2)

In [4]:
access_token = "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"
model = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model, token=access_token)

model = AutoModelForCausalLM.from_pretrained(model, token=access_token)


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

# Visualizing Attention

In [5]:
prompt = "What was Ross's pet monkey's name?"

In [8]:
inputs = tokenizer(prompt, return_tensors="pt",max_length=80)
out = model(**inputs, output_attentions=True)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [9]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm):

In [12]:
attention = out['attentions']

print(f"# Attention Heads : {len(attention)}")
print(f"Attention in 1st head : {attention[0].detach().cpu().shape}")

# Attention Heads : 32
Attention in 1st head : torch.Size([1, 32, 13, 13])


In [13]:
print(inputs['input_ids'][0])

tensor([    1,  1724,   471, 13693, 29915, 29879,  5697,  1601,  1989, 29915,
        29879,  1024, 29973])


In [14]:
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
print(tokens)

['<s>', '▁What', '▁was', '▁Ross', "'", 's', '▁pet', '▁mon', 'key', "'", 's', '▁name', '?']


### Model View (all heads at once)

In [15]:
model_view(attention,tokens)

<IPython.core.display.Javascript object>

In [16]:
head_view(attention,tokens)

<IPython.core.display.Javascript object>