In [8]:
import os
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import math
from scripts.constants import ATTENTION_LAYER_NAMES, WIN_SHIFT_ATTENTION_LAYER_INDEXES, DOWN_SAMPLE_ATTENTION_LAYER_INDEXES

## Load Data

In [9]:
# Data directories and paths
input_data_dir = 'input_data'
output_data_dir = 'output_data'
data_date = "2018-01-01"
data_time = "12:00"
input_name = "input_surface"
layer_index = 2
layer_name = ATTENTION_LAYER_NAMES[layer_index]

In [10]:
# Load the input data
input_surface_path = os.path.join(input_data_dir, data_date, data_time, f"{input_name.replace('/', '_')}.npy")
input_surface = np.load(input_surface_path)  

input_upper_path = os.path.join(input_data_dir, data_date, data_time, f"{input_name.replace('/', '_')}.npy")
input_upper = np.load(input_upper_path)  

if layer_index in WIN_SHIFT_ATTENTION_LAYER_INDEXES:
    # roll the data
    # input_data = np.roll(input_data, shift=(chunk_size_lat // 2, chunk_size_lon // 2), axis=(-2, -1))
    pass

In [11]:
# Load the attention head data
path = os.path.join(output_data_dir, data_date, data_time, f"{layer_name.replace('/', '_')}.npy")
attention_output = np.load(path)

## Select Hyperparameters

In [12]:
lat_index = 6
lon_index = 1
pl_index = 0

surface_var_index = {
    "MSLP": 0,
    "U10": 1,
    "V10": 2,
    "T2M": 3,
}

weather_var = "T2M"

head = 5
if layer_index in DOWN_SAMPLE_ATTENTION_LAYER_INDEXES:
    patchSize = 8
    chunk_size_lat = 48
    chunk_size_lon = 96
else:
    patchSize = 4
    chunk_size_lat = 24
    chunk_size_lon = 48

lat_coord_start = lat_index * chunk_size_lat
lat_coord_end = lat_coord_start + chunk_size_lat
lon_coord_start = lon_index * chunk_size_lon
lon_coord_end = lon_coord_start + chunk_size_lon

lat_coord_start, lat_coord_end, lon_coord_start, lon_coord_end

input_1 = input_surface[surface_var_index[weather_var], lat_coord_start:lat_coord_end, lon_coord_start:lon_coord_end]

att_pat = attention_output[lon_index, lat_index + lat_index * pl_index, head, :, :]

## Display the Attention Pattern and Input Data

**Disclamer:** *Some parts of this code is a bit strange and thrown together. You have been warned!*

In [19]:
input_images = [input_1, input_1, input_1, input_1]
input_images = [input_1, input_1, input_1]
num_input_images = len(input_images)

# Create subplots with corrected subplot titles
subplot_titles = ["Input Image 1", "Attention Pattern"] + \
                 [f"Input Image {i+1}" for i in range(1, num_input_images)]

fig = make_subplots(
    rows=num_input_images, cols=2, 
    column_widths=[0.5, 0.5], 
    subplot_titles=subplot_titles,
    specs=[[{"type": "heatmap"}, {"rowspan": num_input_images, "type": "heatmap"}]] + \
          [[{"type": "heatmap"}, None] for _ in range(num_input_images - 1)]
)

# Add input heatmaps to the left column
for i, input_img in enumerate(input_images):
    trace = go.Heatmap(
        z=input_img,
        x0=0, dx=1,
        y0=0, dy=1,
        showscale=False,
        hoverinfo='skip',
        colorscale='turbo',
        zmin=input_img.min(), zmax=input_img.max(),
    )
    fig.add_trace(trace, row=i + 1, col=1)

# Add attention pattern heatmap in the right column spanning all rows
fig.add_trace(
    go.Heatmap(
        z=att_pat,
        x0=0, dx=1,
        y0=0, dy=1,
        showscale=False,
        colorscale='Viridis',
        zmin=att_pat.min(), zmax=att_pat.max(),
    ),
    row=1, col=2
)

# Compute input axes indices dynamically
if num_input_images == 1:
    input_axes_indices = [1]
else:
    input_axes_indices = [1] + list(range(3, num_input_images + 2))
attention_axes_index = 2  # Attention Pattern is at axes index 2

# Update axes for each input heatmap
for i, input_img in enumerate(input_images):
    num_rows, num_cols = input_img.shape
    fig.update_xaxes(range=[0, num_cols], row=i + 1, col=1)
    fig.update_yaxes(
        range=[num_rows, 0], row=i + 1, col=1,
        scaleanchor=f"x{input_axes_indices[i]}", scaleratio=1
    )

# Update axes for the attention heatmap
num_rows_att_pat, num_cols_att_pat = att_pat.shape
fig.update_xaxes(range=[0, num_cols_att_pat], row=1, col=2)
fig.update_yaxes(
    range=[num_rows_att_pat, 0], row=1, col=2,
    scaleanchor=f"x{attention_axes_index}", scaleratio=1
)

# Update figure layout size
fig.update_layout(
    width=1200, 
    height=200 * num_input_images,
    showlegend=False,
    margin=dict(l=10, r=10, t=10, b=10),
)

# Convert the figure to a FigureWidget for interactivity
fig_widget = go.FigureWidget(fig)

# Initialize rectangles for highlighting with corrected xrefs and yrefs
shapes = [
    dict(
        type="rect",
        xref=f'x{input_axes_indices[i]}',
        yref=f'y{input_axes_indices[i]}',
        x0=0,
        x1=1,
        y0=0,
        y1=1,
        line=dict(color="rgba(0, 128, 128, 1)", width=0),
        fillcolor="rgba(0, 128, 128, 0.5)",
        layer="above",
        visible=False,
    )
    for i in range(num_input_images)
] + [
    dict(
        type="rect",
        xref=f'x{input_axes_indices[i]}',
        yref=f'y{input_axes_indices[i]}',
        x0=0,
        x1=1,
        y0=0,
        y1=1,
        line=dict(color="rgba(255, 105, 180, 1)", width=0),
        fillcolor="rgba(255, 105, 180, 0.5)",
        layer="above",
        visible=False,
    )
    for i in range(num_input_images)
]

# Assign the list of shapes as a tuple to fig_widget.layout.shapes
fig_widget.layout.shapes = tuple(shapes)

def coords_2_highlight(fig_widget, lat, lon, shape_index, surface=False):
    x0 = lon
    x1 = lon + patchSize
    y0 = lat
    y1 = lat + patchSize

    with fig_widget.batch_update():
        fig_widget.layout.shapes[shape_index].x0 = x0
        fig_widget.layout.shapes[shape_index].x1 = x1
        fig_widget.layout.shapes[shape_index].y0 = y1
        fig_widget.layout.shapes[shape_index].y1 = y0
        fig_widget.layout.shapes[shape_index].visible = True

        if not surface:
            fig_widget.layout.shapes[shape_index+1].x0 = x0
            fig_widget.layout.shapes[shape_index+1].x1 = x1
            fig_widget.layout.shapes[shape_index+1].y0 = y1
            fig_widget.layout.shapes[shape_index+1].y1 = y0
            fig_widget.layout.shapes[shape_index+1].visible = True

def hover_fn(trace, points, state):
    if points.point_inds:
        y, x = points.point_inds[0]

        # Compute variables as in your code
        k_pl = math.floor(x / 72)
        q_pl = math.floor(y / 72)

        k_lat = (math.floor(x / 12) % 6) * patchSize
        q_lat = (math.floor(y / 12) % 6) * patchSize

        k_lon = (x % 12) * patchSize
        q_lon = (y % 12) * patchSize

        if num_input_images == 3:
            coords_2_highlight(fig_widget, k_lat, k_lon, k_pl, not(k_pl))
            coords_2_highlight(fig_widget, q_lat, q_lon, num_input_images + q_pl, not(q_pl))
        else:
            coords_2_highlight(fig_widget, k_lat, k_lon, k_pl*2)
            coords_2_highlight(fig_widget, q_lat, q_lon, num_input_images + q_pl*2)

def unhover_fn(trace, points, state):
    with fig_widget.batch_update():
        for shape in fig_widget.layout.shapes:
            shape.visible = False

# Attach the hover functions to the attention pattern trace
fig_widget.data[len(input_images)].on_hover(hover_fn)
fig_widget.data[len(input_images)].on_unhover(unhover_fn)

fig_widget

FigureWidget({
    'data': [{'colorscale': [[0.0, '#30123b'], [0.07142857142857142, '#4145ab'],
                             [0.14285714285714285, '#4675ed'],
                             [0.21428571428571427, '#39a2fc'], [0.2857142857142857,
                             '#1bcfd4'], [0.35714285714285715, '#24eca6'],
                             [0.42857142857142855, '#61fc6c'], [0.5, '#a4fc3b'],
                             [0.5714285714285714, '#d1e834'], [0.6428571428571429,
                             '#f3c63a'], [0.7142857142857143, '#fe9b2d'],
                             [0.7857142857142857, '#f36315'], [0.8571428571428571,
                             '#d93806'], [0.9285714285714286, '#b11901'], [1.0,
                             '#7a0402']],
              'dx': 1,
              'dy': 1,
              'hoverinfo': 'skip',
              'showscale': False,
              'type': 'heatmap',
              'uid': '321d06df-5119-4765-9deb-06c05a90bffc',
              'x0': 0,
       