---
title: "Planned vs Reactive Experiment Results"
format:
  html:
    code-fold: true
    page-layout: custom
jupyter: python3
---

In [None]:
import json
import glob
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def process_json_files(file_pattern='*.json'):
    """
    Process multiple JSON files and combine their data into a DataFrame.
    """
    all_error_data = []
    all_reactive_data = []
    false_positives_data = []
    
    for filename in glob.glob(file_pattern):
        try:
            with open(filename, 'r') as f:
                # Read the file content
                file_content = f.read()
                
                # Split the content into separate JSON objects
                json_objects = []
                current_json = ""
                depth = 0
                
                for char in file_content:
                    current_json += char
                    if char == '{':
                        depth += 1
                    elif char == '}':
                        depth -= 1
                        if depth == 0 and current_json.strip():
                            try:
                                json_obj = json.loads(current_json.strip())
                                json_objects.append(json_obj)
                                current_json = ""
                            except json.JSONDecodeError:
                                pass
                
                # Process the main data object
                main_data = next((obj for obj in json_objects if 'results' in obj), None)
                if main_data:
                    # Extract user info
                    user_info = main_data.get('userinfo', {})
                    user_id = user_info.get('id', 'unknown')
                    
                    # Handle reactive results if present
                    if 'reactive_results' in main_data:
                        reactive_row = {
                            'filename': filename,
                            'user_id': user_id,
                            **main_data.get('reactive_results', {}).get('ObjectsMissed', {})
                        }
                        all_reactive_data.append(reactive_row)
                    
                    # Process main results
                    results = main_data.get('results', {})
                    for error_type, error_data in results.items():
                        for instance_num, instance_data in error_data.items():
                            row = {
                                'error_type': error_type,
                                'instance': int(instance_num),
                                'missed': instance_data['missed'],
                                'time': instance_data['time'],
                                'filename': filename,
                                'user_id': user_id
                            }
                            all_error_data.append(row)
                
                # Process false positives data
                fp_obj = next((obj for obj in json_objects if 'FalsePositives' in obj), None)
                if fp_obj:
                    fp_data = fp_obj.get('FalsePositives', {}).get('FalsePositives', {})
                    for error_type, count in fp_data.items():
                        false_positives_data.append({
                            'error_type': error_type,
                            'count': count,
                            'filename': filename,
                            'user_id': user_id
                        })

        except Exception as e:
            print(f"Error processing file {filename}: {e}")
            continue

    # Create DataFrames with default columns even if empty
    error_df = pd.DataFrame(all_error_data) if all_error_data else pd.DataFrame(
        columns=['error_type', 'instance', 'missed', 'time', 'filename', 'user_id'])
    reactive_df = pd.DataFrame(all_reactive_data) if all_reactive_data else pd.DataFrame(
        columns=['filename', 'user_id', 'rows', 'structures'])
    false_positives_df = pd.DataFrame(false_positives_data) if false_positives_data else pd.DataFrame(
        columns=['error_type', 'count', 'filename', 'user_id'])
    
    return error_df, reactive_df, false_positives_df


def create_interactive_visualization(error_df, false_positives_df, type):
    """
    Create interactive Plotly visualization with stacked bar plots for each error type.
    """
    error_types = ['CameraError', 'FlightError', 'HardwareError']
    users = sorted(error_df['user_id'].unique())
    colors = {'caught': 'lightgreen', 'missed': 'lightcoral', 'false_positives': 'lightskyblue'}
    
    # Create subplots with consistent width
    fig = make_subplots(
        rows=1, 
        cols=3, 
        subplot_titles=error_types,
        horizontal_spacing=0.1
    )
    
    # Keep track of trace indices
    total_traces = []
    user_traces = []
    
    # Create base traces (aggregated view)
    for i, error_type in enumerate(error_types, 1):
        error_data = error_df[error_df['error_type'] == error_type]
        fp_data = false_positives_df[false_positives_df['error_type'] == error_type]
        
        # Calculate aggregated values
        total_caught = (~error_data['missed']).sum()
        total_missed = error_data['missed'].sum()
        total_fp = fp_data['count'].sum()
        
        # Add traces for aggregated view
        total_trace_idx = len(fig.data)
        fig.add_trace(
            go.Bar(
                x=['All Users'],
                y=[total_caught],
                name='Caught' if i == 1 else None,
                marker_color=colors['caught'],
                showlegend=i == 1,
                legendgroup='caught',
                hovertemplate='Caught: %{y}<br>Error Type: ' + error_type + '<extra></extra>',
            ),
            row=1, col=i
        )
        total_traces.append(len(fig.data) - 1)
        
        fig.add_trace(
            go.Bar(
                x=['All Users'],
                y=[total_missed],
                name='Missed' if i == 1 else None,
                marker_color=colors['missed'],
                showlegend=i == 1,
                legendgroup='missed',
                hovertemplate='Missed: %{y}<br>Error Type: ' + error_type + '<extra></extra>',
            ),
            row=1, col=i
        )
        total_traces.append(len(fig.data) - 1)
        
        fig.add_trace(
            go.Bar(
                x=['All Users'],
                y=[total_fp],
                name='False Positives' if i == 1 else None,
                marker_color=colors['false_positives'],
                showlegend=i == 1,
                legendgroup='false_positives',
                hovertemplate='False Positives: %{y}<br>Error Type: ' + error_type + '<extra></extra>',
            ),
            row=1, col=i
        )
        total_traces.append(len(fig.data) - 1)
        
        # Add individual user traces (initially hidden)
        for user_id in users:
            user_data = error_data[error_data['user_id'] == user_id]
            user_fp = fp_data[fp_data['user_id'] == user_id]
            
            # Calculate values
            caught = (~user_data['missed']).sum()
            missed = user_data['missed'].sum()
            fp_count = user_fp['count'].iloc[0] if not user_fp.empty else 0
            
            trace_idx = len(fig.data)
            fig.add_trace(
                go.Bar(
                    x=[f'User {user_id}'],
                    y=[caught],
                    name='Caught',
                    marker_color=colors['caught'],
                    showlegend=False,
                    visible=False,
                    legendgroup='caught',
                    hovertemplate=f'User {user_id}<br>Caught: %{{y}}<extra></extra>',
                ),
                row=1, col=i
            )
            user_traces.append(len(fig.data) - 1)
            
            fig.add_trace(
                go.Bar(
                    x=[f'User {user_id}'],
                    y=[missed],
                    name='Missed',
                    marker_color=colors['missed'],
                    showlegend=False,
                    visible=False,
                    legendgroup='missed',
                    hovertemplate=f'User {user_id}<br>Missed: %{{y}}<extra></extra>',
                ),
                row=1, col=i
            )
            user_traces.append(len(fig.data) - 1)
            
            fig.add_trace(
                go.Bar(
                    x=[f'User {user_id}'],
                    y=[fp_count],
                    name='False Positives',
                    marker_color=colors['false_positives'],
                    showlegend=False,
                    visible=False,
                    legendgroup='false_positives',
                    hovertemplate=f'User {user_id}<br>False Positives: %{{y}}<extra></extra>',
                ),
                row=1, col=i
            )
            user_traces.append(len(fig.data) - 1)

    # Add buttons for user selection
    buttons = [
        dict(
            args=[{
                'visible': [i in total_traces for i in range(len(fig.data))]
            }],
            label="All Users",
            method="restyle"
        )
    ]
    
    # Add button for each user
    for idx, user_id in enumerate(users):
        user_visibility = []
        for i in range(len(fig.data)):
            subplot_idx = i // (3 * (len(users) + 1))
            is_user_trace = (i - (subplot_idx * 3 * (len(users) + 1)) - 3) // 3 == idx
            user_visibility.append(is_user_trace)
        
        buttons.append(dict(
            args=[{'visible': user_visibility}],
            label=f"User {user_id}",
            method="restyle"
        ))

    # Update layout with all settings
    fig.update_layout(
        barmode='stack',
        title_text=f"{type.capitalize()} Mode: Error Distribution by Type",
        height=500,
        width=1000,
        showlegend=True,
        legend_title_text="Error Categories",
        hovermode='x unified',
        margin=dict(l=50, r=120, t=130, b=50),
        updatemenus=[dict(
            buttons=buttons,
            direction="down",
            showactive=True,
            x=1.2,
            xanchor="right",
            y=1.1,
            yanchor="bottom",
            bgcolor='white',
            bordercolor='darkgray',
            font=dict(size=12),
            pad=dict(r=10, t=10)
        )],
        annotations=[
            dict(
                text=title,
                x=x,
                y=1.025,
                xref="paper",
                yref="paper",
                showarrow=False,
                font=dict(size=14)
            )
            for title, x in zip(error_types, [0.13, 0.5, 0.87])
        ]
    )

    # Update axes labels
    fig.update_yaxes(title_text="Count", row=1, col=1)
    
    return fig

def create_heatmap(error_df, type):
    """
    Create Plotly heatmap for average error times with YlOrRd color scale and zero minimum.
    Uses zero values only when no non-zero values exist for a particular 
    instance/error type combination.
    """
    all_data_matrix = pd.pivot_table(
        error_df,
        values='time',
        index='instance',
        columns='error_type',
        aggfunc=lambda x: list(x)
    )
    
    # Create matrix for final values
    time_matrix = pd.DataFrame(
        index=all_data_matrix.index,
        columns=all_data_matrix.columns,
        dtype=float
    )
    
    # Fill in averages, using non-zero values when available, zero otherwise
    for idx in all_data_matrix.index:
        for col in all_data_matrix.columns:
            values = all_data_matrix.loc[idx, col]
            non_zero_values = [v for v in values if v > 0]
            if non_zero_values:
                time_matrix.loc[idx, col] = np.mean(non_zero_values)
            else:
                # If only zeros exist for this cell, use 0
                time_matrix.loc[idx, col] = 0
    
    # print(f"\n{type} Mode - Time Matrix Statistics:")
    # print(f"Min value: {time_matrix.values.min()}")
    # print(f"Max value: {time_matrix.values.max()}")
    
    hover_text = np.empty(time_matrix.shape, dtype=object)
    for i in range(time_matrix.shape[0]):
        for j in range(time_matrix.shape[1]):
            values = all_data_matrix.iloc[i, j]
            non_zero_count = sum(1 for v in values if v > 0)
            total_count = len(values)
            if non_zero_count > 0:
                hover_text[i, j] = f"Instance: {time_matrix.index[i]}<br>" \
                                 f"Error Type: {time_matrix.columns[j]}<br>" \
                                 f"Avg Time: {time_matrix.iloc[i, j]:.2f}<br>" \
                                 f"(from {non_zero_count} non-zero values)"
            else:
                hover_text[i, j] = f"Instance: {time_matrix.index[i]}<br>" \
                                 f"Error Type: {time_matrix.columns[j]}<br>" \
                                 f"Time: 0 (all {total_count} values were zero)"
    
    fig = go.Figure(data=go.Heatmap(
        z=time_matrix.values,
        x=time_matrix.columns,
        y=time_matrix.index,
        colorscale='YlOrRd',
        showscale=True,
        text=np.round(time_matrix.values, 2),
        texttemplate='%{text}',
        textfont={'size': 10},
        hovertemplate='%{customdata}<extra></extra>',
        customdata=hover_text,
        zmin=0,  # Set minimum value to 0
        zmid=time_matrix.values.max() / 2 if time_matrix.values.max() > 0 else 0.5,  # Set midpoint
        zauto=False  # Disable automatic range
    ))
    
    fig.update_layout(
        title=f'{type.capitalize()} Mode: Average Error Resolution Time by Instance<br><sub>*Using non-zero values where available; zero values shown where no non-zero times exist</sub>',
        xaxis_title='Error Type',
        yaxis_title='Instance',
        height=500,
        coloraxis_colorbar_title='Time'
    )

    fig = go.Figure(data=go.Heatmap(
        z=time_matrix.values,
        x=time_matrix.columns,
        y=time_matrix.index,
        colorscale='YlOrRd',
        showscale=True,
        text=np.round(time_matrix.values, 2),
        texttemplate='%{text}',
        textfont={'size': 10},
        hovertemplate='%{customdata}<extra></extra>',
        customdata=hover_text,
        zmin=0,
        zmid=time_matrix.values.max() / 2 if time_matrix.values.max() > 0 else 0.5,
        zauto=False
    ))
    
    fig.update_layout(
        title=f'{type.capitalize()} Mode: Average Error Resolution Time by Instance<br><sub>Using non-zero values where available; zero values shown where no non-zero times exist</sub>',
        xaxis_title='Error Type',
        yaxis_title='Instance',
        height=500,
        width=1000,  # Set explicit width
        margin=dict(l=50, r=50, t=100, b=50)
    )
    
    return fig


def planned_plot():
    error_df, planned_df, false_positives_df = process_json_files('*-planned.json')
    
    # Create and display interactive distribution plot
    dist_fig = create_interactive_visualization(error_df, false_positives_df, 'planned')
    dist_fig.show()
    
    # Create and display heatmap
    heatmap_fig = create_heatmap(error_df, 'planned')
    heatmap_fig.show()

def reactive_plot():
    error_df, reactive_df, false_positives_df = process_json_files('*-reactive.json')
    
    # Create and display interactive distribution plot
    dist_fig = create_interactive_visualization(error_df, false_positives_df, 'reactive')
    dist_fig.show()
    
    # Create and display heatmap
    heatmap_fig = create_heatmap(error_df, 'reactive')
    heatmap_fig.show()

::: {.panel-tabset}
## Planned

In [None]:
planned_plot()

## Reactive

In [None]:
reactive_plot()

:::