In [2]:
from collections import defaultdict
import numpy as np
import os
from tqdm import tqdm

In [60]:
PATH = 'data'
PATH_IMG = 'imgs'

In [4]:
attns = defaultdict(list)
c = 0
files = sorted(os.listdir(PATH))
for f in tqdm(files):
    path = os.path.join(PATH, f)
    attn = np.load(path)
    attns[c%4].append(attn)
    c += 1

100%|██████████████████████████████████████| 4000/4000 [00:32<00:00, 124.07it/s]


In [119]:
HEAD = '3.3'
layer, head = map(int, HEAD.split('.'))
arr = np.array(attns[layer])[:,head,:]

In [120]:
from flask import Flask, send_from_directory
import dash
from dash import html, dcc, Input, Output
import plotly.express as px
import numpy as np

server = Flask(__name__)
app = dash.Dash(__name__, server=server)

n = 4  # Scale factor for display
fig = px.imshow((1 - arr.T), color_continuous_scale='gray', origin='upper', labels={'color': 'Pixel Value'})
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
fig.update_layout(width=1000*n, height=129*n, margin=dict(l=0, r=0, t=0, b=20), coloraxis_showscale=False)

app.layout = html.Div([
    dcc.Graph(id='main-image', figure=fig, style={'width': '90vw', 'height': '77vh'}),
    html.Div([
        html.Img(id='hover-image-1', src='', style={'maxWidth': '200px', 'display': 'block'}),
        html.Img(id='hover-image-2', src='', style={'maxWidth': '200px', 'display': 'block'})
    ], style={'position': 'relative'})
])

@server.route('/images/<path:filename>')
def serve_image(filename):
    return send_from_directory(PATH_IMG, filename)

@app.callback(
    [Output('hover-image-1', 'src'), Output('hover-image-2', 'src')],
    [Output('hover-image-1', 'style'), Output('hover-image-2', 'style')],
    Input('main-image', 'hoverData')
)
def update_hover_images(hoverData):
    if hoverData is not None:
        x = hoverData['points'][0]['x']
        y = hoverData['points'][0]['y']
        x2 = x - (129 - y)
        
        # Filenames for both images
        filename_1 = f"{x:06d}.png"
        filename_2 = f"{x2:06d}.png" if x2 >= 0 else ""
        
        position_style_1 = {'maxWidth': '200px', 'display': 'block', 'position': 'absolute','left': f'{x*n+130}px'}
        position_style_2 = {'maxWidth': '200px', 'display': 'block', 'position': 'absolute','left': f'{x*n}px'}

        
        return f'/images/{filename_1}', f'/images/{filename_2}', position_style_1, position_style_2
    return '', '', {'display': 'none'}, {'display': 'none'}

if __name__ == '__main__':
    app.run_server(debug=True)
