In [1]:
from pathlib import Path

In [38]:
plot_base_path = Path(
    "/Users/larsankile/Library/CloudStorage/Dropbox/Apps/Overleaf/[CoRL24] From Imitation to Refinement"
)

In [3]:
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Times New Roman"

# Analyze Distillation Scaling


In [None]:
import pandas as pd

In [None]:
df = pd.read_csv(
    "/data/scratch/ankile/robust-rearrangement/notebooks/data/scaling_low_1.csv"
)

df = df[["Name", "best_success_rate", "success_rate"]]

df.head()

In [None]:
df = df.assign(group_name=df.Name.str.split("-").str[:-1].str.join("-"))

df.head()

In [None]:
df = df.groupby("group_name").mean().reset_index()

df

# Results 4.2 -- Scaling and BC/RL Distillation

## Scaling


In [None]:
import wandb
import matplotlib.pyplot as plt

In [None]:
api = wandb.Api(overrides={"entity": "robust-assembly"})

In [None]:
run_ids = dict(
    state_1k="45i0dikc",
    state_10k="mm93m3qx",
    state_50k="8joapjpv",
    state_100k="wi6t4u51",
)

project = "ol-vision-scaling-low-1"

In [None]:
# Get a project
runs = {name: api.run(f"{project}/{run_id}") for name, run_id in run_ids.items()}

In [None]:
best_success_rates = {
    name: run.summary["best_success_rate"] for name, run in runs.items()
}

best_success_rates

In [None]:
bc_success_rate = 0.54
rl_success_rate = 0.95

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Set the font family to Times New Roman
plt.rcParams["font.family"] = "Times New Roman"

# Data
num_demonstrations = [50, 1000, 10000]
distilled_success_rate = [54, 70, 77]
rl_success_rate = 95
bc_success_rate = 54

# Create the line plot
plt.figure(figsize=(4, 3.5))

plt.axhline(
    y=rl_success_rate,
    linestyle="--",
    linewidth=3,
    color="#2398DA",
    label="RL policy success rate",
)

plt.axhline(
    y=bc_success_rate,
    linestyle="--",
    linewidth=3,
    color="#6B529C",
    label="BC policy success rate",
)

plt.plot(
    num_demonstrations,
    distilled_success_rate,
    marker="o",
    linestyle="-",
    linewidth=3,
    markersize=8,
    label="Distilled policy success rate",
    color="#E34A6F",
)
# Add labels and title
plt.xlabel("Number of Demonstrations", fontsize=16)

# Set x-axis to logarithmic scale
plt.xscale("log")

# Set x-axis tick labels as exponents of 10
plt.xticks(
    [10, 100, 1_000, 10_000], ["$10^1$", "$10^2$", "$10^3$", "$10^4$"], fontsize=16
)

# Add legend
plt.legend(loc="lower left", fontsize=15, frameon=False, handletextpad=0.2)

# Set y-axis limits
plt.ylim(0, 100)
plt.xlim(40, 11_000)

# Set y-ticks font size
plt.yticks(fontsize=16)

# Set y-ticks to be increments of 25
plt.yticks(np.arange(0, 101, 25))

# Add grid lines
plt.grid(axis="y", linestyle="--", linewidth=0.5, zorder=1)

# Remove the border around the plot
plt.box(False)

# Adjust the layout
plt.tight_layout()

# Save the figure as a PDF
plt.savefig(plot_base_path / "fig" / "scaling.pdf", format="pdf", dpi=300)

# Display the plot
plt.show()

## BC/RL Distillation


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data
categories = ["DP", "DP + residual PPO"]
vision_based_student = [50, 73]
total = [53, 95]

# Calculate the state values
state_based_teacher = np.array(total) - np.array(vision_based_student)

# Set the width of each bar
bar_width = 0.75

# Set the positions of the bars on the x-axis
r = range(len(categories))

# Create the stacked bar plot
plt.figure(figsize=(4, 5.5))
plt.bar(
    r,
    state_based_teacher,
    width=bar_width,
    bottom=vision_based_student,
    label="State-based teacher",
    color="#6B529C",
    zorder=2,
)
plt.bar(
    r,
    vision_based_student,
    width=bar_width,
    label="Vision-based student",
    color="#2398DA",
    zorder=2,
)
plt.xticks(r, categories, fontsize=16)

# Add legend in upper left
plt.legend(
    loc="upper left",
    fontsize=14,
    frameon=False,
    bbox_to_anchor=(-0.18, 0.98),
    handletextpad=0.2,
)

plt.box(False)


# Add horizontal grid lines
plt.grid(axis="y", linestyle="--", linewidth=0.5, zorder=1)

# Set y-axis limits
plt.ylim(0, 100)

# Increase the size of the y-axis ticks
plt.yticks(fontsize=16)

# Save figure to base path + distillation.pdf
plt.savefig(
    plot_base_path / "fig" / "distillation.pdf",
    format="pdf",
    dpi=300,
    bbox_inches="tight",
)

# Display the plot
plt.tight_layout()
plt.show()

# Sankey diagram over real-world data


In [4]:
import sys

print(sys.executable)

/opt/homebrew/opt/python@3.11/bin/python3.11


In [97]:
import plotly.graph_objects as go

labels = [
    "Corner (10 / 10)",
    "Grasp (9 / 10)",
    "Insert (6 / 10)",
    "Screw (5 / 10)",
    "Success",
    "Failure",
]

fig = go.Figure(
    data=[
        go.Sankey(
            node=dict(
                pad=15,
                thickness=20,
                line=dict(color="black", width=0.5),
                label=labels,
                color=[
                    "#F4F269",
                    "#CEE26B",
                    "#A8D26D",
                    "#82C26E",
                    "#5CB270",
                    "#A41623",
                ],
                x=[0, 0.25, 0.45, 0.65, 1],
                y=[0, 0.5, 0.4, 0.25, 0.0],
            ),
            link=dict(
                source=[0, 0, 1, 1, 2, 2, 3, 3],
                target=[1, 5, 2, 5, 3, 5, 4, 5],
                value=[10, 0, 9, 1, 6, 3, 5, 1],
                label=["10", "0", "9", "1", "6", "3", "5", "1"],
            ),
            textfont=dict(size=30, family="Times New Roman"),
        )
    ]
)

# for x_coordinate, column_name in enumerate(["column 1", "column 2", "column 3", "column 3", "column 3", "column 3"]):
#     fig.add_annotation(
#         x=x_coordinate,
#         y=1.05,
#         xref="x",
#         yref="paper",
#         text=column_name,
#         showarrow=False,
#         font=dict(family="Courier New, monospace", size=16, color="tomato"),
#         align="center",
#     )

# fig.update_layout(
#     xaxis={
#         "showgrid": False,  # thin lines in the background
#         "zeroline": False,  # thick line at x=0
#         "visible": False,  # numbers below
#     },
#     yaxis={
#         "showgrid": False,  # thin lines in the background
#         "zeroline": False,  # thick line at x=0
#         "visible": False,  # numbers below
#     },
#     plot_bgcolor="rgba(0,0,0,0)",
#     font_size=10,
# )

fig.update_layout(
    margin=dict(l=0, r=0, t=20, b=10),
    height=400,
    width=1000,
)

fig.write_image(str(plot_base_path / "appendix_fig" / "real_success_sankey.pdf"))
fig.show()