# Attention Visualization Notebook

This notebook helps visualize the transformer's attention mechanism for two stories. The data includes StoryID, Premise, Initial, Counterfactual, Original Ending, Edited Ending, and Generated Text.

## Instructions
1. Select a model from the dropdown menu.
2. Select a story ID from the dropdown menu.
3. The visualizations will update automatically based on your selections.


In [8]:
# Import necessary libraries
import numpy as np
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import logging
from bertviz.bertviz import head_view, model_view
import ipywidgets as widgets
from IPython.display import display

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# Get Attention Data Function

In [9]:
# Function to load data from a CSV file
def load_data(file_path):
    if not file_path.exists():
        print(f"Data path {file_path} does not exist.")
        return None
    print(f"Loading data from {file_path}")
    return pd.read_csv(file_path)

# Load data and display
data_path = Path('/data/agirard/Projects/StoryRewriterAttention/data/test_data_sample-attention.csv')
data = load_data(data_path)

if data is not None:
    display(data.head())  # Display the first few rows of the data


# Function to get attention data from a given directory and story ID
def get_attention_data(attention_path, story_id):
    attention_dir = attention_path / str(story_id)
    logger.info(f"Loading attention data from {attention_dir}")

    if not attention_dir.exists():
        logger.error(f"Attention directory does not exist: {attention_dir}")
        return None

    try:
        encoder_attentions = [np.load(attention_dir / f'encoder_attentions_layer_{i}.npy') for i in range(12)]
        logger.info(f"Loaded encoder attentions for layers 0-11")
        decoder_attentions = [np.load(attention_dir / f'decoder_attentions_layer_{i}.npy') for i in range(12)]
        logger.info(f"Loaded decoder attentions for layers 0-11")
        cross_attentions = [np.load(attention_dir / f'cross_attentions_layer_{i}.npy') for i in range(12)]
        logger.info(f"Loaded cross attentions for layers 0-11")
    except Exception as e:
        logger.error(f"Error loading attention arrays: {e}")
        return None

    try:
        with open(attention_dir / "tokens.json") as f:
            tokens = json.load(f)
            logger.info("Loaded tokens.json")
    except Exception as e:
        logger.error(f"Error loading tokens.json: {e}")
        return None

    encoder_text = tokens.get('encoder_text', [])
    generated_text = tokens.get('generated_text', "")
    generated_text_tokens = tokens.get('generated_text_tokens', [])

    logger.info("Loaded encoder_text: %s", encoder_text)
    logger.info("Loaded generated_text: %s", generated_text)
    logger.info("Loaded generated_text_tokens: %s", generated_text_tokens)

    return encoder_attentions, decoder_attentions, cross_attentions, encoder_text, generated_text, generated_text_tokens



Loading data from /data/agirard/Projects/StoryRewriterAttention/data/test_data_sample-attention.csv


Unnamed: 0,StoryID,Premise,Initial,Counterfactual,Original Ending,Edited Ending,Generated Text
0,ca8a7f8d-7f63-422f-8007-c4a26bb8e889,Ela was babysitting.,Her young charge wanted chicken nuggets.,Her young charge wanted some hot cocoa.,"Ela checked, but there were none in the freeze...","Ela checked, but there were none in the pantry...","Ela checked, but there were none in the freeze..."
1,9387e571-2819-4e29-bedb-a35f0410da51,I wanted to make hot chocolate.,I took milk and warmed it up.,I didn't have the ingredients to make it though.,"Then, I added cocoa powder and stirred it all ...","I needed cocoa powder. I tasted it, but it was...","I took milk and warmed it up. Then, I added co..."


# Define Dropdown Options

In [14]:
# Define the model options for the dropdown
model_options = [
    "model_2024-03-22-10",
    "model_2024-04-09-22",
    "model_2024-04-08-13",
    "model_2024-03-22-15",
    "model_2024-04-10-10",
    "model_2024-04-08-09",
    "model_2024-04-10-14",
    "model_2024-05-13-17",
    "model_2024-05-14-20"
]

# Create a dropdown widget for model selection
model_dropdown = widgets.Dropdown(
    options=model_options,
    description='Model:',
    value=model_options[0]
)


# BERTViz Visualization Function for Story 1

In [18]:
# Function to update BERTViz visualization for the first story
def update_bertviz_visualization_story1(model):
    story_id = 'ca8a7f8d-7f63-422f-8007-c4a26bb8e889'
    attention_path = Path(f'/data/agirard/Projects/StoryRewriterAttention/data/{model}/attentions')
    attention_data = get_attention_data(attention_path, story_id)

    if attention_data:
        encoder_attentions, decoder_attentions, cross_attentions, encoder_text, generated_text, generated_text_tokens = attention_data

        # Ensure the cross-attention tensors are 5D
        cross_attentions = np.stack(cross_attentions)  # (num_layers, batch_size, num_heads, seq_len, seq_len)
        if cross_attentions.ndim == 4:  # If there's no batch dimension
            cross_attentions = cross_attentions[np.newaxis, ...]  # Add a batch dimension

        # Transpose to match the expected shape [batch_size, num_layers, num_heads, seq_len, seq_len]
        cross_attentions = np.transpose(cross_attentions, (1, 0, 2, 3, 4))

        tokens = {
            'encoder': encoder_text,
            'decoder': generated_text_tokens
        }

        # BERTViz expects attention in shape [batch_size, num_layers, num_heads, seq_len, seq_len]
        # and tokens for the encoder and decoder separately
        model_view(
            encoder_attention=None,
            decoder_attention=None,
            cross_attention=cross_attentions,
            encoder_tokens=tokens['encoder'],
            decoder_tokens=tokens['decoder']
        )

# Create an interactive output area for BERTViz for the first story
output_bertviz_story1 = widgets.Output()

# Function to handle dropdown value changes for BERTViz for the first story
def on_bertviz_value_change_story1(change):
    with output_bertviz_story1:
        output_bertviz_story1.clear_output()
        update_bertviz_visualization_story1(model_dropdown.value)

# Attach the update function to dropdown changes
model_dropdown.observe(on_bertviz_value_change_story1, names='value')

# Display the dropdown and output area for BERTViz for the first story
display(widgets.HTML("<h2>Story 1 Visualization</h2>"))
display(model_dropdown)
display(output_bertviz_story1)

# Initialize the BERTViz visualization with the default values for the first story
update_bertviz_visualization_story1(model_dropdown.value)


HTML(value='<h2>Story 1 Visualization</h2>')

Dropdown(description='Model:', options=('model_2024-03-22-10', 'model_2024-04-09-22', 'model_2024-04-08-13', '…

Output()

INFO:__main__:Loading attention data from /data/agirard/Projects/StoryRewriterAttention/data/model_2024-03-22-10/attentions/ca8a7f8d-7f63-422f-8007-c4a26bb8e889
INFO:__main__:Loaded encoder attentions for layers 0-11
INFO:__main__:Loaded decoder attentions for layers 0-11
INFO:__main__:Loaded cross attentions for layers 0-11
INFO:__main__:Loaded tokens.json
INFO:__main__:Loaded encoder_text: ['▁El', 'a', '▁was', '▁baby', 's', 'i', 'tting', '.', '▁Her', '▁young', '▁charge', '▁wanted', '▁chicken', '▁nu', 'g', 'get', 's', '.', '▁El', 'a', '▁checked', ',', '▁but', '▁there', '▁were', '▁none', '▁in', '▁the', '▁freezer', '.', '▁She', '▁went', '▁to', '▁McDonald', "'", 's', '▁and', '▁bought', '▁some', '▁nu', 'g', 'get', 's', '.', '▁The', '▁child', '▁was', '▁happy', '▁with', '▁his', '▁nu', 'g', 'get', 's', '.', '</s>', '▁El', 'a', '▁was', '▁baby', 's', 'i', 'tting', '.', '▁Her', '▁young', '▁charge', '▁wanted', '▁some', '▁hot', '▁coco', 'a', '.', '</s>']
INFO:__main__:Loaded generated_text: Ela c

Layer 0 attention shape before squeeze: (12, 12, 37, 74)


ValueError: cannot select an axis to squeeze out which has size not equal to one