In [1]:
import json
import pandas as pd
import dash
from dash import dcc, html
import plotly.graph_objects as go
from dash.dependencies import Input, Output

# Load the JSON file
with open('umap_data_new_Original.json', 'r') as json_file:
    data_to_save = json.load(json_file)

# Extract data from the JSON file
ids = []
feature_losses = []
image_losses = []
total_losses = []
umap1 = []
umap2 = []
umap3 = []
images = []

for item in data_to_save:
    ids.append(item['id'])
    metadata = item['metadata']
    feature_losses.append(metadata['feature_loss'])
    image_losses.append(metadata['image_loss'])
    total_losses.append(metadata['total_loss'])
    umap1.append(metadata['umap1'])
    umap2.append(metadata['umap2'])
    umap3.append(metadata['umap3'])
    images.append(item['url'][0])

# Create a DataFrame
df = pd.DataFrame({
    'ID': ids,
    'Feature Loss': feature_losses,
    'Image Loss': image_losses,
    'Total Loss': total_losses,
    'UMAP1': umap1,
    'UMAP2': umap2,
    'UMAP3': umap3,
    'Image': images
})

# Initialize the Dash app
app = dash.Dash(__name__)

app.layout = html.Div([
    dcc.Graph(id='3d-plot', style={'height': '800px'}),
    dcc.Tooltip(id="graph-tooltip"),
    html.Div([
        html.Label('Feature Loss Range:'),
        dcc.RangeSlider(
            id='feature-loss-slider',
            min=df['Feature Loss'].min(),
            max=df['Feature Loss'].max(),
            step=0.1,
            value=[df['Feature Loss'].min(), df['Feature Loss'].max()],
            tooltip={"placement": "bottom", "always_visible": True}
        ),
        html.Div(id='feature-loss-slider-output', style={'margin-top': 20})
    ]),
    html.Div([
        html.Label('Image Loss Range:'),
        dcc.RangeSlider(
            id='image-loss-slider',
            min=df['Image Loss'].min(),
            max=df['Image Loss'].max(),
            step=0.001,
            value=[df['Image Loss'].min(), df['Image Loss'].max()],
            tooltip={"placement": "bottom", "always_visible": True}
        ),
        html.Div(id='image-loss-slider-output', style={'margin-top': 20})
    ]),
    html.Div([
        html.Label('Total Loss Range:'),
        dcc.RangeSlider(
            id='total-loss-slider',
            min=df['Total Loss'].min(),
            max=df['Total Loss'].max(),
            step=0.1,
            value=[df['Total Loss'].min(), df['Total Loss'].max()],
            tooltip={"placement": "bottom", "always_visible": True}
        ),
        html.Div(id='total-loss-slider-output', style={'margin-top': 20})
    ]),
    html.Div([
        html.Label('Enter Feature Loss Range:'),
        dcc.Input(id='feature-loss-input-min', type='number', value=df['Feature Loss'].min(), step=0.1),
        dcc.Input(id='feature-loss-input-max', type='number', value=df['Feature Loss'].max(), step=0.1),
    ], style={'margin-top': 20}),
    html.Div([
        html.Label('Enter Image Loss Range:'),
        dcc.Input(id='image-loss-input-min', type='number', value=df['Image Loss'].min(), step=0.001),
        dcc.Input(id='image-loss-input-max', type='number', value=df['Image Loss'].max(), step=0.001),
    ], style={'margin-top': 20}),
    html.Div([
        html.Label('Enter Total Loss Range:'),
        dcc.Input(id='total-loss-input-min', type='number', value=df['Total Loss'].min(), step=0.1),
        dcc.Input(id='total-loss-input-max', type='number', value=df['Total Loss'].max(), step=0.1),
    ], style={'margin-top': 20})
])

@app.callback(
    Output('feature-loss-slider', 'value'),
    Output('image-loss-slider', 'value'),
    Output('total-loss-slider', 'value'),
    Input('feature-loss-input-min', 'value'),
    Input('feature-loss-input-max', 'value'),
    Input('image-loss-input-min', 'value'),
    Input('image-loss-input-max', 'value'),
    Input('total-loss-input-min', 'value'),
    Input('total-loss-input-max', 'value')
)
def update_sliders(feature_loss_min, feature_loss_max, image_loss_min, image_loss_max, total_loss_min, total_loss_max):
    return [feature_loss_min, feature_loss_max], [image_loss_min, image_loss_max], [total_loss_min, total_loss_max]

@app.callback(
    Output('3d-plot', 'figure'),
    Output('feature-loss-slider-output', 'children'),
    Output('image-loss-slider-output', 'children'),
    Output('total-loss-slider-output', 'children'),
    Input('feature-loss-slider', 'value'),
    Input('image-loss-slider', 'value'),
    Input('total-loss-slider', 'value')
)
def update_graph(feature_loss_range, image_loss_range, total_loss_range):
    filtered_df = df[
        (df['Feature Loss'] >= feature_loss_range[0]) & (df['Feature Loss'] <= feature_loss_range[1]) &
        (df['Image Loss'] >= image_loss_range[0]) & (df['Image Loss'] <= image_loss_range[1]) &
        (df['Total Loss'] >= total_loss_range[0]) & (df['Total Loss'] <= total_loss_range[1])
    ]
    
    fig = go.Figure(data=[go.Scatter3d(
        x=filtered_df['UMAP1'],
        y=filtered_df['UMAP2'],
        z=filtered_df['UMAP3'],
        mode='markers',
        marker=dict(
            size=5,
            color=filtered_df['Total Loss'],
            colorscale='Viridis',
            colorbar=dict(title='Total Loss')
        ),
        customdata=filtered_df[['ID', 'Feature Loss', 'Image Loss', 'Total Loss', 'Image']].values,
        hoverinfo='none'
    )])
    
    fig.update_layout(
        title='3D UMAP Visualization',
        scene=dict(
            xaxis_title='UMAP Component 1',
            yaxis_title='UMAP Component 2',
            zaxis_title='UMAP Component 3'
        ),
        margin=dict(l=0, r=0, b=0, t=40)
    )
    
    feature_loss_output = f"Feature Loss Range: {feature_loss_range[0]:.1f} - {feature_loss_range[1]:.1f}"
    image_loss_output = f"Image Loss Range: {image_loss_range[0]:.3f} - {image_loss_range[1]:.3f}"
    total_loss_output = f"Total Loss Range: {total_loss_range[0]:.1f} - {total_loss_range[1]:.1f}"
    
    return fig, feature_loss_output, image_loss_output, total_loss_output

@app.callback(
    Output("graph-tooltip", "show"),
    Output("graph-tooltip", "bbox"),
    Output("graph-tooltip", "children"),
    Input("3d-plot", "hoverData")
)
def display_hover(hoverData):
    if hoverData is None:
        return False, {}, ""

    pt = hoverData["points"][0]
    bbox = pt["bbox"]
    num = pt["pointNumber"]
    customdata = pt["customdata"]

    img_src = customdata[4]
    children = [
        html.Div([
            html.Img(src=img_src, style={"width": "200px", "height": "200px"}),
            html.P(f"Index: {customdata[0]}<br>"
                   f"Feature Loss: {customdata[1]:.2f}<br>"
                   f"Image Loss: {customdata[2]:.2f}<br>"
                   f"Total Loss: {customdata[3]:.2f}")
        ], style={"padding": "10px", "background-color": "rgba(255, 255, 255, 0.8)", "border": "1px solid black"})
    ]

    return True, bbox, children

if __name__ == '__main__':
    app.run_server(debug=True)
