##### token probability per layer

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import re
from rich import print as pp

In [None]:
path_to_all_data = "../data/text_with_entities.json"

with open(path_to_all_data, 'r') as json_file:
    ddi_data = json.load(json_file)

In [None]:
def get_token_probs_per_layer(model, tokenizer, input_text, tokens_of_interest):
    
    
    """
    Analyze a model's hidden states to track token probabilities layer by layer.

    Args:
        model: The HuggingFace transformer model.
        tokenizer: The tokenizer for the model.
        input_text (str): The prompt to analyze.
        tokens_of_interest (list[str]): A list of specific tokens whose probability we want to track.

    Returns:
        dict: A results dictionary containing:
            - A list of dictionaries, one for each layer.
            - Each list item has the layer index, probabilities for the tokens of interest and the 5 tokens with the highest probability.
    """
    
    
    
    ## Tokenize input
    inputs = tokenizer(input_text, return_tensors = "pt").to(model.device)
    
    ## Get token IDs for tokens of interest
    token_ids_of_interest = []
    for token in tokens_of_interest:
        token_id = tokenizer.encode(" " + token, add_special_tokens = False)[0] ## if word is tokenized into multiple tokens, so using the first one: {token_id[0]}
        token_ids_of_interest.append(token_id)
    
    
    generated_ids = model.generate(**inputs,
                                   max_new_tokens = 50,
                                   pad_token_id = tokenizer.eos_token_id)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens = True)
    
    ## Run model with hidden states output for jus the prompt
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states = True)
    
    ## Get the layer of the model that projects hidden states to vocabulary logits
    output_projection = model.get_output_embeddings()

    results = []

    ## For each layer find top probable tokens and probabilities of tokens of interest
    for layer_idx, layer_hidden_state in enumerate(outputs.hidden_states):
        
        ## Get last token position hidden state
        last_token_hidden = layer_hidden_state[0, -1, :]
        
        ## Project to vocabulary space
        logits = output_projection(last_token_hidden)
        
        ## Convert to probabilities
        probs = torch.softmax(logits, dim = -1)
        
        ## Extract probabilities for tokens of interest
        token_probs = {token: probs[token_id].item()\
                       for token, token_id in zip(tokens_of_interest, token_ids_of_interest)}
        
        top_5_probs, top_5_ids = torch.topk(probs, 5)
        top_5_tokens = [
            {
                "token": tokenizer.decode(token_id),
                "probability": prob.item()
            }
            for token_id, prob in zip(top_5_ids, top_5_probs)
        ]

        results.append({
            "layer": layer_idx,
            "tokens_of_interest": token_probs,
            "top_5_tokens": top_5_tokens
        })
        
        results.append((layer_idx, token_probs))
    
    return results

In [None]:
pattern = r'[\x00-\x1F\x7F]'
model_name = "microsoft/Phi-4-mini-instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map = "auto",
    torch_dtype = "auto",
    trust_remote_code = True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

for idx, entry in enumerate(ddi_data):

    tokens_to_check = ddi_data[entry]['entities']
    eval_prompt = "The drugs, chemicals and medical entitites mentioned in the text. - \"" + ddi_data[entry]['full_text'] + "\" are the following: "
    eval_prompt = re.sub(pattern, '', eval_prompt)
    results = get_token_probs_per_layer(model, tokenizer, eval_prompt, tokens_to_check)

    with open(f'../results/token_prob/phi/sample_{idx + 1}.json', 'w') as file_writer:
        json.dump(results, file_writer)
        
    break

##### attention flow

In [1]:
import json
import re
import seaborn as sns
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer



In [2]:
path_to_all_data = "../data/text_with_entities.json"

with open(path_to_all_data, 'r') as json_file:
    ddi_data = json.load(json_file)

In [3]:
def calculate_and_visualize_attention_rollout(model, tokenizer, input_text, show_plot = False, scale = 'log'):
    """
    Calculate and visualize attention rollout for a model and input text.

    Args:
        model: The Hugging Face transformer model.
        tokenizer: The tokenizer for the model.
        input_text (str): The prompt to analyze.
        show_plot (bool): If True, displays the plot. Otherwise, returns fig and ax.

    Returns:
        tuple: A tuple containing:
            - rollout (torch.Tensor): The final [seq_len, seq_len] attention rollout matrix.
            - (fig, ax) (tuple): The matplotlib figure and axis objects for the plot.
    """
    
    device = model.device
    inputs = tokenizer(input_text, return_tensors = "pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_attentions = True)
    
    attention_matrices = outputs.attentions
    input_tokens = [tokenizer.decode(token_id) for token_id in inputs['input_ids'][0]]
    seq_len = len(input_tokens)

    ## Calculate attn rollout
    rollout = torch.eye(seq_len, device = device)
    for layer_attention in attention_matrices:
        
        avg_head_attention = layer_attention.squeeze(0).mean(dim = 0)
        identity_matrix = torch.eye(seq_len, device = device)
        residual_attention = 0.5 * avg_head_attention + 0.5 * identity_matrix
        rollout = residual_attention @ rollout

    ## Calculate percentage contribution scores
    raw_scores = rollout.sum(dim = 0)
    probabilistic_scores = raw_scores / raw_scores.sum()
    contribution_scores = probabilistic_scores.cpu().numpy()
    
    ## Create dataFrame to store values
    contributions = pd.DataFrame({
        'token': input_tokens,
        'contribution_score': contribution_scores
    })


    ## For visualization
    if show_plot:
        
        rollout_data = rollout.cpu().numpy()
        cbar_label = "Attention Score"
        plot_title = f"Attention Rollout for Prompt:\n'{input_text}'"

        if scale == 'log':
            rollout_data = np.log(rollout_data + 1e-9)
            cbar_label = "Log Attention Score"
            plot_title = f"Log-Scaled Attention Rollout for Prompt:\n'{input_text}'"
        
        fig1, ax1 = plt.subplots(figsize = (14, 12))
        sns.heatmap(rollout_data, annot = False, cmap = 'magma', ax = ax1, cbar_kws = {'label': cbar_label})
        ax1.set_xticks(np.arange(seq_len) + 0.5)
        ax1.set_xticklabels(input_tokens, rotation=90, ha = "center")
        ax1.set_yticks(np.arange(seq_len) + 0.5)
        ax1.set_yticklabels(input_tokens, rotation=0, va = "center")
        ax1.set_xlabel("Attended To (Source Tokens)")
        ax1.set_ylabel("Attention From (Target Representations)")
        ax1.set_title(plot_title, fontsize = 16)
        plt.tight_layout(pad = 2.0)
        plt.show()

        fig2, ax2 = plt.subplots(figsize = (12, 8))
        sns.barplot(x = 'token',
                    y = 'contribution_score',
                    data = contributions,
                    palette = 'viridis',
                    ax = ax2)
        ax2.set_xticklabels(ax2.get_xticklabels(), rotation = 45, ha = "right")
        ax2.set_title("Overall Token Contribution Scores (from Attention Rollout)")
        ax2.set_xlabel("Input Token")
        ax2.set_ylabel("Total Contribution Score (Column Sum)")
        plt.tight_layout()
        plt.show()

    return contributions

In [None]:
pattern = r'[\x00-\x1F\x7F]'
model_name = "microsoft/Phi-4-mini-instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map = "auto",
    torch_dtype = "auto",
    trust_remote_code = True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()

for idx, entry in enumerate(ddi_data):

    tokens_to_check = ddi_data[entry]['entities']
    eval_prompt = "The drugs, chemicals and medical entitites mentioned in the text. - \"" + ddi_data[entry]['full_text'] + "\" are the following: "
    eval_prompt = re.sub(pattern, '', eval_prompt)

    contribution_df = calculate_and_visualize_attention_rollout(model,
                                                                tokenizer,
                                                                eval_prompt,
                                                                scale = 'log')
    
    print("\n--- Top Contributing Tokens ---")
    print(contribution_df.to_string())

    contribution_df.to_csv(f'../results/attn_rollout/phi/sample_{idx + 1}.csv', index = False)
        
    break

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]




--- Top Contributing Tokens ---
             token  contribution_score
0              The        9.999976e-01
1            drugs        1.618518e-07
2                ,        8.151707e-07
3        chemicals        3.443011e-08
4              and        5.354339e-08
5          medical        2.770667e-08
6              ent        2.488946e-08
7               it        2.749413e-08
8             ites        3.545462e-08
9        mentioned        6.290023e-08
10              in        6.653143e-08
11             the        3.193286e-07
12            text        2.490004e-08
13               .        7.352909e-08
14               -        2.770193e-08
15               "        3.526010e-08
16              In        4.115075e-08
17           order        1.195736e-08
18              to        2.296895e-08
19         provide        1.880360e-08
20     information        1.773302e-08
21             for        3.023531e-08
22             the        1.130675e-07
23     appropriate        1.057

In [5]:
final_rollout.shape

NameError: name 'final_rollout' is not defined

In [None]:
contribution_df

Unnamed: 0,token,contribution_score
0,The,8.600391e+01
2,",",7.010803e-05
11,the,2.746357e-05
1,drugs,1.391992e-05
22,the,9.724270e-06
...,...,...
77,applicable,1.988141e-08
82,the,8.104623e-09
83,following,5.734450e-09
84,:,5.221844e-09
