# fft score cam visualization tool


### Imports

In [1]:
import dash
from dash import dcc, html, Input, Output, callback_context
import plotly.express as px
import pickle
import pandas as pd
import numpy as np
from plotly import graph_objects as go

# Load the pickled DataFrame
df = pd.read_pickle('/home/trudes/XAI/fft-score-cam/test_results_for_vis2.pkl')
print(df.tail())

# Initialize the Dash app
app = dash.Dash(__name__)

app.layout = html.Div([
    html.H1("Display Image from Pickled Data"),
    html.Button("Reset Zoom", id="btn-reset-zoom"),
    html.Div([
        # Left half of the window for the main image
        dcc.Graph(
            id='image-display',
            figure={},
            style={'height': '90vh', 'width': '50%'}
        ),
        # Right half of the window for the four quadrant plots
        html.Div([
            # Top row
            html.Div([
                dcc.Graph(id='model-accuracy', style={'height': '45vh', 'width': '50%'}),
                dcc.Graph(id='frequency-magnitude', style={'height': '45vh', 'width': '50%'})
            ], style={'display': 'flex', 'width': '100%'}),
            # Bottom row
            html.Div([
                dcc.Graph(id='time-domain', style={'height': '45vh', 'width': '50%'}),
                dcc.Graph(id='constellation-plot', style={'height': '45vh', 'width': '50%'})
            ], style={'display': 'flex', 'width': '100%'})
        ], style={'display': 'inline-block', 'width': '50%'})
    ], style={'display': 'flex'})
])

@app.callback(
    Output('image-display', 'figure'),
    [Input('btn-reset-zoom', 'n_clicks')],
    prevent_initial_call=True
)
def reset_zoom(n_clicks):
    print(f"Reset Zoom clicked: {n_clicks}")  # Debugging statement

    fig = px.imshow(df.iloc[:, :256], color_continuous_scale='viridis')  # Assuming the first 256 columns are image data
    #fig = px.imshow(df, color_continuous_scale='viridis')
    fig.update_layout(
        margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
        autosize=True,
        xaxis=dict(constrain='range', autorange=True),
        yaxis=dict(constrain='range', autorange=True)
    )
    fig.update_xaxes(showticklabels=False, visible=False)
    fig.update_yaxes(showticklabels=False, visible=False)
    return fig


@app.callback(
    Output('frequency-magnitude', 'figure'),
    [Input('image-display', 'relayoutData')],
    prevent_initial_call=True
)
def update_frequency_magnitude(relayout_data):
    if not relayout_data or 'xaxis.range[0]' not in relayout_data or 'yaxis.range[0]' not in relayout_data:
        return go.Figure()

    # Extract the zoom area coordinates
    x0, x1 = relayout_data['xaxis.range[0]'], relayout_data['xaxis.range[1]']
    y0, y1 = relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']

    # Convert coordinates to integer indices
    x0, x1 = max(int(x0), 0), min(int(x1), df.shape[1])
    y0, y1 = max(int(y0), 0), min(int(y1), df.shape[0])

    # Extract the relevant portion of the DataFrame (assuming the first 256 columns are image data)
    zoomed_df = df.iloc[y0:y1, :256]

    # Calculate the average of each column
    frequency_magnitude = zoomed_df.mean(axis=0)

    # Create a connected scatter plot using graph_objects
    fig = go.Figure(
        go.Scatter(
            x=frequency_magnitude.index, 
            y=frequency_magnitude.values, 
            mode='lines+markers',
            marker=dict(size=5),
            line=dict(width=1)
        )
    )

    fig.update_layout(
        xaxis_title='Column Index',
        yaxis_title='Average Value'
    )

    return fig



@app.callback(
    Output('model-accuracy', 'figure'),
    [Input('image-display', 'relayoutData')],
    prevent_initial_call=True
)
def update_model_accuracy(relayout_data):
    if not relayout_data or 'xaxis.range[0]' not in relayout_data or 'yaxis.range[0]' not in relayout_data:
        return go.Figure()  # Return an empty plot if there's no zoom data

    # Extract the zoom area coordinates
    y0, y1 = relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']

    # Convert coordinates to integer indices
    y0, y1 = max(int(y0), 0), min(int(y1), df.shape[0])

    # Extract the relevant portion of the accuracy data
    zoomed_accuracy = df['accuracy'].iloc[y0:y1].to_numpy()

    # Create a scatter plot of the accuracy data
    fig = go.Figure(
        go.Scatter(
            x=list(range(y0, y1)),
            y=zoomed_accuracy, 
            mode='lines+markers',
            marker=dict(size=5),
            line=dict(width=1)
        )
    )

    fig.update_layout(
        xaxis_title='Index',
        yaxis_title='Accuracy'
    )

    return fig


@app.callback(
    Output('time-domain', 'figure'),
    [Input('image-display', 'relayoutData')],
    prevent_initial_call=True
)
def update_time_domain(relayout_data):
    if not relayout_data or 'xaxis.range[0]' not in relayout_data or 'yaxis.range[0]' not in relayout_data:
        return go.Figure()  # Return an empty plot if there's no zoom data

    # Extract the zoom area coordinates
    y0, y1 = relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']

    # Convert coordinates to integer indices
    y0, y1 = max(int(y0), 0), min(int(y1), df.shape[0])

    # Extract the relevant portion of the IQ_data column
    zoomed_iq_data = df['IQ_data'].iloc[y0:y1]

    # Extract the real part of the complex numbers and ensure it's a list
    real_part = list(np.real(np.array([complex(val) for val in zoomed_iq_data])))

    # Create a connected scatter plot of the real part of the IQ data
    fig = go.Figure(
        go.Scatter(
            x=list(range(y0, y1)), 
            y=real_part, 
            mode='lines+markers',
            marker=dict(size=5),
            line=dict(width=1)
        )
    )

    fig.update_layout(
        xaxis_title='Index',
        yaxis_title='Real Part of IQ Data'
    )

    return fig



@app.callback(
    Output('constellation-plot', 'figure'),
    [Input('image-display', 'relayoutData')],
    prevent_initial_call=True
)
def update_constellation_plot(relayout_data):
    # Placeholder implementation
    return px.scatter()

if __name__ == '__main__':
    app.run_server(debug=True)


       0  1  2  3  4  5  6  7  8  9  ...  249  250  251  252  253  254  255  \
12795  0  0  0  0  0  0  0  0  0  0  ...    0    0    0    0    0    0    0   
12796  0  0  0  0  0  0  0  0  0  0  ...    0    0    0    0    0    0    0   
12797  0  0  0  0  0  0  0  0  0  0  ...    0    0    0    0    0    0    0   
12798  0  0  0  0  0  0  0  0  0  0  ...    0    0    0    0    0    0    0   
12799  0  0  0  0  0  0  0  0  0  0  ...    0    0    0    0    0    0    0   

                  IQ_data  time_domain_CAM   accuracy  
12795 -2.166814+5.005035j         1.000976  0.9726025  
12796  0.000000+1.678518j         1.001125  0.9726025  
12797 -0.091556+1.892148j         1.001275  0.9726025  
12798 -4.028443-2.166814j         1.001424  0.9726025  
12799 -2.075259-3.448591j         1.001573  0.9726025  

[5 rows x 259 columns]
