In [3]:

import torch
import tiktoken
from tqdm import tqdm
from huggingface_hub import HfApi
import psutil

import gradio as gr
import tempfile
from pathlib import Path

from dual_attention.model_analysis.datlm_utils import datlm_forward_w_intermediate_results
from dual_attention import attention_utils

from bertviz import head_view, model_view

from dual_attention.hf import DualAttnTransformerLM_HFHub

In [4]:
api = HfApi()
models = api.list_models(author='awni00', search='DAT')
models = [model.modelId for model in models] # list of models I have uploaded to the Hugging Face Hub

# Global variables to store the loaded model and tokenizer
loaded_model = None
loaded_model_name = None
tokenizer = tiktoken.get_encoding("gpt2") # TODO: in the future, different models may require different tokenizers
is_model_loaded = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
intermediate_results = None
tokenized_text = None

save_dir = Path('./static')
save_dir.mkdir(parents=True, exist_ok=True)


def load_model(selected_model_path, progress=gr.Progress(track_tqdm=True)):
    global loaded_model, loaded_model_name, tokenizer, is_model_loaded

    try:
        loaded_model = DualAttnTransformerLM_HFHub.from_pretrained(selected_model_path).to(device)
        loaded_model_name = selected_model_path.split('/')[-1]
        is_model_loaded = True
        load_status = f"Model `{selected_model_path}` loaded successfully. 😁"
        run_forward_button = gr.Button("Run Forward Pass", visible=True)
        head_selection = gr.update(choices=list(range(loaded_model.n_heads_ra)))
        load_button = gr.update(visible=False)
        return load_status, load_button, run_forward_button, head_selection
    except Exception as e:
        load_status = f"Failed to load model '{selected_model_path}' 🥲. Error: {str(e)}"
        return load_status, gr.update(), gr.update(), gr.update()

def run_forward_pass(prompt_text):
    global intermediate_results, tokenized_text
    if loaded_model is None:
        forward_pass_status = "No model loaded yet. Please load a model first."
        generate_viz_button = gr.Button("Generate Visualization", visible=False)
        return forward_pass_status, generate_viz_button

    prompt_tokens = torch.tensor(tokenizer.encode(prompt_text)).unsqueeze(0).to(device)
    tokenized_text = [tokenizer.decode_single_token_bytes(i).decode('utf-8') for i in prompt_tokens[0]]
    logits, intermediate_results = datlm_forward_w_intermediate_results(loaded_model, prompt_tokens)
    forward_pass_status = "Forward pass computed successfully. 😁"
    generate_viz_button = gr.Button("Generate Visualization", visible=True)
    return forward_pass_status, generate_viz_button

def causal_softmax(scores, temperature=1.0):
    scores = scores / temperature
    bsz, nh, l, _ = scores.size()
    attn_mask = attention_utils.compute_causal_mask(l, device=scores.device)
    attn_mask_ = torch.zeros(l, l, dtype=scores.dtype, device=scores.device).masked_fill(attn_mask.logical_not(), float('-inf'))
    scores = scores + attn_mask_
    scores = torch.nn.functional.softmax(scores, dim=-1)
    return scores

rel_processor_map = {
    "Clip": lambda x: x.clip(0, 1),
    "Sign": lambda x: (x > 0).float(),
    "Sigmoid-Normalize": torch.nn.functional.sigmoid,
    "Softmax-Normalize": causal_softmax
}
def generate_html_visualization(viz_type, view_type, head_selection, rel_processing, rel_scale=1.0):

    if intermediate_results is None or tokenized_text is None:
        return "Please run forward pass first."

    if viz_type == "Self-Attention Attention Scores":
        scores = [x.cpu() for x in intermediate_results['sa_attn_scores']]
    elif viz_type == "Relational-Attention Attention Scores":
        scores = [x.cpu() for x in intermediate_results['ra_attn_scores']]
    elif viz_type == "Relational-Attention Relations":
        rel_processor = rel_processor_map[rel_processing]
        scores = [rel_processor(rels.transpose(-1, 1).cpu() * rel_scale) for rels in intermediate_results['ra_rels']]
    elif viz_type == "Relational-Attention Relations (Scaled by Attention Scores)":
        h = head_selection
        rel_processor = rel_processor_map[rel_processing]
        scores = [rel_processor(rels.transpose(-1, 1).cpu() * rel_scale) * attn[:, h].cpu() for rels, attn in
            zip(intermediate_results['ra_rels'], intermediate_results['ra_attn_scores'])]
    else:
        raise ValueError(f"Invalid visualization type: {viz_type}")

    if view_type == "Head View":
        html_out = head_view(scores, tokenized_text, html_action='return')
    elif view_type == "Model View":
        html_out = model_view(scores, tokenized_text, html_action='return', display_mode="light")
    else:
        raise ValueError(f"Invalid view type: {view_type}")

    return html_out.data

def create_download_link(file_path):
    return f"<a href='file://{file_path}' target='_blank'>Click here to view the generated HTML</a>"

def print_machine_info():
    cuda_available = torch.cuda.is_available()
    gpu_name = torch.cuda.get_device_name(0) if cuda_available else "No GPU available"
    num_cpus = psutil.cpu_count()
    memory_info = psutil.virtual_memory()
    total_memory = memory_info.total / (1024 ** 3)  # Convert bytes to GB

    print(f"CUDA Available: {cuda_available}")
    print(f"GPU: {gpu_name}")
    print(f"Number of CPUs: {num_cpus}")
    print(f"Total Memory: {total_memory:.2f} GB")

In [5]:
load_model(models[0])

('Model `awni00/DAT-sa8-ra8-ns1024-sh8-nkvh4-343M` loaded successfully. 😁',
 {'__type__': 'update', 'visible': False},
 <gradio.components.button.Button at 0x2a15d0eae50>,
 {'choices': [0, 1, 2, 3, 4, 5, 6, 7], '__type__': 'update'})

In [6]:
string = "A finite-state machine (FSM) or finite-state automaton (FSA, plural: automata), finite automaton, or simply a state machine, is a mathematical model of computation. "
run_forward_pass(string)

('Forward pass computed successfully. 😁',
 <gradio.components.button.Button at 0x2a15de6e5d0>)

In [7]:
viz_type = "Relational-Attention Relations"
view_type = 'Head View'

layer = 0
rel_processing = "Clip"
rel_scale = 1.0
head_selection = 0

if viz_type == "Self-Attention Attention Scores":
    scores = [x.cpu() for x in intermediate_results['sa_attn_scores']]
elif viz_type == "Relational-Attention Attention Scores":
    scores = [x.cpu() for x in intermediate_results['ra_attn_scores']]
elif viz_type == "Relational-Attention Relations":
    rel_processor = rel_processor_map[rel_processing]
    scores = [rel_processor(rels.transpose(-1, 1).cpu() * rel_scale) for rels in intermediate_results['ra_rels']]
elif viz_type == "Relational-Attention Relations (Scaled by Attention Scores)":
    h = head_selection
    rel_processor = rel_processor_map[rel_processing]
    scores = [rel_processor(rels.transpose(-1, 1).cpu() * rel_scale) * attn[:, h].cpu() for rels, attn in
        zip(intermediate_results['ra_rels'], intermediate_results['ra_attn_scores'])]
else:
    raise ValueError(f"Invalid visualization type: {viz_type}")

if view_type == "Head View":
    html_out = head_view(scores, tokenized_text, html_action='return')
elif view_type == "Model View":
    html_out = model_view(scores, tokenized_text, html_action='return', display_mode="light")
else:
    raise ValueError(f"Invalid view type: {view_type}")

html_out

In [8]:
save_dir = Path('datlm_vizs')
save_dir.mkdir(parents=True, exist_ok=True)

file_path = save_dir / f"{loaded_model_name}-{viz_type}-{view_type}.html"
with open(file_path, 'w') as f:
    f.write(html_out.data)

In [None]:
viz_type = "Relational-Attention Relations"
layer = 0
rel_processing = "None"
rel_scale = 1.0
head_selection = 0

if viz_type == "Self-Attention Attention Scores":
    scores = [x.cpu() for x in intermediate_results['sa_attn_scores']]
elif viz_type == "Relational-Attention Attention Scores":
    scores = [x.cpu() for x in intermediate_results['ra_attn_scores']]
elif viz_type == "Relational-Attention Relations":
    rel_processor = rel_processor_map[rel_processing]
    scores = [rel_processor(rels.transpose(-1, 1).cpu() * rel_scale) for rels in intermediate_results['ra_rels']]
elif viz_type == "Relational-Attention Relations (Scaled by Attention Scores)":
    h = head_selection
    rel_processor = rel_processor_map[rel_processing]
    scores = [rel_processor(rels.transpose(-1, 1).cpu() * rel_scale) * attn[:, h].cpu() for rels, attn in
        zip(intermediate_results['ra_rels'], intermediate_results['ra_attn_scores'])]
else:
    raise ValueError(f"Invalid visualization type: {viz_type}")

scores = scores[layer].detach().numpy()[0]

html_out = circuitsvis.attention.attention_patterns(attention=scores, tokens=tokenized_text)
html_out