In [None]:
import gc
import os
import torch

from pathlib import Path
from transformers import LlamaForCausalLM, LlamaTokenizer, PreTrainedModel
from generate_answers_lit import TEMPLATES

TOKENIZER_PATH = Path("~/Llama-2-7b-hf/")
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
MODEL_NAME = None
MODEL_PATH = Path("~/Llama-2-7b-rag-baseline/epoch=4-step=120.ckpt/")
PROMPT_STYLE: TEMPLATES = 'chatml' # 'chatml', 'chatml_enhanced', 'llama-2'

In [None]:
# Select model
model_name_or_path = MODEL_NAME if MODEL_NAME else MODEL_PATH

# Process outfile name
if isinstance(model_name_or_path, Path):
    outfile_prefix = str(model_name_or_path).lstrip('~/')
else:
    assert isinstance(model_name_or_path, str)
    outfile_prefix = model_name_or_path

outfile_prefix = outfile_prefix.replace('/', '_')
outfile_prefix

In [None]:
# Fix invalid path issues
TOKENIZER_PATH = TOKENIZER_PATH.expanduser()
if isinstance(model_name_or_path, Path):
    model_name_or_path = model_name_or_path.expanduser()

# Ensure the correct paths are loaded
print("Tokenizer path:", TOKENIZER_PATH)
print("Model path:", model_name_or_path)

# Load the tokenizer
tokenizer = LlamaTokenizer.from_pretrained(
    TOKENIZER_PATH,
    device_map='cuda'
)
tokenizer.pad_token_id = 0

# Load the model
model = LlamaForCausalLM.from_pretrained(
    model_name_or_path,
    device_map='cuda',
    attn_implementation="eager"
)
model.resize_token_embeddings(len(tokenizer))
model.eval()

In [None]:
def clear_cache():
    """Utility function to help prevent OOM"""
    gc.collect()
    torch.cuda.empty_cache()

clear_cache()

In [None]:
# Get list of prompts
from generate_answers_lit import get_prompt
from load_jsonl_files import load_jsonl_files

# Load jsonl files
dataset_list = load_jsonl_files('dataset')
# Get list of prompts
prompts = []
for datarow in dataset_list:
    prompt = get_prompt(datarow, PROMPT_STYLE)
    prompts.append(prompt)

In [None]:
import matplotlib.pyplot
matplotlib.use('Agg')
import seaborn
import pandas

from matplotlib.colors import LogNorm
from typing import Optional, Tuple

def pad_tensor(tensor: torch.Tensor, shape: Tuple) -> torch.Tensor:
    """Pad tensor with zeroes to specified shape."""
    result = torch.zeros(*shape)
    result[:tensor.shape[0], :tensor.shape[1], :tensor.shape[2]] = tensor
    return result

def merge_tensors(tensors: Tuple[torch.Tensor]) -> torch.Tensor:
    """Merge tuple of tensors (with varying dimensions) into a single tensor."""
    num_answer_weights = sum(tensor.shape[1] for tensor in tensors[1:])
    print(f"num_answer_weights: {num_answer_weights}")
    list_of_tensors = list(tensors)
    for i in range(len(tensors)-1, -1, -1):
        list_of_tensors[i:i+1] = torch.chunk(tensors[i], tensors[i].shape[1], dim=1)
    shape = list_of_tensors[-1].shape
    each_tensor_shape = (shape[0], 1, shape[2])
    padded_tensors = [pad_tensor(tensor, each_tensor_shape) for tensor in list_of_tensors]
    return torch.cat(padded_tensors, dim=1)

def get_tokens_and_att_weights(model: PreTrainedModel, prompt: str) -> None:
    input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()

    # Generate attention weights
    output = model.generate(
        input_ids,
        max_new_tokens=4096,
        pad_token_id=0,
        do_sample=False,
        # temperature=0.01,
        # top_p=None,
        # top_k=200,
        return_dict_in_generate=True,
        output_attentions=True,
    )
    # TODO: Modify huggingface transformers code to save only last-layer attn weights
    # instead of all 32 layers' weights. This will help with CUDA OOM and greatly increase performance.
    attention_sequences = output.attentions
    for attention in attention_sequences:
        assert len(attention) == 32

    tuple_of_last_attention_layers = tuple(attention[-1].squeeze(dim=0) for attention in attention_sequences)
    last_layer_attentions = merge_tensors(tuple_of_last_attention_layers)

    sequence_ids = output.sequences[0]
    assert sequence_ids[:len(input_ids[0])].equal(input_ids[0]), \
        "The prompt is not in the prefix of the model output."
    num_answer_ids = len(sequence_ids) - len(input_ids[0])
    print(f"num_answer_ids: {num_answer_ids}")
    # # Decode the generated answer
    # answer = tokenizer.decode(sequence_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    # print(answer)
    tokens = tokenizer.convert_ids_to_tokens(sequence_ids, skip_special_tokens=False)

    # Average attention weights across all heads
    mean_attention_weights = last_layer_attentions.mean(dim=0).detach().cpu()
    print(f"mean_attention_weights.shape: {mean_attention_weights.shape}")
    tokens = tokens[:-1]
    return tokens, mean_attention_weights

In [None]:
def plot_attention_heatmap(
    tokens: torch.Tensor,
    mean_attention_weights: torch.Tensor,
    cmap: str = 'Reds',
    title: Optional[str] = None,
    outpath: Path = Path('./heatmaps/heatmap.png'),
) -> None:
    if outpath.exists():
        print(f"{outpath} already exists. Skipping.")
        return

    print(f"Processing {outpath}...")
    assert len(tokens) == len(mean_attention_weights), \
        f"tokens: {len(tokens)}, weights: {len(mean_attention_weights)}"

    start = 0
    seq_len = len(tokens)
    figsize = (192, 192)

    # Create a DataFrame
    df = pandas.DataFrame(
        mean_attention_weights[start:start+seq_len, start:start+seq_len],
        index=tokens[start:start+seq_len],
        columns=tokens[start:start+seq_len]
    )

    print("Plotting heatmap...")
    # Plot the heatmap
    matplotlib.pyplot.figure(figsize=figsize)
    if cmap == 'hot':
        cmap = matplotlib.colormaps['hot'].copy()
        cmap.set_bad(color='black')
    hm = seaborn.heatmap(
        df,
        annot=False,
        cmap=cmap,
        cbar=False,
        square=True,
        norm=LogNorm(),
        xticklabels=1,
        yticklabels=1,
    )
    if title:
        hm.set_title(title)
    hm.set_xticklabels(hm.get_xmajorticklabels(), fontsize=7)
    hm.set_yticklabels(hm.get_ymajorticklabels(), fontsize=7)
    matplotlib.pyplot.tick_params(
        axis='x',
        which='both',
        bottom=True,
        top=False,
        labelbottom=True,
        labeltop=False,
        rotation=90
    )
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    matplotlib.pyplot.savefig(outpath, bbox_inches='tight')
    print(f"Heatmap saved to {outpath}")

In [None]:
import multiprocessing

clear_cache()
pool = multiprocessing.Pool(
    processes=multiprocessing.cpu_count(),
    maxtasksperchild=8, # Required to prevent memory leak; adjust based on available RAM
)
for i, prompt in enumerate(prompts):
    outfolder = Path('./heatmaps/')
    outfile = Path(outfile_prefix + '_' + str(i) + '.png')
    outpath = outfolder / outfile
    if outpath.exists():
        print(f"{outpath} already exists. Skipping.")
        continue

    try:
        tokens, mean_attention_weights = get_tokens_and_att_weights(model, prompt)
    except torch.cuda.OutOfMemoryError:
        print(f"ERROR: CUDA OOMed. Skipping {outpath}.")
        clear_cache()
        continue

    clear_cache()
    # Generate heatmaps for all prompts
    pool.apply_async(
        plot_attention_heatmap,
        kwds = {
            'tokens': tokens,
            'mean_attention_weights': mean_attention_weights,
            # cmap='hot',
            'outpath': outpath
        }
    )
    clear_cache()

pool.close()
pool.join()

In [None]:
clear_cache()