In [8]:
from pathlib import Path
import os

folder_name = "RW3_put_the_granola_bar_inside_the_basket"

traj_files = []
success_traces = []
for run_folder in Path(folder_name).glob("run_*"):
    traj_files.append(str(run_folder / "trajectory.hdf5"))
    success_traces.append(os.path.exists(str(run_folder / "Success.log")))

print(traj_files)

['RW3_put_the_granola_bar_inside_the_basket/run_007/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_009/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_008/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_006/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_001/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_010/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_004/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_003/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_002/trajectory.hdf5', 'RW3_put_the_granola_bar_inside_the_basket/run_005/trajectory.hdf5']


In [9]:
import h5py
import plotly

fig = plotly.graph_objs.Figure()
trajs = []
for traj_id in range(len(traj_files)):
    with h5py.File(traj_files[traj_id], "r") as f:
        skill_ids = f['skill_id'][:]

    skill_ids = skill_ids.squeeze()
    trajs.append(skill_ids[::10])
    # plot sequence of skill ids
    fig.add_trace(plotly.graph_objs.Scatter(y=skill_ids, mode='lines'))
fig.show()
print(success_traces)


[True, True, True, False, True, False, True, True, False, False]


In [11]:
import plotly.graph_objects as go
import plotly.express as px

def plot_trajectories(trajectories):
    """
    Plots the given trajectories as horizontal bar segments with unique colors.
    :param trajectories: List of lists. Each inner list is a sequence of integer indices.
    """

    # Flatten the trajectories and create a set to get all unique ids
    unique_ids = sorted(set([item for sublist in trajectories for item in sublist]))
    colors = px.colors.qualitative.Plotly[:len(unique_ids)] # Assuming there are not more unique ids than available colors
    
    # Map each id to a color
    id_color_map = dict(zip(unique_ids, colors))
    
    fig = go.Figure()
    
    # For each trajectory, add a bar
    for i, traj in enumerate(trajectories):
        start = 0
        for idx in traj:
            fig.add_shape(type="rect",
                          x0=start, x1=start+1, y0=i-0.4, y1=i+0.4,
                          line=dict(color=id_color_map[idx]),
                          fillcolor=id_color_map[idx])
            start += 1
            
    # Add legend information
    for uid, color in id_color_map.items():
        fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers',
                                 marker=dict(size=10, color=color),
                                 legendgroup=str(uid), showlegend=True, name=str(uid)))

    # Set axis parameters
    fig.update_layout(yaxis=dict(tickvals=list(range(len(trajectories))),
                                 ticktext=["Trajectory {}".format(i+1) for i in range(len(trajectories))]),
                      xaxis_title="Sequence",
                      yaxis_title="Trajectories",
                      title="Visual Representation of Trajectories")
    
    # Add title
    fig.update_layout(title_text="Trajectories of TASK_NAME")

    # Make the figure interactive


    fig.show()

def plot_interactive_trajectories(trajectories):
    """
    Plots the given trajectories as horizontal bar segments with unique colors.
    :param trajectories: List of lists. Each inner list is a sequence of integer indices.
    """
    # Flatten the trajectories and create a set to get all unique ids
    unique_ids = sorted(set([item for sublist in trajectories for item in sublist]))
    # colors = go.Figure().data[0].marker.colors[:len(unique_ids)] # Using default color sequence
    colors = px.colors.qualitative.Plotly[:len(unique_ids)] # Assuming there are not more unique ids than available colors
    
    # Map each id to a color
    id_color_map = dict(zip(unique_ids, colors))
    
    fig = go.Figure()

    # For each trajectory, add bar segments
    for i, traj in enumerate(trajectories):
        base = 0
        for idx in traj:
            fig.add_trace(go.Bar(
                y=[f"Rollout {i+1}"],
                x=[1],
                base=[base],
                marker_color=id_color_map[idx],
                orientation='h',
                hovertext=f"ID: {idx}",
                hoverinfo="text",
                showlegend=False
            ))
            base += 1
        if success_traces[i]:
            # Add a green block to indicate success
            fig.add_trace(go.Bar(
                y=[f"Rollout {i+1}"],
                x=[1],
                base=[base],
                marker_color="green",
                orientation='h',
                hovertext=f"Success",
                hoverinfo="text",
                showlegend=False
            ))
        else:
            # Add a red block to indicate failure
            fig.add_trace(go.Bar(
                y=[f"Rollout {i+1}"],
                x=[1],
                base=[base],
                marker_color="red",
                orientation='h',
                hovertext=f"Failure",
                hoverinfo="text",
                showlegend=False
            ))

    # Add legend manually using dummy data
    for idx, color in id_color_map.items():
        fig.add_trace(go.Bar(
            y=[None],
            x=[None],
            marker_color=color,
            orientation='h',
            name=str(idx),
            showlegend=True
        ))
    fig.add_trace(go.Bar(
        y=[None],
        x=[None],
        marker_color="green",
        orientation='h',
        name="Success",
        showlegend=True
    ))
    fig.add_trace(go.Bar(
        y=[None],
        x=[None],
        marker_color="red",
        orientation='h',
        name="Failure",
        showlegend=True
    ))

    fig.update_layout(
        barmode='stack',
        yaxis=dict(autorange="reversed"),
        xaxis_title="Sequence",
        title="Visual Representation of Trajectories"
    )

    fig.show()
# Test with dummy trajectories

# plot_trajectories(trajs)


plot_interactive_trajectories(trajs)

