# Heatmap of token probabilities

This notebook offers a simple script that can highlight "unexpected" token within a prompt. The highlighting is based on the difference between predicted token probability and actual token probability.

Example:

Consider the prompt 'print("Hello}World")'.
After processing 'print("Hello',the token '}' would get very low probability of being the next token. That's why the following script would highligt it. 

In [1]:
from transformers import GemmaTokenizer, AutoModelForCausalLM
import torch

In [2]:
gpu = torch.device('cuda:0')

In [3]:
model_id = "google/codegemma-1.1-7b-it"
tokenizer = GemmaTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=gpu, torch_dtype=torch.float16)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu_pytorch_tanh`, edit the `model.config` to set `hidden_activation=gelu_pytorch_tanh`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
def escape_from_hex(r, g, b, text):
    return f"\x1B[38;2;{r};{g};{b}m{text}\x1B[0m"

In [5]:
print(escape_from_hex(255, 240, 0, "test"), escape_from_hex(255, 0, 0, "test"))

[38;2;255;240;0mtest[0m [38;2;255;0;0mtest[0m


In [9]:
prompt = """
You are debugging some code. Can you help me identify the issues in this piece of code? Can you give me for every line a probability if this line contains a bug:
def sumElementsOfList(in_list):
    sum = 0
    for i in list:
        sum -= i
    return sum
"""
#with open("convex_hull.py") as f:
#    prompt = f.read()

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
#print("num_tokens", inputs["input_ids"].shape)
with torch.no_grad():
    logits = model(**inputs).logits[0]

differences = []

prompt_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
result = []
last_dif = 0
for j, i in enumerate(logits):
    probabilities = torch.softmax(i, 0)
    actual_token = torch.zeros([1], dtype=torch.int32)
    if j + 1 < len(inputs["input_ids"][0]):
        actual_token = inputs["input_ids"][0][j+1]
    act_prob = probabilities[actual_token]
    pred_prob = torch.max(probabilities)

    difference =  pred_prob.item() - act_prob.item()
    differences.append(difference)

    token = prompt_tokens[j]
    
    #prompt_tokens[j] = 
    #.replace("▁", " ")
    threshold = 0.2
    if last_dif <= threshold:
        token = token.replace("▁", " ")
    result.append(escape_from_hex(round(last_dif * 255), 0, 0, token))
    last_dif = difference
    if difference > threshold:
        pass
        #result.append(escape_from_hex(0, 255, 0, tokenizer.convert_ids_to_tokens([torch.argmax(probabilities)])[0]))
        #print(f"{difference}, {tokenizer.convert_ids_to_tokens([actual_token])}, {tokenizer.convert_ids_to_tokens([torch.argmax(probabilities)])}")

del inputs
del logits
del difference
torch.cuda.empty_cache()
# Print the prompt with colored tokens
colored_prompt = ''.join(result)
print(colored_prompt)

num_tokens torch.Size([1, 71])
[38;2;0;0;0m<bos>[0m[38;2;255;0;0m
[0m[38;2;77;0;0mYou[0m[38;2;0;0;0m are[0m[38;2;44;0;0m debugging[0m[38;2;213;0;0m▁some[0m[38;2;0;0;0m code[0m[38;2;74;0;0m.[0m[38;2;177;0;0m▁Can[0m[38;2;0;0;0m you[0m[38;2;36;0;0m help[0m[38;2;0;0;0m me[0m[38;2;70;0;0m▁identify[0m[38;2;0;0;0m the[0m[38;2;70;0;0m▁issues[0m[38;2;0;0;0m in[0m[38;2;86;0;0m▁this[0m[38;2;226;0;0m▁piece[0m[38;2;0;0;0m of[0m[38;2;0;0;0m code[0m[38;2;0;0;0m?[0m[38;2;250;0;0m▁Can[0m[38;2;0;0;0m you[0m[38;2;178;0;0m▁give[0m[38;2;0;0;0m me[0m[38;2;131;0;0m▁for[0m[38;2;221;0;0m▁every[0m[38;2;241;0;0m▁line[0m[38;2;33;0;0m a[0m[38;2;94;0;0m▁probability[0m[38;2;219;0;0m▁if[0m[38;2;164;0;0m▁this[0m[38;2;0;0;0m line[0m[38;2;79;0;0m▁contains[0m[38;2;145;0;0m▁a[0m[38;2;0;0;0m bug[0m[38;2;174;0;0m:[0m[38;2;254;0;0m
[0m[38;2;101;0;0mdef[0m[38;2;68;0;0m▁sum[0m[38;2;229;0;0mElements[0m[38;2;250;0;0mOf[0m[38;2;0;0;0mList[0m[38