In [None]:
import pandas as pd
import plotly.graph_objs as go
import plotly.offline as pyo
import os

In [3]:
SAMPLE_DATA = {
    'state': 'IL',
    'label': 'FATAL',
    'topk': 3,
    'features': ['BAC', 'cause of crash', 'driver information', 'dynamic condition', 'intersection related', 'location type', 'num of vehicles', 'person information', 'specific location and route ID', 'time', 'vehicle information', 'work zone'],
    'avg_positive': [0.155447123539857, 0.1120321328848092, 0.0456962695169951, 0.0099706550875836, 0.0048780487804878, 0.0990414734959551, 0.0236351831579134, 0.0, 0.0584680839860688, 0.0547746646461932, 0.0139575288294397, 0.0187350364830606, 0.0465709171680576],
    'avg_negative': [0.1170764868685747, 0.3393624334613537, 0.2777229598164427, 0.2572568497164587, 0.1519634730326408, 0.2498509536510904, 0.2224060028312215, 0.0, 0.0927285292286164, 0.3953802275225382, 0.245938741335926, 0.1445186207845508, 0.2254970982796407],
    'label_positive': [0.7307692307692307, 0.0384615384615384, 0.0, 0.0, 0.0, 0.1153846153846153, 0.0, 0.0384615384615384, 0.1153846153846153, 0.0, 0.0, 0.0384615384615384, 0.2],
    'label_negative': [0.0384615384615384, 0.3076923076923077, 0.5, 0.3846153846153846, 0.0769230769230769, 0.0, 0.4615384615384615, 0.0769230769230769, 0.3461538461538461, 0.3461538461538461, 0.0, 0.2307692307692307, 0.1846153846153846]
}

In [4]:
def create_radar_plot(data_dict, output_dir="./output", show_plot=True, high_resolution=True):
    """
    Create a radar plot comparing positive and negative Shapley values.
    
    Parameters:
    data_dict (dict): Dictionary containing data (output from extract_data function)
    output_dir (str): Directory to save output files
    show_plot (bool): Whether to open the plot in browser
    high_resolution (bool): Whether to save high-resolution image
    
    Returns:
    None: Saves the plot as HTML and PNG files
    """
    # Extract data from dictionary
    state = data_dict['state']
    label = data_dict['label']
    topk = data_dict['topk']
    features = data_dict['features']
    avg_data1 = data_dict['avg_positive']
    avg_data2 = data_dict['avg_negative']
    individual_data1 = data_dict['label_positive']
    individual_data2 = data_dict['label_negative']
    
    # Ensure all lists have the same length
    max_len = max(len(avg_data1), len(individual_data1), len(avg_data2), len(individual_data2), len(features))
    avg_data1 += [0] * (max_len - len(avg_data1))
    avg_data2 += [0] * (max_len - len(avg_data2))
    individual_data1 += [0] * (max_len - len(individual_data1))
    individual_data2 += [0] * (max_len - len(individual_data2))
    features += [""] * (max_len - len(features))
    
    # Remove invalid entries
    for i in range(len(features) - 1, -1, -1):
        if (not isinstance(features[i], str) or features[i].strip() == "" or pd.isna(features[i]) or
            avg_data1[i] is None or pd.isna(avg_data1[i]) or
            individual_data1[i] is None or pd.isna(individual_data1[i]) or
            avg_data2[i] is None or pd.isna(avg_data2[i]) or
            individual_data2[i] is None or pd.isna(individual_data2[i])):
            del features[i]
            del avg_data1[i]
            del individual_data1[i]
            del avg_data2[i]
            del individual_data2[i]
    
    # Calculate global min and max for plot scaling
    global_min = min(min(avg_data1), min(individual_data1), min(avg_data2), min(individual_data2)) - 0.05
    global_max = max(max(avg_data1), max(individual_data1), max(avg_data2), max(individual_data2)) + 0.05
    
    # Create radar plot data
    data_plot = [
        go.Scatterpolar(
            r=avg_data1 + [avg_data1[0]],
            theta=features + [features[0]],
            fill="toself",
            fillcolor="rgba(0, 0, 255, 0.1)",
            name="Positive Shapley Average",
        ),
        go.Scatterpolar(
            r=individual_data1 + [individual_data1[0]],
            theta=features + [features[0]],
            fill="none",
            name=f"{label.title()} - Positive",
            line=dict(width=2, color="red"),
        ),
        go.Scatterpolar(
            r=avg_data2 + [avg_data2[0]],
            theta=features + [features[0]],
            fill="toself",
            fillcolor="rgba(30, 0, 0, 0.1)",
            name="Negative Shapley Average",
        ),
        go.Scatterpolar(
            r=individual_data2 + [individual_data2[0]],
            theta=features + [features[0]],
            fill="none",
            name=f"{label.title()} - Negative",
            line=dict(width=2, color="purple"),
        )
    ]
    
    # Add concentric circles for the grid
    radii = [
        global_min, 
        global_min + 0.25 * (global_max - global_min), 
        global_min + 0.5 * (global_max - global_min), 
        global_min + 0.75 * (global_max - global_min), 
        global_max
    ]
    
    for radius in radii:
        data_plot.append(
            go.Scatterpolar(
                r=[radius] * (len(features) + 1),
                theta=features + [features[0]],
                mode="lines",
                line=dict(color="rgba(0, 0, 0, 0.5)", dash="dash", width=1),
                showlegend=False,
            )
        )
    
    # Create the figure
    fig = go.Figure(
        data=data_plot,
        layout=go.Layout(
            title=go.layout.Title(
                text=f"Top{topk} Ratios for Shapley Values of {label.title()} Severity in {state}",
                font=dict(size=17),
            ),
            polar=dict(
                bgcolor="white",
                radialaxis=dict(
                    visible=True,
                    range=[global_min, global_max],
                    tickvals=radii,
                    showticklabels=False,
                ),
            ),
            showlegend=True,
            legend=dict(font=dict(size=16)),
            paper_bgcolor="white",
            plot_bgcolor="white",
        ),
    )
    
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Save plot as HTML and PNG
    filename_base = f'{output_dir}/{state}_{label}_top{topk}_all_radar_plot'
    pyo.plot(fig, filename=f'{filename_base}.html', auto_open=show_plot)
    
    # Save high-resolution image if requested
    scale = 10 if high_resolution else 1
    fig.write_image(f'{filename_base}.png', scale=scale)
    
    print(f"Plot saved as {filename_base}.html and {filename_base}.png")
    return fig


In [6]:
create_radar_plot(
    SAMPLE_DATA, 
    output_dir='./', 
    show_plot=True, 
    high_resolution=True
)

Plot saved as .//IL_FATAL_top3_all_radar_plot.html and .//IL_FATAL_top3_all_radar_plot.png
