In [1]:
import plotly.graph_objects as go

In [2]:
data = [
    ("Driver Behavior", "Other", 0, 0.030192299474557222),
    ("Driver Behavior", "Improper Driving Behavior", 0, 0.025939835859643306),
    ("Driver Behavior", "Aggressive Driving Behavior", 0, 0.14236087067251474),
    ("Driver Behavior", "Impairment-Related", 0, 0.10870720512535892),
    ("Driver Behavior", "Traffic Control/Right-of-Way Violations", 0, 0.059803444324595695),
    ("BAC", "Unknown or not offered", 0, 0.02216343476517974),
    ("BAC", ">=80", 0, 0.4808653369157912),
    ("BAC", "<80", 0, 0.47121275022682596),
    ("Work Zone", "No", 0, -0.0032566226227484884),
    ("Work Zone", "Yes", 0, 0.10091418860472486),
    ("UserType", "Pedalcyclist or Pedestrian", 0, 0.1682276858720747),
    ("UserType", "Other", 0, 0.006634390128970741),
    ("Roadway Type", "Non-Freeway", 0, 0.1054306915666446),
    ("Roadway Type", "Freeway", 0, 0.05661843837256921),
]


In [3]:
titles = [f"{category}: {label}" for category, label, _, _ in data]
start_values = [start for _, _, start, _ in data]
end_values = [end for _, _, _, end in data]
mean_values = [(start + end) / 2 for start, end in zip(start_values, end_values)]

# Determine colors
colors = ["#de9cad" if end > start else "#7c9bbf" for start, end in zip(start_values, end_values)]

# Create y positions
y_pos = list(range(len(titles)))

# Create plot
fig = go.Figure()
for i, title in enumerate(titles):
    fig.add_trace(go.Bar(
        x=[end_values[i] - start_values[i]],
        y=[y_pos[i]],
        base=start_values[i],
        orientation='h',
        marker=dict(color=colors[i]),
        hoverinfo="x+y+name"
    ))

# Update layout
fig.update_layout(
    yaxis=dict(
        tickvals=y_pos,
        ticktext=titles,
        autorange="reversed"
    ),
    xaxis=dict(
        title="Correlation values",
        gridcolor="lightgray",
        zeroline=False,
        zerolinecolor="gray"
    ),
    title="Generalized Data Comparison",
    barmode='overlay',
    height=25 * len(titles),
    margin=dict(l=200, r=50, t=50, b=50),
    showlegend=False
)

# Add line legends
fig.add_trace(go.Scatter(
    x=[None], y=[None], mode='lines',
    line=dict(color="#de9cad", width=4),
    name="> Start Value",
    showlegend=True
))
fig.add_trace(go.Scatter(
    x=[None], y=[None], mode='lines',
    line=dict(color="#7c9bbf", width=4),
    name="< Start Value",
    showlegend=True
))
fig.update_layout(
    plot_bgcolor="rgba(0,0,0,0)",
    paper_bgcolor="rgba(0,0,0,0)"
)

# Show plot
fig.show()
# Save to HTML
fig.write_html("general_comparison.html")
