In [1]:
import io
import base64
import pandas as pd
import os

from dash import Dash, dcc, html, Input, Output, no_update
from jupyter_dash import JupyterDash 
import plotly.express as px
import numpy as np
from PIL import Image
import plotly.graph_objects as go

In [4]:
#load_base_dir = '/home/nu/data/contents_shared/NSD-stimuli/derivatives/umap/cos_sim'
load_base_dir = '../../results/res_umap'
load_base_dir = './results/res_umap'
assert os.path.isdir(load_base_dir)
save_base_dir = '../../results/assets/fig04'
network = 'pytorch/brain_diffuser_versatile_diffusion'
feat_name ='text_encoder' # 'vision_encoder'

In [5]:
#file_name = 'nsd_embedding_norm_default_param.npy'
prefix_list = ['nsd', 'deeprecon', 'both']
prefix = prefix_list[0]
file_name = f'{prefix}_embedding_norm.npy'
#file_name = f'{prefix}_embedding_norm_default_param.npy'

In [6]:
embeddings = np.load(os.path.join(load_base_dir, network, feat_name ,file_name))

In [7]:
# Define the base directory for hover data
nsd_hover_base_dir = '/home/nu/data/contents_shared/NSD-stimuli/derivatives/umap/'

# Load and prepare training data
df_NSD_train = pd.read_csv(f'{nsd_hover_base_dir}/NSD_sub01_hover.csv', index_col=0)
df_NSD_train['dataset'] = 'nsd-train (Sub 01)'
df_NSD_train_sorted = df_NSD_train.sort_values('image id').reset_index(drop=True)

# Load and prepare test data
df_NSD_test = pd.read_csv(f'{nsd_hover_base_dir}/NSD_shared1000_hover.csv', index_col=0)
df_NSD_test['dataset'] = 'nsd-test'
df_NSD_test_sorted = df_NSD_test.sort_values('image id').reset_index(drop=True)

nsd_hover_base_dir = '/home/nu/data/contents_shared/NSD-stimuli/derivatives/umap/'

df_dr_train =  pd.read_csv('/home/nu/data/contents_shared/NSD-stimuli/derivatives/umap/GOD_train_hover.csv', index_col=0)
df_dr_train['dataset'] = 'dr-train'
df_dr_train_sorted = df_dr_train.sort_values('image id').reset_index(drop=True)

df_dr_test =  pd.read_csv('/home/nu/data/contents_shared/NSD-stimuli/derivatives/umap/GOD_test_hover.csv', index_col=0)
df_dr_test['dataset'] = 'dr-test'
df_dr_test_sorted = df_dr_test.sort_values('image id').reset_index(drop=True)

if prefix == 'nsd':
    merge_df = pd.concat([df_NSD_train, df_NSD_test, ])
    data_dict = {
    0:'nsd-train (Sub 01)',
    1: 'nsd-test'}
elif prefix == 'deeprecon':
    merge_df = pd.concat([df_dr_train, df_dr_test])
    data_dict = { 
    0: 'dr-train',
    1: 'dr-test', 
    }
elif prefix == 'both':
    merge_df = pd.concat([df_dr_train, df_dr_test,df_NSD_train, df_NSD_test,])
    data_dict = {
        0: 'dr-train',
        1: 'dr-test', 
        2: 'nsd-train (Sub 01)',
        3: 'nsd-test', }


In [8]:
df_embed = pd.DataFrame(embeddings, columns=['x', 'y'],
                       index= merge_df.index
                       )
df_all = pd.concat([df_embed, merge_df], axis=1)

In [57]:
fig = px.scatter(df_all, x='x', y='y', color='dataset', width=800, height=800,color_discrete_sequence=px.colors.qualitative.D3[2:],
               title =  f"Umap Visualization of CLIP-Text_encoder"
                )
fig.update_traces(
    marker=dict(size=6, 
                            #symbol=["circle","diamond", "circle", "diamond"], 
                            #color = ['#1F77B4', '#FF7F0E', '#2CA02C', '#D62728'],
               ),
    
    selector=dict(mode="markers"),
    opacity=0.3,
)

#for i in range(cluster_num):
#    px.scatter(nsd_umap[kmeans.labels_==i,0], nsd_umap[kmeans.labels_==i,1], marker= f'${i}$',  cmap = 'Spectral' , alpha=0.1)




fig.update_layout()
fig.update_traces(
    hoverinfo="none",
    hovertemplate=None,
)

In [58]:
# Helper functions
def np_image_to_base64(im_matrix):
    im = Image.fromarray(im_matrix)
    buffer = io.BytesIO()
    im.save(buffer, format="jpeg")
    encoded_image = base64.b64encode(buffer.getvalue()).decode()
    im_url = "data:image/jpeg;base64, " + encoded_image
    return im_url


def image_path_to_base64(im_matrix):
    im = Image.open(im_matrix)
    buffer = io.BytesIO()
    im.save(buffer, format="jpeg")
    encoded_image = base64.b64encode(buffer.getvalue()).decode()
    im_url = "data:image/jpeg;base64, " + encoded_image
    return im_url

In [59]:
app = Dash(__name__)
df = df_all

app.layout = html.Div(
    className="container",
    children=[
        dcc.Graph(id="graph-5", figure=fig, clear_on_unhover=True),
        dcc.Tooltip(id="graph-tooltip-5", direction='bottom'),
    ],
)
@app.callback(
    Output("graph-tooltip-5", "show"),
    Output("graph-tooltip-5", "bbox"),
    Output("graph-tooltip-5", "children"),
    Input("graph-5", "hoverData"),
)

def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    # demo only shows the first point, but other points may also be available
    hover_data = hoverData["points"][0]
    
    data_name = data_dict[hover_data["curveNumber"]]
    bbox = hover_data["bbox"]
    num = hover_data["pointNumber"]

    pil_matrix = df[df['dataset'] == data_name]['image_path'][num]
    image_name = df[df['dataset'] == data_name]['dataset'][num]#df['image id'][num]
    cap1 = df[df['dataset'] == data_name]['cap1'][num]
    cap2 = df[df['dataset'] == data_name]['cap2'][num]
    cap3 = df[df['dataset'] == data_name]['cap3'][num]
    cap4 = df[df['dataset'] == data_name]['cap4'][num]
    cap5 = df[df['dataset'] == data_name]['cap5'][num]
    im_url =  image_path_to_base64(pil_matrix)
    children = [
        html.Div([
            html.P(f'{num}_{image_name}', style={'font-weight': 'bold', "width": "50px", 'display': 'block', 'margin': '0 auto'}),
              html.Img(
                src=im_url,
                style={"width": "50px", 'display': 'block', 'margin': '0 auto'},
            ),
            
            html.Ul([cap1],  style={ "width": "50px", 'display': 'block', 'margin': '0 auto'}),
            html.Ul([cap2],  style={ "width": "50px", 'display': 'block', 'margin': '0 auto'}),
            html.Ul([cap3],  style={ "width": "50px", 'display': 'block', 'margin': '0 auto'}),
            html.Ul([cap4],  style={ "width": "50px", 'display': 'block', 'margin': '0 auto'}),
            html.Ul([cap5],  style={ "width": "50px", 'display': 'block', 'margin': '0 auto'}),
          
        ])
    ]

    return True, bbox, children

In [60]:
app.run(jupyter_mode="tab", port=8060)

Dash app running on http://127.0.0.1:8060/


<IPython.core.display.Javascript object>