In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import yaml
import pickle

from learning.plotting.utils import (
    plot_attention_heatmap,
    plot_attention_time_series,
    plot_attention_over_time_grid,
    plot_key_attention_trends,
    plot_token_attention_trends,
    plot_node_attention_trends,
    plot_all_attention_heads,
    plot_transformer_attention
)

plt.rcParams.update({'font.size': 18})

plotting_dir = Path().resolve()
config_dir = plotting_dir / "ppo_config.yaml"

with open(config_dir, "r") as file:
    config = yaml.safe_load(file)

data = []
fig = plt.figure(figsize=(12, 7))
ax = fig.add_subplot(111)

# Use a different color for each experiment-batch combination
experiment_colors = {}
color_idx = 0

for batch in config["batches"]:
    for experiment in config["experiments"]:

        # Use a single color for all trials of the same experiment
        exp_key = f"{batch}-{experiment}"
        
        for trial in config["trials"]:
            checkpoint_path = Path(f"{config['base_path']}/{batch}/{experiment}/{trial}/logs/attention.dat")

            if checkpoint_path.is_file():
                with open(checkpoint_path, "rb") as handle:
                    data = pickle.load(handle)
                
                for n_agents, att_data in data.items():
                
                    edge_indices = att_data["edge_indices"]
                    attention_weights = att_data["attention_weights"]
                    attention_over_time = att_data["attention_over_time"]

                    # Save plots
                    match (experiment):
                        case "gat" | "graph_transformer":

                            # plot_attention_heatmap(experiment,edge_indices[-1], attention_weights[-1])

                            # plot_attention_time_series(edge_indices, attention_weights, top_k=10)

                            # for src_idx in [
                            #     1,
                            #     n_agents // 2,
                            #     n_agents - 2,
                            # ]:  # Plot for first token and middle token
                            #     plot_node_attention_trends(experiment,edge_indices, attention_weights, source_node_idx=src_idx)

                            pass


                        case (
                            "transformer_full"
                        ):
                            # if n_agents == 8:
                            #     plot_transformer_attention(
                            #         experiment,
                            #         attention_weights["Enc_L0"],
                            #         "Attention Weights ",
                            #     )
                            # Create grid visualizations
                            # for attn_type in ["Enc_L0", "Dec_L0", "Cross_L0"]:
                            #     if attn_type in attention_weights:
                            #         for head_idx in range(2):  # Assuming 2 attention heads
                            #             plot_attention_over_time_grid(
                            #                 attention_over_time,
                            #                 attn_type=attn_type,
                            #                 head_idx=head_idx,
                            #                 num_samples=5,
                            #             )

                            # Track key attention points
                            # print("Creating trend plots...")
                            # for attn_type in ["Enc_L0", "Dec_L0", "Cross_L0"]:
                            #     if attn_type in attention_weights:
                            #         for head_idx in range(2):
                            #             plot_key_attention_trends(
                            #                 attention_over_time,
                            #                 attn_type=attn_type,
                            #                 head_idx=head_idx,
                            #                 top_k=10,  # You can adjust this number
                            #             )

                            # Track attention from specific tokens
                            if n_agents == 8 or n_agents == 24:
                                for attn_type in ["Dec_L0"]:
                                    if (attn_type in attention_over_time) and (attention_over_time[attn_type] != []):
                                        # Determine how many attention heads we actually have
                                        first_timestep = attention_over_time[attn_type][0]
                                        num_heads = first_timestep.shape[1]
                                        print(f"Found {num_heads} attention heads for {attn_type}")

                                        
                                        # Only iterate through available heads
                                        # for head_idx in range(num_heads):
                                        #     for src_idx in [
                                        #         1,
                                        #         n_agents // 2,
                                        #         n_agents - 2,
                                        #     ]:
                                        #         plot_token_attention_trends(
                                        #             experiment,
                                        #             attention_over_time,
                                        #             attn_type=attn_type,
                                        #             src_idx=src_idx,
                                        #             head_idx=head_idx,
                                        #         )
                                        
                                        for src_idx in [
                                            1,
                                            n_agents // 2,
                                            n_agents - 2,
                                        ]:
                                            plot_token_attention_trends(
                                                experiment,
                                                n_agents,
                                                attention_over_time,
                                                attn_type=attn_type,
                                                src_idx=src_idx,
                                                head_idx=1,
                                            )

plt.show()