In [None]:
"""
File: AttentionVisualization.ipynb
---------------------------------
Visualize the attention layers of transformer models for interpretability.
"""

In [None]:
import os

import numpy as np
import plotly.graph_objects as go
import torch
from bertviz import head_view, model_view
from bertviz.neuron_view import show
from torch.utils.data import DataLoader, Subset
from transformers import utils


utils.logging.set_verbosity_error()  # Suppress standard warnings


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)

from odyssey.data.dataset import FinetuneDataset
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.evals.prediction import load_finetuned_model, predict_patient_outcomes
from odyssey.models.model_utils import (
    load_finetune_data,
)

In [None]:
class args:
    """Save the configuration arguments."""

    model_path = "checkpoints/best.ckpt"
    vocab_dir = "odyssey/data/vocab"
    data_dir = "odyssey/data/bigbird_data"
    sequence_file = "patient_sequences_2048_mortality.parquet"
    id_file = "dataset_2048_mortality.pkl"
    valid_scheme = "few_shot"
    num_finetune_patients = "20000"
    label_name = "label_mortality_1month"

    max_len = 2048
    batch_size = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = ConceptTokenizer(data_dir=args.vocab_dir)
tokenizer.fit_on_vocab(with_tasks=False)

In [None]:
model = load_finetuned_model(args.model_path, tokenizer)
model

In [None]:
fine_tune, fine_test = load_finetune_data(
    args.data_dir,
    args.sequence_file,
    args.id_file,
    args.valid_scheme,
    args.num_finetune_patients,
)

fine_tune.rename(columns={args.label_name: "label"}, inplace=True)
fine_test.rename(columns={args.label_name: "label"}, inplace=True)

In [None]:
test_dataset = FinetuneDataset(
    data=fine_test,
    tokenizer=tokenizer,
    max_len=args.max_len,
)

test_loader = DataLoader(
    Subset(test_dataset, [85, 89]),  # 85 and 88 are small
    batch_size=args.batch_size,
)

In [None]:
patient = next(iter(test_loader))
patient

In [None]:
output = predict_patient_outcomes(patient, model)
output

In [None]:
tokens = tokenizer.decode(patient["concept_ids"].squeeze(0).cpu().numpy()).split(" ")
truncate_at = patient["attention_mask"].sum().numpy()
attention_matrix = output["attentions"]
last_attention_matrix = attention_matrix[-1].detach()
# batch_size x num_heads x max_len x max_len    x num_layers

In [None]:
truncated_attention_matrix = []

for i in range(len(attention_matrix)):
    truncated_attention_matrix.append(
        attention_matrix[i][:, :, :truncate_at, :truncate_at],
    )

truncated_attention_matrix = tuple(truncated_attention_matrix)
truncated_tokens = tokens[:truncate_at]

In [None]:
def visualize_attention(
    attention_weights,
    patient,
    special_tokens,
    tokenizer,
    truncate=False,
    only_cls=False,
    top_k=10,
):
    # Convert attention tensor to numpy array and squeeze the batch dimension
    concept_ids = patient["concept_ids"].squeeze(0).cpu().numpy()
    attention_weights = attention_weights.squeeze(0).cpu().numpy()

    # Truncate attention weights if specified
    if truncate:
        truncate_at = patient["attention_mask"].sum().numpy()
        attention_weights = attention_weights[:, :truncate_at, :truncate_at]
        concept_ids = concept_ids[:truncate_at]

    if only_cls:
        attention_weights = attention_weights[:, :1, :]

    # Average attention weights across heads
    attention_weights = attention_weights.mean(axis=0)

    # Generate token labels, marking special tokens with a special symbol
    x_token_labels = [
        f"{tokenizer.id_to_token(token)}"
        if tokenizer.id_to_token(token) in special_tokens
        else str(i)
        for i, token in enumerate(concept_ids)
    ]
    y_token_labels = ["[CLS]"]

    # Generate hover text
    hover_text = [
        [
            f"Token {tokenizer.id_to_token(concept_ids[row])} with Token {tokenizer.id_to_token(concept_ids[col])}:"
            f"Attention Value {attention_weights[row, col]:.3f}"
            for col in range(attention_weights.shape[1])
        ]
        for row in range(attention_weights.shape[0])
    ]

    # Generate annotations for special tokens
    annotations = []
    for i, token in enumerate(concept_ids):
        if tokenizer.id_to_token(token) in special_tokens:
            annotations.append(
                dict(
                    x=i,
                    y=0.5,
                    xref="x",
                    yref="paper",  # Use 'paper' coordinates for y
                    text=tokenizer.id_to_token(token),
                    showarrow=False,
                    font=dict(color="black", size=10),
                    textangle=-90,
                    bgcolor="red",
                    opacity=0.8,
                ),
            )

    # Plot the attention matrix as a heatmap
    fig = go.Figure(
        data=go.Heatmap(
            z=attention_weights,
            x=x_token_labels,
            y=y_token_labels,
            hoverongaps=False,
            hoverinfo="text",
            text=hover_text,
            colorscale="YlGnBu",
        ),
    )

    fig.update_layout(
        title="Attention Visualization",
        xaxis_nticks=len(concept_ids),
        yaxis_nticks=len(y_token_labels),
        xaxis_title="Token in Input Sequence",
        yaxis_title="Token in Input Sequence",
        annotations=annotations,
        xaxis_tickangle=-90,
    )

    # Print top k tokens with their attention values
    top_k_indices = np.argsort(-attention_weights, axis=None)[:top_k]
    top_k_values = attention_weights.flatten()[top_k_indices]
    top_k_indices = np.unravel_index(top_k_indices, attention_weights.shape)

    for idx in range(len(top_k_indices[0])):
        token1 = top_k_indices[0][idx]
        token2 = top_k_indices[1][idx]
        attention_value = top_k_values[idx]
        print(
            f"Token {tokenizer.id_to_token(concept_ids[token1])} "
            f"with Token {tokenizer.id_to_token(concept_ids[token2])}: "
            f"Attention Value {attention_value:.3f}",
        )

    fig.show()


# Visualize the attention matrix with special tokens
special_tokens = ["[CLS]", "[VS]", "[VE]", "[REG]"]
visualize_attention(
    last_attention_matrix,
    patient=patient,
    special_tokens=special_tokens,
    tokenizer=tokenizer,
    truncate=True,
    only_cls=True,
    top_k=15,
)

In [None]:
# def visualize_attention(attention_weights, patient, special_tokens, tokenizer, truncate=False, only_cls=False, top_k=10):
#     # Convert attention tensor to numpy array and squeeze the batch dimension
#     concept_ids = patient['concept_ids'].squeeze(0).cpu().numpy()
#     attention_weights = attention_weights.squeeze(0).cpu().numpy()

#     # Truncate attention weights if specified
#     if truncate:
#         truncate_at = patient['attention_mask'].sum().numpy()
#         attention_weights = attention_weights[:, :truncate_at, :truncate_at]
#         concept_ids = concept_ids[:truncate_at]

#     if only_cls:
#         attention_weights = attention_weights[:, :1, :]

#     # Average attention weights across heads
#     attention_weights = attention_weights.mean(axis=0)

#     # Generate token labels, replacing special tokens with their names
#     token_labels = [tokenizer.id_to_token(token) if tokenizer.id_to_token(token) in special_tokens else '' for token in concept_ids]

#     # Plot the attention matrix as a heatmap
#     sns.set_theme(font_scale=1.2)
#     plt.figure(figsize=(15, 12))
#     ax = sns.heatmap(attention_weights, cmap="YlGnBu", linewidths=.5, annot=False, cbar=True)
#     ax.set_title('Attention Visualization')

#     # Set custom tick labels
#     # ax.set_xticks(np.arange(len(token_labels)) + 0.5)
#     # ax.set_xticklabels(token_labels, rotation=45, ha='right', fontsize=10)
#     # ax.set_yticks(np.arange(len(token_labels)) + 0.5)
#     # ax.set_yticklabels(token_labels, rotation=0, ha='right', fontsize=10)

#     ax.set_xlabel('Token in Input Sequence')
#     ax.set_ylabel('Token in Input Sequence')

#     # Print top k tokens with their attention values
#     top_k_indices = np.argsort(-attention_weights, axis=None)[:top_k]
#     top_k_values = attention_weights.flatten()[top_k_indices]
#     top_k_indices = np.unravel_index(top_k_indices, attention_weights.shape)

#     for idx in range(len(top_k_indices[0])):
#         token1 = top_k_indices[0][idx]
#         token2 = top_k_indices[1][idx]
#         attention_value = top_k_values[idx]
#         print(f"Token {tokenizer.id_to_token(concept_ids[token1])} "
#               f"with Token {tokenizer.id_to_token(concept_ids[token2])}: "
#               f"Attention Value {attention_value}")

#     plt.show()


# # Visualize the attention matrix with special tokens
# special_tokens = ['[CLS]', '[VS]', '[VE]', '[REG]']  # Update this list with your actual special tokens
# visualize_attention(last_attention_matrix, patient=patient, special_tokens=special_tokens, tokenizer=tokenizer, truncate=True, only_cls=True, top_k=25)

In [None]:
# Model view
html_model_view = model_view(
    truncated_attention_matrix,
    truncated_tokens,
    include_layers=[5],
    include_heads=[0, 1, 2, 3, 4, 5],
    display_mode="light",
    html_action="return",
)

with open("model_view.html", "w") as file:
    file.write(html_model_view.data)

In [None]:
# Head View
html_head_view = head_view(
    truncated_attention_matrix,
    truncated_tokens,
    # include_layers=[5],
    html_action="return",
)

with open("head_view.html", "w") as file:
    file.write(html_head_view.data)

In [None]:
# Neuron View
model_type = "bert"

show(model, model_type, tokenizer, display_mode="dark", layer=5, head=0)

In [None]:
# Visualize REG token -> Tricky?
#   DONE Why the row vs column attention differs? -> What the matrix actually represents
# Include one example patient and visualize the attention matrix -> Include the exact concept token
# Some sort of markers to separate visits and special tokens
# Libraries used for attention visualization -> Amrit suggestion
# Visualize the gradients