# fft score cam visualization tool


### Imports

In [42]:
import dash
from dash import dcc, html, Input, Output, callback_context, State
import plotly.express as px
import pickle
import pandas as pd
import numpy as np
from plotly import graph_objects as go

# Load the pickled DataFrame
#df = pd.read_pickle('/home/trudes/XAI/fft-score-cam/test_results_for_vis_noise_experiment.pkl')
#df = pd.read_pickle('/home/trudes/XAI/fft-score-cam/test_results_for_vis2.pkl')
#df = pd.read_pickle('/home/trudes/XAI/fft-score-cam/test_results_for_vis_noise_experiment_medium_power.pkl')
df = pd.read_pickle('/home/trudes/XAI/fft-score-cam/test_results_for_vis_with_freq_cam.pkl')
print(df.tail())

# Initialize the Dash app
app = dash.Dash(__name__)

# Create the initial main plot figure
initial_main_fig = px.imshow(df.iloc[:, :256], color_continuous_scale='Viridis')  # Adjust as needed
initial_main_fig.update_layout(
    margin={'l': 0, 'r': 0, 't': 30, 'b': 0},
    xaxis=dict(constrain='range', autorange=True),
    yaxis=dict(constrain='range', autorange=True),
    coloraxis_colorbar=dict(x=-0.15),
    title='Spectrogram'
)
initial_main_fig.update_xaxes(showticklabels=False, visible=False)
initial_main_fig.update_yaxes(showticklabels=False, visible=False)

# Create the initial CAM heatmap figure
initial_cam_heatmap = go.Figure(
    go.Heatmap(
        z=[df['time_domain_CAM'].values],  # Adjust as needed
        colorscale='Viridis',
        transpose=True
    )
)
initial_cam_heatmap.update_layout(
    xaxis=dict(showticklabels=False, showgrid=False),
    yaxis=dict(showticklabels=False, showgrid=False, autorange='reversed'),
    margin={'l': 0, 'r': 0, 't': 30, 'b': 0},
    title='Time Domain CAM'
)

app.layout = html.Div([
    html.H1("RF Model Evaluation Toolkit"),
    html.Button("Reset App", id="btn-reset-zoom"),
    dcc.Checklist(
        id='heatmap-overlay-toggle',
        options=[
            {'label': 'Overlay Frequency CAM', 'value': 'overlay'}
        ],
        value=[]  # Empty list means no option is selected initially
    ),
    dcc.Store(id='zoom-store'),
    html.Div([
        # Left half of the window for the main image
        dcc.Graph(
            id='image-display',
            figure=initial_main_fig,
            style={'height': '90vh', 'width': '35%'}
        ),
        dcc.Graph(
        id='cam-heatmap',
        figure=initial_cam_heatmap,
        style={'height': '90vh', 'width': '15%'}  # Width for the CAM heatmap
        ),
        # Right half of the window for the four quadrant plots
        html.Div([
            # Top row
            html.Div([
                dcc.Graph(id='model-accuracy', style={'height': '45vh', 'width': '50%'}),
                dcc.Graph(id='frequency-magnitude', style={'height': '45vh', 'width': '50%'})
            ], style={'display': 'flex', 'width': '100%'}),
            # Bottom row
            html.Div([
                dcc.Graph(id='time-domain', style={'height': '45vh', 'width': '50%'}),
                dcc.Graph(id='constellation-plot', style={'height': '45vh', 'width': '50%'})
            ], style={'display': 'flex', 'width': '100%'})
        ], style={'display': 'inline-block', 'width': '50%'})
    ], style={'display': 'flex'})
])

@app.callback(
    Output('cam-heatmap', 'figure'),
    [Input('image-display', 'relayoutData'),
    Input('btn-reset-zoom', 'n_clicks')],
    prevent_initial_call=True
)
def update_cam_heatmap(relayout_data, n_clicks):
    ctx = dash.callback_context

    # Check if triggered by reset button or zoom change
    if ctx.triggered and ctx.triggered[0]['prop_id'] == 'btn-reset-zoom.n_clicks':
        # Logic for reset (display full data range)
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    elif not relayout_data or 'xaxis.range[0]' not in relayout_data or 'yaxis.range[0]' not in relayout_data:
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    else:
        # Extract the zoom area coordinates
        x0, x1 = relayout_data['xaxis.range[0]'], relayout_data['xaxis.range[1]']
        y0, y1 = relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']
    
    # Convert coordinates to integer indices
    y0, y1 = max(int(y0), 0), min(int(y1), df.shape[0])

    # Extract the relevant portion of the time_domain_CAM data
    zoomed_cam = df['time_domain_CAM'].iloc[y0:y1]

    global_min = df['time_domain_CAM'].min()
    global_max = df['time_domain_CAM'].max()
    # Create a heatmap for the CAM data
    fig = go.Figure(
        go.Heatmap(
            z=[zoomed_cam.values],
            colorscale='Inferno',
            transpose=True,
            colorbar=dict(
                title='Raw CAM Value'
            ),
            zmin=global_min,  # Set the minimum value for the color scale
            zmax=global_max   # Set the maximum value for the color scale
        )
    )

    fig.update_layout(
        xaxis=dict(showticklabels=False, showgrid=False),
        yaxis=dict(showticklabels=False, showgrid=False),
        margin={'l': 0, 'r': 0, 't': 30, 'b': 0},
        title='Time Domain CAM'
    )

    return fig


# @app.callback(
#     Output('image-display', 'figure'),
#     [Input('btn-reset-zoom', 'n_clicks')],
#     prevent_initial_call=True
# )
# def reset_zoom(n_clicks):
#     print(f"Reset Zoom clicked: {n_clicks}")  # Debugging statement

#     fig = px.imshow(df.iloc[:, :256], color_continuous_scale='viridis')  # Assuming the first 256 columns are image data
#     fig.update_layout(
#         coloraxis_colorbar=dict(
#             x=-0.15 
#         ),
#         margin={'l': 0, 'r': 0, 't': 30, 'b': 0},
#         autosize=True,
#         xaxis=dict(constrain='range', autorange=True),
#         yaxis=dict(constrain='range', autorange=True),
#         title='Spectrogram'
#     )
#     fig.update_xaxes(showticklabels=False, visible=False)
#     fig.update_yaxes(showticklabels=False, visible=False)
#     return fig
# @app.callback(
#     Output('zoom-store', 'data'),
#     [Input('image-display', 'relayoutData'),
#      Input('btn-reset-zoom', 'n_clicks')],
#     [State('zoom-store', 'data')]
# )
# def manage_zoom_state(relayout_data, reset_n_clicks, current_zoom_state):
#     ctx = dash.callback_context

#     if not ctx.triggered:
#         # No input has fired, so we don't update the state
#         return dash.no_update

#     trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]

#     if trigger_id == 'image-display':
#         # The image display was zoomed, update the zoom state
#         return relayout_data if relayout_data else dash.no_update

#     elif trigger_id == 'btn-reset-zoom':
#         # The reset button was clicked, clear the zoom state
#         return None

#     return current_zoom_state  # Return the existing state if no relevant trigger




@app.callback(
    Output('image-display', 'figure'),
    [Input('heatmap-overlay-toggle', 'value'),
     Input('image-display', 'relayoutData'),
     Input('btn-reset-zoom', 'n_clicks')],
    [State('image-display', 'figure')],
    prevent_initial_call=True
)
def update_main_plot(overlay, relayout_data, reset_zoom_n_clicks, current_fig):
    ctx = dash.callback_context
    trigger_id = ctx.triggered[0]['prop_id'].split('.')[0] if ctx.triggered else None

    # If the reset button was clicked or the app is initializing, reset to the initial main figure
    if trigger_id == 'btn-reset-zoom' or not current_fig:
        fig = initial_main_fig
    else:
        # Use the current figure to maintain zoom state
        fig = go.Figure(current_fig)
    fig.update_xaxes(showticklabels=False, visible=False)
    fig.update_yaxes(showticklabels=False, visible=False)

    # Apply zoom if relayoutData is provided
    if trigger_id == 'image-display' and relayout_data:
        fig.update_layout(
            xaxis=dict(range=[relayout_data.get('xaxis.range[0]', 0), 
                              relayout_data.get('xaxis.range[1]', df.shape[1])],
                       constrain='range'),
            yaxis=dict(range=[relayout_data.get('yaxis.range[0]', 0), 
                              relayout_data.get('yaxis.range[1]', df.shape[0])],
                       constrain='range')
        )

    # Overlay heatmap if checkbox is checked
    if 'overlay' in overlay:
        # Extract fcam columns and reshape
        fcam_columns = [col for col in df.columns if str(col).startswith('fcam')]
        fcam_data = df[fcam_columns].to_numpy()#.reshape(-1, 256)  # Adjust the reshaping as needed
        fig.add_trace(
            go.Heatmap(
                z=fcam_data,
                colorscale='Inferno',
                opacity=0.5,
                showscale=False
            )
        )
    else:
        # Remove overlay heatmap
        # Assuming heatmap is the last trace added
        if len(fig.data) > 1:
            fig.data = fig.data[:-1]

    return fig

@app.callback(
    Output('frequency-magnitude', 'figure'),
    [Input('image-display', 'relayoutData')],
    Input('btn-reset-zoom', 'n_clicks'),
    prevent_initial_call=True
)
def update_frequency_magnitude(relayout_data,n_clicks):
    ctx = dash.callback_context

    # Check if triggered by reset button or zoom change
    if ctx.triggered and ctx.triggered[0]['prop_id'] == 'btn-reset-zoom.n_clicks':
        # Logic for reset (display full data range)
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    elif not relayout_data or 'xaxis.range[0]' not in relayout_data or 'yaxis.range[0]' not in relayout_data:
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    else:
        # Extract the zoom area coordinates
        x0, x1 = relayout_data['xaxis.range[0]'], relayout_data['xaxis.range[1]']
        y0, y1 = relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']

    # Convert coordinates to integer indices
    x0, x1 = max(int(x0), 0), min(int(x1), df.shape[1])
    y0, y1 = max(int(y0), 0), min(int(y1), df.shape[0])

    # Extract the relevant portion of the DataFrame (assuming the first 256 columns are image data)
    zoomed_df = df.iloc[y0:y1, :256]

    # Calculate the average of each column
    frequency_magnitude = zoomed_df.mean(axis=0)

    # Create a connected scatter plot using graph_objects
    fig = go.Figure(
        go.Scatter(
            x=frequency_magnitude.index, 
            y=frequency_magnitude.values, 
            mode='lines+markers',
            marker=dict(size=5),
            line=dict(width=1)
        )
    )

    fig.update_layout(
        xaxis_title='Index',
        yaxis_title='Frequency Magnitude',
        title='Average FFT Over Zoom Selection'
    )

    return fig



@app.callback(
    Output('model-accuracy', 'figure'),
    [Input('image-display', 'relayoutData'),
    Input('btn-reset-zoom', 'n_clicks')],
    prevent_initial_call=True
)
def update_model_accuracy(relayout_data,n_clicks):
    ctx = dash.callback_context

    # Check if triggered by reset button or zoom change
    if ctx.triggered and ctx.triggered[0]['prop_id'] == 'btn-reset-zoom.n_clicks':
        # Logic for reset (display full data range)
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    elif not relayout_data or 'xaxis.range[0]' not in relayout_data or 'yaxis.range[0]' not in relayout_data:
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    else:
        # Extract the zoom area coordinates
        x0, x1 = relayout_data['xaxis.range[0]'], relayout_data['xaxis.range[1]']
        y0, y1 = relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']
    # Convert coordinates to integer indices
    y0, y1 = max(int(y0), 0), min(int(y1), df.shape[0])

    # Extract the relevant portion of the accuracy data
    zoomed_accuracy = df['accuracy'].iloc[y0:y1].to_numpy()

    # Create a scatter plot of the accuracy data
    fig = go.Figure(
        go.Scatter(
            x=list(range(y0, y1)),
            y=zoomed_accuracy, 
            mode='lines+markers',
            marker=dict(size=5),
            line=dict(width=1)
        )
    )

    fig.update_layout(
        xaxis_title='Index',
        yaxis_title='Accuracy',
        title='Model Accuracy'
    )

    return fig


@app.callback(
    Output('time-domain', 'figure'),
    [Input('image-display', 'relayoutData'),
    Input('btn-reset-zoom', 'n_clicks')],
    prevent_initial_call=True
)
def update_time_domain(relayout_data,n_clicks):
    ctx = dash.callback_context

    # Check if triggered by reset button or zoom change
    if ctx.triggered and ctx.triggered[0]['prop_id'] == 'btn-reset-zoom.n_clicks':
        # Logic for reset (display full data range)
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    elif not relayout_data or 'xaxis.range[0]' not in relayout_data or 'yaxis.range[0]' not in relayout_data:
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    else:
        # Extract the zoom area coordinates
        x0, x1 = relayout_data['xaxis.range[0]'], relayout_data['xaxis.range[1]']
        y0, y1 = relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']

    # Convert coordinates to integer indices
    y0, y1 = max(int(y0), 0), min(int(y1), df.shape[0])

    # Extract the relevant portion of the IQ_data column
    zoomed_iq_data = df['IQ_data'].iloc[y0:y1]

    # Extract the real part of the complex numbers and ensure it's a list
    real_part = list(np.real(np.array([complex(val) for val in zoomed_iq_data])))

    # Create a connected scatter plot of the real part of the IQ data
    fig = go.Figure(
        go.Scatter(
            x=list(range(y0, y1)), 
            y=real_part, 
            mode='lines+markers',
            marker=dict(size=5),
            line=dict(width=1)
        )
    )

    fig.update_layout(
        xaxis_title='Index',
        yaxis_title='Real Part of IQ Data',
        title='Time Domain'
    )

    return fig



@app.callback(
    Output('constellation-plot', 'figure'),
    [Input('image-display', 'relayoutData'),
    Input('btn-reset-zoom', 'n_clicks')],
    prevent_initial_call=True
)
def update_constellation_plot(relayout_data,n_clicks):
    ctx = dash.callback_context

    # Check if triggered by reset button or zoom change
    if ctx.triggered and ctx.triggered[0]['prop_id'] == 'btn-reset-zoom.n_clicks':
        # Logic for reset (display full data range)
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    elif not relayout_data or 'xaxis.range[0]' not in relayout_data or 'yaxis.range[0]' not in relayout_data:
        x0, x1 = 0, df.shape[1]
        y0, y1 = 0, df.shape[0]
    else:
        # Extract the zoom area coordinates
        x0, x1 = relayout_data['xaxis.range[0]'], relayout_data['xaxis.range[1]']
        y0, y1 = relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']

    # Convert coordinates to integer indices
    y0, y1 = max(int(y0), 0), min(int(y1), df.shape[0])

    # Extract the relevant portion of the IQ_data column
    zoomed_iq_data = df['IQ_data'].iloc[y0:y1]

    # Separate the real and imaginary parts
    real_part = [np.real(complex(val)) for val in zoomed_iq_data]
    imaginary_part = [np.imag(complex(val)) for val in zoomed_iq_data]

    # Create a scatter plot for the constellation (I vs Q)
    fig = go.Figure(
        go.Scatter(
            x=real_part, 
            y=imaginary_part,
            mode='markers'
        )
    )

    fig.update_layout(
        xaxis_title='In-phase Component (I)',
        yaxis_title='Quadrature Component (Q)',
        title='Constellation Plot'
    )

    return fig


if __name__ == '__main__':
    app.run_server(debug=True)


       0  1  2  3  4  5  6  7  8  9  ...   fcam246   fcam247   fcam248  \
12795  0  0  0  0  0  0  0  0  0  0  ...  0.888793  0.888793  0.888793   
12796  0  0  0  0  0  0  0  0  0  0  ...  0.888793  0.888793  0.888793   
12797  0  0  0  0  0  0  0  0  0  0  ...  0.888793  0.888793  0.888793   
12798  0  0  0  0  0  0  0  0  0  0  ...  0.888793  0.888793  0.888793   
12799  0  0  0  0  0  0  0  0  0  0  ...  0.888793  0.888793  0.888793   

        fcam249   fcam250   fcam251   fcam252   fcam253   fcam254   fcam255  
12795  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  
12796  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  
12797  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  
12798  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  
12799  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  0.888793  

[5 rows x 515 columns]
