In [None]:
from collections import defaultdict
import torch
import wandb
import numpy as np
from tqdm import tqdm
import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns

from diffusion_co_design.rware.diffusion.generate import generate
from diffusion_co_design.rware.diffusion.transform import (
    storage_to_layout_image,
    graph_projection_constraint,
)
from diffusion_co_design.rware.schema import ScenarioConfig as RwareScenarioConfig
from wfcrl.environments.data_cases import floris_ormonde
from diffusion_co_design.wfcrl.schema import ScenarioConfig as WfcrlScenarioConfig
from diffusion_co_design.wfcrl.env import _create_designable_windfarm
from rware.warehouse import Warehouse

In [None]:
shelf_im = generate(size=8, n_shelves=20, goal_idxs=[0, 7, 55, 63], n_colors=4)[0]
layout = storage_to_layout_image(
    shelf_im,
    agent_idxs=[12, 23, 54, 8],
    agent_colors=[-1, -1, -1, -1],
    goal_idxs=[0, 7, 56, 63],
    goal_colors=[0, 1, 2, 3],
)
warehouse = Warehouse(layout=layout)
image = warehouse.render()
image = image[::2, ::2]
warehouse.close()

H, W, C = image.shape

noise = np.random.normal(loc=0.0, scale=1.0, size=image.shape)
noise = (noise - noise.min()) / (noise.max() - noise.min())
noise = noise * 255
noise = noise.astype(np.uint8)

for i, beta in enumerate([0, 0.5, 0.75, 1]):
    blended = ((1 - beta) * image + beta * noise).astype(np.uint8)
    plt.figure(figsize=(12, 6))
    plt.imshow(blended)
    plt.axis("off")
    plt.savefig(f"blended_{i}.png", bbox_inches="tight", dpi=300)

In [None]:
warehouse = Warehouse(layout=layout)
image = warehouse.render()
warehouse.close()
plt.imshow(image)
plt.axis("off")
plt.savefig("d-rware", bbox_inches="tight", dpi=300)

In [None]:
scenario = RwareScenarioConfig(
    name="d-rware-example",
    n_agents=3,
    n_shelves=16,
    n_colors=4,
    goal_idxs=[0, 7, 56, 63],
    agent_idxs=[12, 23, 54, 8],
    agent_colors=[-1, -1, -1, -1],
    n_goals=4,
    goal_colors=[0, 1, 2, 3],
    max_steps=100,
    size=8,
)


x0 = []
for x in range(2, 6):
    for y in range(2, 6):
        x0.append((x, y))
x0 = np.array(x0)


def plot_points(layout, scenario: RwareScenarioConfig):
    fig, ax = plt.subplots(1, 1)
    x = [p[0] for p in layout]
    y = [p[1] for p in layout]

    ax.scatter(x, y, s=130, color="blue", edgecolors="black", linewidths=0.5, zorder=3)
    ax.axis("off")

    # Grid lines
    for x in range(scenario.size):
        ax.plot(
            [x, x],
            [0, scenario.size - 1],
            color="gray",
            linewidth=1,
        )

    for y in range(scenario.size):
        ax.plot(
            [0, scenario.size - 1],
            [y, y],
            color="gray",
            linewidth=1,
        )

    ax.set_aspect("equal")
    fig.set_tight_layout(True)
    return fig, ax


# x0
fig, ax = plot_points(x0, scenario=scenario)
fig.savefig("pug_0.png", bbox_inches="tight", dpi=300)

# xT
xT = np.clip(
    (np.random.normal(0, 1.0, size=x0.shape) + 1) / 2 * (scenario.size - 1),
    0,
    (scenario.size - 1),
)
fig, ax = plot_points(xT, scenario=scenario)
fig.savefig("pug_T.png", bbox_inches="tight", dpi=300)

# xt
alpha = 0.7
xt = alpha**0.4 * x0 + (1 - alpha) * xT
fig, ax = plot_points(xt, scenario=scenario)
fig.savefig("pug_t.png", bbox_inches="tight", dpi=300)

# xt_0
xt_0 = x0 + np.random.normal(0, 0.3, size=x0.shape)
fig, ax = plot_points(xt_0, scenario=scenario)
fig.savefig("pug_t0.png", bbox_inches="tight", dpi=300)

# xt constrained
xt_constr = graph_projection_constraint(scenario)(
    torch.tensor(xt_0 / (scenario.size - 1) * 2 - 1, dtype=torch.float32).unsqueeze(0)
)[0].numpy()
xt_constr = (xt_constr + 1) / 2 * (scenario.size - 1)
xt_constr = np.round(xt_constr)
fig, ax = plot_points(xt_constr, scenario=scenario)
fig.savefig("pug_t_constr.png", bbox_inches="tight", dpi=300)

In [None]:
xcoords = np.array(floris_ormonde.xcoords)
ycoords = np.array(floris_ormonde.ycoords)
margin = 300

xcoords = xcoords - xcoords.min() + margin
ycoords = ycoords - ycoords.min() + margin

scenario = WfcrlScenarioConfig(
    name="ormonde_render_example",
    n_turbines=len(xcoords),
    max_steps=margin,
    map_x_length=int(xcoords.max() + margin),
    map_y_length=int(ycoords.max() + margin),
    min_distance_between_turbines=400,
)

env = _create_designable_windfarm(
    scenario=scenario,
    initial_xcoords=xcoords.tolist(),
    initial_ycoords=ycoords.tolist(),
    render=True,
)

# Take some random steps
env.reset()
for _ in range(2000):
    env.step({"yaw": np.array([np.random.rand() * 10 - 5])})

fig, ax = plt.subplots(figsize=(6, 6))
ax.axis("off")
ax.imshow(env.render(), aspect="auto")

fig.savefig("wfcrl_ormonde.png", bbox_inches="tight", dpi=300)

In [None]:
# D-RWARE Corners Plot
project_name = "diffusion-co-design-rware"
api = wandb.Api()
runs = api.runs(path=project_name)

total_steps = 4000
runs_dict = defaultdict(list)
train_reward_key = "train/reward/episode_reward_mean"


# Wandb limits to 500
def get_full_history(run, key):
    values = []
    for row in run.scan_history(keys=[key]):
        values.append(row[key])
    return np.array(values)


for run in tqdm(runs):
    name = run.name
    cfg = run.config
    reward = get_full_history(run, train_reward_key)

    runs_dict[name].append({"cfg": cfg, "reward": reward})


In [None]:
sns.set_theme(context="notebook")
mpl.rcParams["font.family"] = "monospace"
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(8.1, 5)

key_to_label_map = {
    "corners_agent_distill_image_210525": "DiCD",
    "corners_agent_distill_gnn_210525": "DiCD-Points",
    "corners_agent_image_210525": "DiCD-MC",
    "corners_agent_descent_210525": "DiCD-Descent",
    "corners_agent_sampling_210525": "DiCD-Sampling",
    "corners_agent_fixed_210525": "Fixed",
    "corners_agent_random_210525": "DR",
    "corners_agent_rl_210525": "RL",
}
total_training_iterations = 4000
samples_per_iteration = 5000
colors = sns.color_palette(n_colors=len(key_to_label_map))


def ema(data: np.ndarray, alpha: float = 0.95):
    ema = np.zeros_like(data)

    ema[0] = data[0]
    for i in range(1, data.shape[0]):
        ema[i] = alpha * ema[i - 1] + (1 - alpha) * data[i]

    return ema


for (key, label), color in zip(key_to_label_map.items(), colors):
    runs = runs_dict[key]

    rewards = []
    for x in runs:
        reward = ema(x["reward"])
        if len(x["reward"]) != total_training_iterations and label != "RL":
            # Run not complete, skip
            continue
        rewards.append(reward)
    rewards = np.array(rewards)

    if label == "RL":
        X = np.linspace(1, 2000 + 1, 2000)
        X = X * (samples_per_iteration + 10000)
        X = X[:1333]  # Too many samples
        rewards = rewards[:, :1333]

    else:
        X = (
            np.linspace(1, total_training_iterations + 1, total_training_iterations)
            * samples_per_iteration
        )

    mu = rewards.mean(axis=0)
    ax.plot(X, mu, color=color, label=label)
    if rewards.shape[0] > 1:
        std = rewards.std(axis=0)
        ax.fill_between(X, y1=mu - std, y2=mu + std, color=color, alpha=0.3)
    pass

ax.set_title("D-RWARE (Corner) Training Progress")
ax.set_xlabel("Frames")
ax.set_ylabel("Episode Reward")
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=4, frameon=False)

fig.savefig(fname="d-rware-corners.png", bbox_inches="tight", dpi=300)