In [None]:
import pandas as pd
import plotly.graph_objects as go

# Load CSV
df = pd.read_csv("multi_goal_steering.csv")

# Clean up labels
df["label"] = df["label"].str.replace("not low", "high", regex=False)
df["label"] = df["label"].str.replace("not high", "low", regex=False)

# Create a Plotly 3D scatter plot with lines
fig = go.Figure()

# Define a color map for the labels
color_map = {
    label: color for label, color in zip(
        df["label"].unique(), ["red", "blue", "green", "orange"]
    )
}

# Add one trace per label
for label in df["label"].unique():
    subset = df[df["label"] == label].sort_values(by="step")
    fig.add_trace(
        go.Scatter3d(
            x=subset["step"],
            y=subset["charge_ph7"],
            z=subset["instability_index"],
            mode='lines+markers',
            marker=dict(size=4),
            line=dict(width=3, color=color_map[label]),
            name=label
        )
    )

# Set layout
fig.update_layout(
    scene=dict(
        xaxis_title='Step',
        yaxis_title='Charge at pH7',
        zaxis_title='Instability Index',
    ),
    title="Multi-goal Steering: Charge at pH7 & Instability Index over Steps",
    legend=dict(x=0, y=1)
)

# Show the interactive plot
fig.show()
