In [8]:
import os
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import math
from scripts.constants import ATTENTION_LAYER_NAMES, WIN_SHIFT_ATTENTION_LAYER_INDEXES, DOWN_SAMPLE_ATTENTION_LAYER_INDEXES

## Load Data

In [None]:
# Data directories and paths
input_data_dir = 'input_data'
output_data_dir = 'output_data'
data_date = "2018-01-01"
data_time = "12:00"
input_name = "input_surface"
layer_index = 2
layer_name = ATTENTION_LAYER_NAMES[layer_index]

In [None]:
# Load the input data
input_surface_path = os.path.join(input_data_dir, data_date, data_time, f"{input_name.replace('/', '_')}.npy")
input_surface = np.load(input_surface_path)  

input_upper_path = os.path.join(input_data_dir, data_date, data_time, f"{input_name.replace('/', '_')}.npy")
input_upper = np.load(input_upper_path)  

if layer_index in WIN_SHIFT_ATTENTION_LAYER_INDEXES:
    # roll the data
    pass

In [None]:
# Load the attention head data
path = os.path.join(output_data_dir, data_date, data_time, f"{layer_name.replace('/', '_')}.npy")
attention_output = np.load(path)

(15, 64, 12, 144, 144)

## Select Hyperparameters

In [None]:
lat_index = 6
lon_index = 1
pl_index = 0

surface_var_index = {
    "MSLP": 0,
    "U10": 1,
    "V10": 2,
    "T2M": 3,
}

weather_var = "T2M"

head = 5
if layer_index in DOWN_SAMPLE_ATTENTION_LAYER_INDEXES:
    patchSize = 8
    chunk_size_lat = 48
    chunk_size_lon = 96
else:
    patchSize = 4
    chunk_size_lat = 24
    chunk_size_lon = 48

lat_coord_start = lat_index * chunk_size_lat
lat_coord_end = lat_coord_start + chunk_size_lat
lon_coord_start = lon_index * chunk_size_lon
lon_coord_end = lon_coord_start + chunk_size_lon

lat_coord_start, lat_coord_end, lon_coord_start, lon_coord_end

input_1 = input_surface[surface_var_index[weather_var], lat_coord_start:lat_coord_end, lon_coord_start:lon_coord_end]

att_pat = attention_output[lon_index, lat_index + lat_index * pl_index, head, :, :]

## Display the Attention Pattern and Input Data

In [23]:
# Create a subplot with two plots
fig = make_subplots(rows=1, cols=2)

# Add the two heatmaps to the figure
trace1 = go.Heatmap(
    z=input_1,
    x0=0, dx=1,
    y0=0, dy=1,
    showscale=False,
    hoverinfo='skip',
    # colorscale='Viridis',
    colorscale='turbo',
    zmin=input_1.min(), zmax=input_1.max(),
)

trace2 = go.Heatmap(
    z=att_pat,
    x0=0, dx=1,
    y0=0, dy=1,
    showscale=False,
    colorscale='Viridis',
    zmin=att_pat.min(), zmax=att_pat.max(),
)

fig.add_trace(trace1, row=1, col=1)
fig.add_trace(trace2, row=1, col=2)

# Set up the axes for both subplots
num_rows_input1, num_cols_input1 = input_1.shape
num_rows_att_pat, num_cols_att_pat = att_pat.shape

# Increase the overall figure size
fig.update_layout(
    width=1000,
    height=500,
    showlegend=False,
    margin=dict(l=10, r=10, t=10, b=10),
)

# Update axes for both heatmaps (same as before)
fig.update_xaxes(range=[0, num_cols_input1], row=1, col=1)
fig.update_yaxes(range=[num_rows_input1, 0], row=1, col=1, scaleanchor="x1", scaleratio=1)

fig.update_xaxes(range=[0, num_cols_att_pat], row=1, col=2)
fig.update_yaxes(range=[num_rows_att_pat, 0], row=1, col=2, scaleanchor="x2", scaleratio=1)

# Convert the figure to a FigureWidget for interactivity
fig_widget = go.FigureWidget(fig)

# Initialize a rectangle shape for highlighting (hidden initially)
fig_widget.layout.shapes = [
    dict(
        type="rect",
        xref='x1',
        yref='y1',
        x0=0,
        x1=1,
        y0=0,
        y1=1,
        line=dict(color="rgba(0, 128, 128, 1)", width=0),
        fillcolor="rgba(0, 128, 128, 0.5)",
        layer="above",
        visible=False,
    ),
    dict(
        type="rect",
        xref='x1',
        yref='y1',
        x0=0,
        x1=1,
        y0=0,
        y1=1,
        line=dict(color="rgba(255, 105, 180, 1)", width=0),
        fillcolor="rgba(255, 105, 180, 0.5)",
        layer="above",
        visible=False,
    )
]

def coords_2_highlight(fig_widget, lat, lon, shape_index):
    # Calculate rectangle coordinates in input_1
    x0 = lon
    x1 = lon + patchSize

    y0 = lat
    y1 = lat + patchSize

    with fig_widget.batch_update():
        # Update rectangle's position and make it visible
        fig_widget.layout.shapes[shape_index].x0 = x0
        fig_widget.layout.shapes[shape_index].x1 = x1
        fig_widget.layout.shapes[shape_index].y0 = y1  # Note inversion due to y-axis reversal
        fig_widget.layout.shapes[shape_index].y1 = y0
        fig_widget.layout.shapes[shape_index].visible = True

# Define the hover callback function
def hover_fn(trace, points, state):
    if points.point_inds:
        # Get the index of the hovered point in the flattened array
        y, x = points.point_inds[0]

        # Compute variables as in your code
        k_pl = math.floor(x / 72)
        q_pl = math.floor(y / 72)

        k_lat = (math.floor(x / 12) % 6) * patchSize
        q_lat = (math.floor(y / 12) % 6) * patchSize

        k_lon = (x % 12) * patchSize
        q_lon = (y % 12) * patchSize

        coords_2_highlight(fig_widget, k_lat, k_lon, 0)
        coords_2_highlight(fig_widget, q_lat, q_lon, 1)


# Define the unhover callback function to hide the rectangle
def unhover_fn(trace, points, state):
    with fig_widget.batch_update():
        fig_widget.layout.shapes[0].visible = False

# Attach the hover and unhover events to the second heatmap
fig_widget.data[1].on_hover(hover_fn)
fig_widget.data[1].on_unhover(unhover_fn)

# Display the interactive figure
fig_widget

FigureWidget({
    'data': [{'colorscale': [[0.0, '#30123b'], [0.07142857142857142, '#4145ab'],
                             [0.14285714285714285, '#4675ed'],
                             [0.21428571428571427, '#39a2fc'], [0.2857142857142857,
                             '#1bcfd4'], [0.35714285714285715, '#24eca6'],
                             [0.42857142857142855, '#61fc6c'], [0.5, '#a4fc3b'],
                             [0.5714285714285714, '#d1e834'], [0.6428571428571429,
                             '#f3c63a'], [0.7142857142857143, '#fe9b2d'],
                             [0.7857142857142857, '#f36315'], [0.8571428571428571,
                             '#d93806'], [0.9285714285714286, '#b11901'], [1.0,
                             '#7a0402']],
              'dx': 1,
              'dy': 1,
              'hoverinfo': 'skip',
              'showscale': False,
              'type': 'heatmap',
              'uid': '318e0074-d7b4-44bf-b956-3a915559fcdf',
              'x0': 0,
       