In [None]:
import os
from collections import defaultdict
import torch
import wandb
import numpy as np
from tqdm import tqdm
import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns

from diffusion_co_design.common import OUTPUT_DIR, get_latest_model, cuda
from diffusion_co_design.rware.diffusion.generate import generate
from diffusion_co_design.rware.env import render_env as render_rware_env
from diffusion_co_design.rware.design import create_designer, DescentDesigner
from diffusion_co_design.rware.diffusion.transform import (
    storage_to_layout,
    storage_to_layout_image,
    graph_projection_constraint,
    train_to_eval,
)
from diffusion_co_design.rware.diffusion.generator import (
    Generator as RwareGenerator,
    OptimizerDetails,
)
from diffusion_co_design.rware.model.classifier import make_model
from diffusion_co_design.rware.schema import (
    ScenarioConfig as RwareScenarioConfig,
    TrainingConfig as RwareTrainingConfig,
)
from wfcrl.environments.data_cases import floris_ormonde
from diffusion_co_design.wfcrl.schema import (
    ScenarioConfig as WfcrlScenarioConfig,
    TrainingConfig as WfcrlTrainingConfig,
)
from diffusion_co_design.wfcrl.design import DicodeDesigner
from diffusion_co_design.wfcrl.diffusion.generate import Generate
from diffusion_co_design.wfcrl.env import _create_designable_windfarm, render_layout
from rware.warehouse import Warehouse


# Wandb limits to 500
def get_full_history(run, key):
    values = []
    for row in run.scan_history(keys=[key]):
        values.append(row[key])
    return np.array(values)


def ema(data: np.ndarray, alpha: float = 0.95):
    ema = np.zeros_like(data)

    ema[0] = data[0]
    for i in range(1, data.shape[0]):
        ema[i] = alpha * ema[i - 1] + (1 - alpha) * data[i]

    return ema


device = cuda

In [None]:
shelf_im = generate(size=8, n_shelves=20, goal_idxs=[0, 7, 55, 63], n_colors=4)[0]
layout = storage_to_layout_image(
    shelf_im,
    agent_idxs=[12, 23, 54, 8],
    agent_colors=[-1, -1, -1, -1],
    goal_idxs=[0, 7, 56, 63],
    goal_colors=[0, 1, 2, 3],
)
warehouse = Warehouse(layout=layout)
image = warehouse.render()
image = image[::2, ::2]
warehouse.close()

H, W, C = image.shape

noise = np.random.normal(loc=0.0, scale=1.0, size=image.shape)
noise = (noise - noise.min()) / (noise.max() - noise.min())
noise = noise * 255
noise = noise.astype(np.uint8)

for i, beta in enumerate([0, 0.5, 0.75, 1]):
    blended = ((1 - beta) * image + beta * noise).astype(np.uint8)
    plt.figure(figsize=(12, 6))
    plt.imshow(blended)
    plt.axis("off")
    plt.savefig(f"blended_{i}.png", bbox_inches="tight", dpi=300)

In [None]:
warehouse = Warehouse(layout=layout)
image = warehouse.render()
warehouse.close()
plt.imshow(image)
plt.axis("off")
plt.savefig("d-rware", bbox_inches="tight", dpi=300)

In [None]:
scenario = RwareScenarioConfig(
    name="d-rware-example",
    n_agents=3,
    n_shelves=16,
    n_colors=4,
    goal_idxs=[0, 7, 56, 63],
    agent_idxs=[12, 23, 54, 8],
    agent_colors=[-1, -1, -1, -1],
    n_goals=4,
    goal_colors=[0, 1, 2, 3],
    max_steps=100,
    size=8,
)


x0 = []
for x in range(2, 6):
    for y in range(2, 6):
        x0.append((x, y))
x0 = np.array(x0)


def plot_points(layout, scenario: RwareScenarioConfig):
    fig, ax = plt.subplots(1, 1)
    x = [p[0] for p in layout]
    y = [p[1] for p in layout]

    ax.scatter(x, y, s=130, color="blue", edgecolors="black", linewidths=0.5, zorder=3)
    ax.axis("off")

    # Grid lines
    for x in range(scenario.size):
        ax.plot(
            [x, x],
            [0, scenario.size - 1],
            color="gray",
            linewidth=1,
        )

    for y in range(scenario.size):
        ax.plot(
            [0, scenario.size - 1],
            [y, y],
            color="gray",
            linewidth=1,
        )

    ax.set_aspect("equal")
    fig.set_tight_layout(True)
    return fig, ax


# x0
fig, axs = plot_points(x0, scenario=scenario)
fig.savefig("pug_0.png", bbox_inches="tight", dpi=300)

# xT
xT = np.clip(
    (np.random.normal(0, 1.0, size=x0.shape) + 1) / 2 * (scenario.size - 1),
    0,
    (scenario.size - 1),
)
fig, axs = plot_points(xT, scenario=scenario)
fig.savefig("pug_T.png", bbox_inches="tight", dpi=300)

# xt
alpha = 0.7
xt = alpha**0.4 * x0 + (1 - alpha) * xT
fig, axs = plot_points(xt, scenario=scenario)
fig.savefig("pug_t.png", bbox_inches="tight", dpi=300)

# xt_0
xt_0 = x0 + np.random.normal(0, 0.3, size=x0.shape)
fig, axs = plot_points(xt_0, scenario=scenario)
fig.savefig("pug_t0.png", bbox_inches="tight", dpi=300)

# xt constrained
xt_constr = graph_projection_constraint(scenario)(
    torch.tensor(xt_0 / (scenario.size - 1) * 2 - 1, dtype=torch.float32).unsqueeze(0)
)[0].numpy()
xt_constr = (xt_constr + 1) / 2 * (scenario.size - 1)
xt_constr = np.round(xt_constr)
fig, axs = plot_points(xt_constr, scenario=scenario)
fig.savefig("pug_t_constr.png", bbox_inches="tight", dpi=300)

In [None]:
xcoords = np.array(floris_ormonde.xcoords)
ycoords = np.array(floris_ormonde.ycoords)
margin = 300

xcoords = xcoords - xcoords.min() + margin
ycoords = ycoords - ycoords.min() + margin

scenario = WfcrlScenarioConfig(
    name="ormonde_render_example",
    n_turbines=len(xcoords),
    max_steps=margin,
    map_x_length=int(xcoords.max() + margin),
    map_y_length=int(ycoords.max() + margin),
    min_distance_between_turbines=400,
)

env = _create_designable_windfarm(
    scenario=scenario,
    initial_xcoords=xcoords.tolist(),
    initial_ycoords=ycoords.tolist(),
    render=True,
)

# Take some random steps
env.reset()
for _ in range(2000):
    env.step({"yaw": np.array([np.random.rand() * 10 - 5])})

fig, axs = plt.subplots(figsize=(6, 6))
axs.axis("off")
axs.imshow(env.render(), aspect="auto")

fig.savefig("wfcrl_ormonde.png", bbox_inches="tight", dpi=300)

In [None]:
# D-RWARE Corners Plot
project_name = "diffusion-co-design-rware"
api = wandb.Api()
runs = api.runs(path=project_name)

total_steps = 4000
runs_dict = defaultdict(list)
train_reward_key = "train/reward/episode_reward_mean"


for run in tqdm(runs):
    name = run.name
    cfg = run.config
    reward = get_full_history(run, train_reward_key)

    d_loss = get_full_history(run, "train/designer_loss")
    d_min = get_full_history(run, "train/design_y_min")
    d_max = get_full_history(run, "train/design_y_max")

    runs_dict[name].append(
        {"cfg": cfg, "reward": reward, "d_loss": d_loss, "d_min": d_min, "d_max": d_max}
    )

In [None]:
sns.set_theme(context="notebook")
mpl.rcParams["font.family"] = "monospace"
fig, axs = plt.subplots(1, 1)
fig.set_size_inches(8.1, 5)

key_to_label_map = {
    "corners_agent_distill_image_210525": "DiCoDe",
    "corners_agent_distill_gnn_210525": "DiCoDe-Points",
    "corners_agent_image_210525": "DiCoDe-MC",
    "corners_agent_descent_210525": "DiCoDe-Descent",
    "corners_agent_sampling_210525": "DiCoDe-Sampling",
    "corners_agent_fixed_210525": "Fixed",
    "corners_agent_random_210525": "DR",
    "corners_agent_rl_210525": "RL",
}
total_training_iterations = 4000
samples_per_iteration = 5000
colors = sns.color_palette(n_colors=len(key_to_label_map))


for (key, label), color in zip(key_to_label_map.items(), colors):
    runs = runs_dict[key]

    rewards = []
    for x in runs:
        reward = ema(x["reward"])
        if len(x["reward"]) != total_training_iterations and label != "RL":
            # Run not complete, skip
            continue
        rewards.append(reward)
    rewards = np.array(rewards)

    if label == "RL":
        X = np.linspace(1, 2000 + 1, 2000)
        X = X * (samples_per_iteration + 10000)
        X = X[:1333]  # Too many samples
        print("RL edge case", rewards[:, -1].mean(axis=0))
        rewards = rewards[:, :1333]

    else:
        X = (
            np.linspace(1, total_training_iterations + 1, total_training_iterations)
            * samples_per_iteration
        )

    mu = rewards.mean(axis=0)
    print(label, f"mean: {mu[-1]}")
    axs.plot(X, mu, color=color, label=label)
    if rewards.shape[0] > 1:
        std = rewards.std(axis=0)
        print(f"std: {std[-1]}")
        axs.fill_between(X, y1=mu - std, y2=mu + std, color=color, alpha=0.3)
    pass

axs.set_title("D-RWARE (Corner) Training Progress")
axs.set_xlabel("Frames")
axs.set_ylabel("Episode Reward")
axs.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=4, frameon=False)

fig.savefig(fname="d-rware-corners.png", bbox_inches="tight", dpi=300)

In [None]:
fig, axs = plt.subplots(1, 3)
fig.set_size_inches(15, 4)
fig.suptitle("Training Targets for Environment Critic")


def plot_mu_and_std(ax, X, data, color, label):
    mu = data.mean(axis=0)
    std = data.std(axis=0)
    ax.plot(X, mu, color=color, label=label)
    ax.fill_between(X, y1=mu - std, y2=mu + std, color=color, alpha=0.3)


colors = sns.color_palette(n_colors=2)

for i, key in enumerate(
    ("corners_agent_distill_image_210525", "corners_agent_image_210525")
):
    label = key_to_label_map.get(key)
    runs = runs_dict[key]
    c = colors[i]

    d_loss = []
    d_min = []
    d_max = []
    for x in runs:
        if len(x["d_loss"]) == 3995:
            d_loss.append(ema(x["d_loss"]))
            d_min.append(ema(x["d_min"]))
            d_max.append(ema(x["d_max"]))

    d_loss = np.array(d_loss)
    d_min = np.array(d_min)
    d_max = np.array(d_max)

    axs[0].set_title("MSE Loss")
    axs[0].set_xlabel("Training Step")
    axs[0].set_ylabel("Loss")
    plot_mu_and_std(
        ax=axs[0], X=range(d_loss.shape[1]), data=d_loss, color=c, label=label
    )

    axs[1].set_title("Training Target y Min")
    axs[1].set_xlabel("Training Step")
    axs[1].set_ylabel("Value")
    plot_mu_and_std(
        ax=axs[1], X=range(d_loss.shape[1]), data=d_min, color=c, label=label
    )

    axs[2].set_title("Training Target y Max")
    axs[2].set_ylabel("Value")
    axs[2].set_xlabel("Training Step")
    plot_mu_and_std(
        ax=axs[2], X=range(d_loss.shape[1]), data=d_max, color=c, label=label
    )

handles, labels = axs[0].get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    loc="upper center",
    bbox_to_anchor=(0.5, 0.02),
    ncol=4,
    frameon=False,
)
fig.set_tight_layout(True)
fig.savefig(fname="ablation-distill-training.png", bbox_inches="tight", dpi=300)

In [None]:
runs_dict_processed: dict[
    str, tuple[RwareScenarioConfig, RwareTrainingConfig, list]
] = {}
scenario = None
for key, label in key_to_label_map.items():
    cfg = runs_dict[key][0]
    if scenario is None:
        scenario = RwareScenarioConfig.from_raw(cfg["cfg"]["scenario"])
    train_cfg = RwareTrainingConfig.from_raw(cfg["cfg"])

    checkpoints = os.path.join("downloaded_checkpoints", key)
    if not os.path.exists(checkpoints):
        print(f"Skipping {checkpoints}")
        continue

    repeats = [os.path.join(checkpoints, x) for x in os.listdir(checkpoints)]

    runs_dict_processed[label] = (scenario, train_cfg, repeats)

In [None]:
# PUG Ablation
scenario, train_cfg, repeats = runs_dict_processed["DiCoDe-Points"]
diffusion_dir = pretrain_dir = os.path.join(
    OUTPUT_DIR, "rware", "diffusion", "graph", scenario.name
)
latest_diffusion_checkpoint = get_latest_model(diffusion_dir, "model")
state_dict = torch.load(
    os.path.join(repeats[0], "designer_3999.pt"), map_location=device
)

model = make_model(
    model=train_cfg.designer.value_model.name,
    scenario=scenario,
    model_kwargs=train_cfg.designer.value_model.model_kwargs,
    device=device,
)
model.load_state_dict(state_dict)


N = 32

# PUG
generator = RwareGenerator(
    batch_size=N,
    generator_model_path=latest_diffusion_checkpoint,
    scenario=scenario,
    representation="graph",
    device=device,
)
guidance_model = model
guidance_model.eval()

operation = OptimizerDetails()
operation.lr = 0.01
operation.num_recurrences = 8
operation.backward_steps = 16
operation.forward_guidance_wt = 5.0
operation.projection_constraint = graph_projection_constraint(scenario)

pug_batch = np.array(
    generator.generate_batch(
        value=guidance_model,
        use_operation=True,
        operation_override=operation,
    )
)

# UG
ug_batch = None
ug_batch = np.array(
    generator.generate_batch(
        value=guidance_model,
        use_operation=True,
        operation_override=operation,
    )
)

# Descent
grad_designer = GradientDescentDesigner(
    scenario=scenario,
    classifier=train_cfg.designer.value_model,
    epochs=32,
    gradient_iterations=10,
    gradient_lr=0.03,
    device=device,
)
grad_designer.model = model

descent_batch = np.array(grad_designer._reset_env_buffer(N))

# Sampling
K = 32
with torch.no_grad():
    u_x = generate(
        size=scenario.size,
        n_shelves=scenario.n_shelves,
        goal_idxs=scenario.goal_idxs,
        n_colors=scenario.n_colors,
        training_dataset=True,
        representation="graph",
        n=N * K,
    )

    u_x = torch.tensor(np.array(u_x), device=device)

    y = model(u_x)
    u_x = u_x.reshape(N, K, scenario.n_shelves, 2)
    y = y.reshape(N, K)
    _, best_idxs = y.max(dim=-1)
    sampling_batch = u_x[torch.arange(N), best_idxs]
    sampling_batch = train_to_eval(sampling_batch, scenario, "graph")
    sampling_batch = sampling_batch.numpy(force=True)
    del u_x, best_idxs, y
    torch.cuda.empty_cache()

random_batch = np.array(
    generate(
        size=scenario.size,
        n_shelves=scenario.n_shelves,
        goal_idxs=scenario.goal_idxs,
        n_colors=scenario.n_colors,
        training_dataset=False,
        representation="graph",
        n=N,
    )
)

In [None]:
sns.set_theme(style="whitegrid")

labels = []
exp_returns = []

# pc = graph_projection_constraint(scenario)
# ug_batch_alt = torch.tensor(ug_batch)
# ug_batch_alt = ug_batch_alt / (scenario.size - 1) * 2 - 1
# ug_batch_alt = pc(ug_batch_alt)
# ug_batch_alt = train_to_eval(ug_batch_alt, scenario, "graph")

selected_envs = {}

for label, batch in (
    ("PUG", pug_batch),
    ("UG", ug_batch),
    ("Descent", descent_batch),
    ("Sampling", sampling_batch),
    ("DR", random_batch),
):
    # Eval to train, run through model
    x = torch.tensor(batch, device=device)
    x = x / (scenario.size - 1) * 2 - 1
    with torch.no_grad():
        y = model(x).numpy(force=True)
        print(label, y.mean().item())

    labels.append(label)
    exp_returns.append(y)
    selected_envs[label] = {"best_idx": y.argmax(), "worst_idx": y.argmin()}

colors = sns.color_palette(n_colors=len(exp_returns))

fig, ax = plt.subplots(figsize=(5, 5))
box = ax.boxplot(
    exp_returns,
    patch_artist=True,
    labels=labels,
    boxprops=dict(linewidth=1.2),
    medianprops=dict(color="black", linewidth=1.5),
    whiskerprops=dict(color="gray"),
    capprops=dict(color="gray"),
)

# Apply colors
for patch, color in zip(box["boxes"], colors):
    patch.set_facecolor(color)
    patch.set_edgecolor("black")


ax.set_title("Environment Search Comparison")
ax.set_ylabel("Critic Value")
ax.set_xlabel("Generator Method")
fig.tight_layout()
fig.savefig(fname="ablation-pug-box.png", bbox_inches="tight", dpi=300)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(9, 6))

for i, (label, batch) in enumerate(
    (
        ("PUG", pug_batch),
        ("Descent", descent_batch),
        ("Sampling", sampling_batch),
    )
):
    best_idx = selected_envs[label]["best_idx"]
    best_env = batch[best_idx]

    ax = axs[i]
    ax.imshow(render_rware_env(best_env, scenario, "graph"))
    ax.set_title(f"{label}", fontsize=10)
    ax.axis("off")

fig.savefig(fname="ablation-envs.png", bbox_inches="tight", dpi=300)

In [None]:
# Generate heatmaps
N = 100

label_to_heatmap = {}
for label, representation in [("DiCoDe", "image"), ("DiCoDe-Points", "graph")]:
    scenario, train_cfg, repeats = runs_dict_processed[label]

    diffusion_dir = pretrain_dir = os.path.join(
        OUTPUT_DIR, "rware", "diffusion", representation, scenario.name
    )
    latest_diffusion_checkpoint = get_latest_model(diffusion_dir, "model")

    designer, _ = DesignerRegistry.get(
        designer=train_cfg.designer,
        scenario=scenario,
        ppo_cfg=train_cfg.ppo,
        artifact_dir=".",
        device=device,
    )
    designer = designer.master_designer

    B = N // len(repeats)
    envs = []
    for checkpoint_dir in tqdm(repeats):
        print(checkpoint_dir)
        designer.model.load_state_dict(
            torch.load(
                os.path.join(checkpoint_dir, "designer_3999.pt"), map_location=device
            )
        )

        batch = designer._reset_env_buffer(batch_size=B)
        for i, env in enumerate(batch):
            layout = storage_to_layout(
                features=env, config=scenario, representation=representation
            )
            envs.append(layout.storage)
    heatmap = np.stack(envs).sum(axis=0)
    label_to_heatmap[label] = heatmap


In [None]:
colors = ("Teal", "Purple", "Blue", "Green")
for (
    label,
    heatmap,
) in label_to_heatmap.items():
    fig, axs = plt.subplots(2, 2, figsize=(6, 6))
    fig.suptitle(label)

    color_min = 0
    color_max = np.max(heatmap)

    for i, c in enumerate(colors):
        ax = axs[i // 2][i % 2]
        im = ax.imshow(
            heatmap[i], cmap="viridis", aspect="equal", vmin=color_min, vmax=color_max
        )
        ax.set_title(c)
        ax.axis("off")

    fig.subplots_adjust(top=0.95, right=0.85)
    cbar_ax = fig.add_axes([0.88, 0.20, 0.03, 0.7])
    fig.colorbar(im, cax=cbar_ax)

    fig.savefig(fname=f"{label}_heatmap.png", bbox_inches="tight", dpi=300)


In [None]:
# WFCRL Plot
project_name = "diffusion-co-design-wfcrl"
api = wandb.Api()
runs = api.runs(path=project_name)

total_steps = 301
wfcrl_runs_dict = defaultdict(list)
train_reward_key = "train/reward/episode_reward_mean"


for run in tqdm(runs):
    name = run.name
    cfg = run.config
    reward = get_full_history(run, train_reward_key)
    run_data = {"cfg": cfg, "reward": reward}

    run_data["designer_state_dict"] = None
    for artifact in run.logged_artifacts():
        if artifact.name.startswith("designer_final"):
            artifact_dir = artifact.download()
            state_dict = torch.load(
                os.path.join(artifact_dir, "designer_300.pt"), map_location=device
            )
            run_data["designer_state_dict"] = state_dict

    wfcrl_runs_dict[name].append(run_data)

In [None]:
sns.set_theme(context="notebook")
mpl.rcParams["font.family"] = "monospace"
fig, axs = plt.subplots(1, 2)
fig.set_size_inches(15, 4)

key_to_label_map = {
    0: {
        "wfcrl_fixed": "Fixed",
        "wfcrl_diffusion_distill": "DiCoDe",
    },
    1: {
        "wfcrl_fixed_rect_8": "Fixed",
        "wfcrl_diffusion_distill_rect_8": "DiCoDe",
    },
}
total_training_iterations = 301
samples_per_iteration = 300
colors = sns.color_palette(n_colors=2)


for ax_id, data_dict in key_to_label_map.items():
    ax = axs[ax_id]

    for (key, label), color in zip(data_dict.items(), colors):
        runs = wfcrl_runs_dict[key]

        rewards = []
        for x in runs:
            reward = ema(x["reward"])
            if len(x["reward"]) != total_training_iterations:
                # Run not complete, skip
                continue
            rewards.append(reward)
        rewards = np.array(rewards)

        X = (
            np.linspace(1, total_training_iterations + 1, total_training_iterations)
            * samples_per_iteration
        )

        mu = rewards.mean(axis=0)
        print(ax_id, label, f"mean: {mu[-1]}")
        ax.plot(X, mu, color=color, label=label)
        if rewards.shape[0] > 1:
            std = rewards.std(axis=0)
            print(f"std: {std[-1]}")
            ax.fill_between(X, y1=mu - std, y2=mu + std, color=color, alpha=0.3)
        pass

fig.suptitle("WFCRL Training Progress")
axs[0].set_title("Square-10")
axs[0].set_xlabel("Frames")
axs[0].set_ylabel("Mean Episode Reward")
axs[1].set_title("Rect-8")
axs[1].set_xlabel("Frames")
axs[1].set_ylabel("Mean Episode Reward")
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    loc="upper center",
    bbox_to_anchor=(0.5, 0.02),
    ncol=4,
    frameon=False,
)
fig.savefig(fname="wfcrl-training.png", bbox_inches="tight", dpi=300)

In [None]:
for i, data_dict in key_to_label_map.items():
    for j, (key, label) in enumerate(data_dict.items()):
        runs = wfcrl_runs_dict[key]
        scenario = WfcrlScenarioConfig.from_raw(runs[0]["cfg"]["scenario"])
        training_cfg = WfcrlTrainingConfig.from_raw(runs[0]["cfg"])

        if label == "DiCoDe":
            diffusion = training_cfg.designer.diffusion.model_copy()
            diffusion.forward_guidance_annealing = False
            designer = DiffusionDesigner(
                scenario=scenario,
                classifier=training_cfg.designer.model,
                diffusion=diffusion,
                normalisation_statistics=training_cfg.normalisation,
                total_training_iterations=10,
            )
            designer.model.load_state_dict(state_dict=runs[0]["designer_state_dict"])
            layout = designer._reset_env_buffer(1)[0]
        else:
            generator = Generate(
                num_turbines=scenario.n_turbines,
                map_x_length=scenario.map_x_length,
                map_y_length=scenario.map_y_length,
                minimum_distance_between_turbines=scenario.min_distance_between_turbines,
            )
            layout = generator(n=1, training_dataset=False).squeeze()

        im = render_layout(x=layout, scenario=scenario)
        filename = f"{label.lower()}_{scenario.name}.png"
        mpl.image.imsave(filename, im)