In [None]:
import os
import shutil
import math
import numpy as np
from IPython.display import Video
from omegaconf import OmegaConf
from tqdm import tqdm
from gymnasium.utils.save_video import save_video
import torch
from tensordict import TensorDict
from torchrl.envs import EnvBase
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
from matplotlib import pyplot as plt
from guided_diffusion.script_util import create_classifier, classifier_defaults
from diffusion_co_design.pretrain.rware.transform import storage_to_layout
from rware.warehouse import Warehouse
from diffusion_co_design.rware.env import create_env, create_batched_env
from diffusion_co_design.rware.model import rware_models
from diffusion_co_design.utils import (
    omega_to_pydantic,
    get_latest_model,
    cuda,
    OUTPUT_DIR,
)
from diffusion_co_design.pretrain.rware.transform import storage_to_layout
from diffusion_co_design.bin.train_rware import TrainingConfig, DesignerRegistry
from diffusion_co_design.pretrain.rware.generator import (
    Generator,
    GeneratorConfig,
    OptimizerDetails,
)

FIGURE_SIZE_CNST = 3
CACHE_ID = "train_visualisation"
RECOMPUTE = False

# Parameters
training_dir = "/home/markhaoxiang/.diffusion_co_design/training/2025-02-23/21-11-20"
device = cuda

# Get latest policy
checkpoint_dir = os.path.join(training_dir, "checkpoints")
latest_policy = get_latest_model(checkpoint_dir, "policy_")

# Get config
hydra_dir = os.path.join(training_dir, ".hydra")
training_config = os.path.join(hydra_dir, "config.yaml")
cfg = omega_to_pydantic(OmegaConf.load(training_config), TrainingConfig)

# Create environment
cache_dir = os.path.join(OUTPUT_DIR, ".tmp", CACHE_ID)
if RECOMPUTE and os.path.exists(cache_dir):
    shutil.rmtree(cache_dir)
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)

master_designer, env_designer = DesignerRegistry.get(
    "random",
    cfg.scenario,
    cache_dir,
    environment_batch_size=32,
    device=device,
)
agent_idxs = cfg.scenario.agent_idxs
goal_idxs = cfg.scenario.goal_idxs
env = create_env(cfg.scenario, env_designer, render=True, device=device)
policy, _ = rware_models(env, cfg.policy, device=device)
policy.load_state_dict(torch.load(latest_policy))


def view_video(env: EnvBase, policy):
    frames = []
    video_out = os.path.join(cache_dir, "video/rl-video-episode-0.mp4")

    def append_frames(env, td):
        return frames.append(env.render())

    env.rollout(
        max_steps=cfg.scenario.max_steps,
        policy=policy,
        callback=append_frames,
        auto_cast_to_device=True,
    )

    save_video(
        frames=frames,
        video_folder=os.path.join(cache_dir, "video"),
        fps=10,
    )

    return lambda: Video(filename=video_out, embed=True)

In [2]:
# Test if the value function is able to discern good environments!

NUM_PARALLEL_COLLECTION = 25
DATASET_SIZE = 10_000
# DATASET_SIZE = 4
BATCH_SIZE = 128

collection_env = create_batched_env(
    num_environments=NUM_PARALLEL_COLLECTION,
    scenario=cfg.scenario,
    designer=env_designer,
    is_eval=False,
    device="cpu",
)

env_returns = ReplayBuffer(
    storage=LazyTensorStorage(max_size=DATASET_SIZE),
    sampler=SamplerWithoutReplacement(),
    batch_size=BATCH_SIZE,
)

if RECOMPUTE:
    for _ in tqdm(range(DATASET_SIZE // NUM_PARALLEL_COLLECTION)):
        rollout = collection_env.rollout(
            max_steps=cfg.scenario.max_steps, policy=policy, auto_cast_to_device=True
        )
        done = rollout.get(("next", "done"))
        X = rollout.get("state")[done.squeeze()]
        y = rollout.get(("next", "agents", "episode_reward")).mean(-2)[done]
        data = TensorDict({"env": X, "episode_reward": y}, batch_size=len(y))
        env_returns.extend(data)
    del rollout, done

env_returns_path = os.path.join(cache_dir, "env_returns")
if RECOMPUTE:
    env_returns.dumps(env_returns_path)
else:
    env_returns.loads(env_returns_path)

In [3]:
GUIDANCE_WT = 50
TRAIN_NUM_EPOCHS = 80
VALUE_LR = 3e-4
VALUE_WEIGHT_DECAY = 0.05
OUTPUT_CHANNELS = 1
RECOMPUTE = False

ITERATIONS_PER_EPOCH = math.ceil(DATASET_SIZE / BATCH_SIZE)

goal_map = np.zeros((cfg.scenario.size, cfg.scenario.size))
for goal in goal_idxs:
    goal_map[goal // cfg.scenario.size, goal % cfg.scenario.size] = 1
goal_map = (
    torch.from_numpy(goal_map).unsqueeze(0).to(device=device, dtype=torch.float32)
)

# Create value model
pretrain_dir = os.path.join(OUTPUT_DIR, "diffusion_pretrain", cfg.scenario.name)
latest_checkpoint = get_latest_model(pretrain_dir, "model")

gen_cfg = GeneratorConfig(
    batch_size=8,
    generator_model_path=latest_checkpoint,
    size=cfg.scenario.size,
    num_channels=OUTPUT_CHANNELS,
)
generator = Generator(gen_cfg, guidance_wt=GUIDANCE_WT)

model_dict = classifier_defaults()
model_dict["image_size"] = cfg.scenario.size
model_dict["image_channels"] = OUTPUT_CHANNELS + 1
model_dict["classifier_width"] = 256
model_dict["classifier_depth"] = 2
model_dict["classifier_attention_resolutions"] = "16, 8, 4"
model_dict["output_dim"] = 1

model = create_classifier(**model_dict).to(device)

# Train
optim = torch.optim.Adam(
    model.parameters(), lr=VALUE_LR, weight_decay=VALUE_WEIGHT_DECAY
)
criterion = torch.nn.MSELoss()

if RECOMPUTE:
    model.train()
    losses = []
    with tqdm(range(TRAIN_NUM_EPOCHS)) as pbar:
        for epoch in range(TRAIN_NUM_EPOCHS):
            running_loss = 0
            for _ in range(ITERATIONS_PER_EPOCH):
                optim.zero_grad()
                sample = env_returns.sample(batch_size=BATCH_SIZE)
                X_batch = sample.get("env").to(dtype=torch.float32, device=device)
                # Add goal map
                goal_map_batch = goal_map.expand(X_batch.shape[0], -1, -1).unsqueeze(1)
                X_batch = torch.cat([X_batch, goal_map_batch], dim=1)

                # Normalisation
                X_batch = X_batch * 2 - 1
                y_batch = sample.get("episode_reward").to(
                    dtype=torch.float32, device=device
                )
                # t, _ = generator.schedule_sampler.sample(len(X_batch), device)
                # X_batch = generator.diffusion.q_sample(X_batch, t)
                # y_pred = model(X_batch, t).squeeze()
                y_pred = model(X_batch).squeeze()
                loss = criterion(y_pred, y_batch)
                loss.backward()
                optim.step()

                running_loss += loss.item()
            running_loss = running_loss / ITERATIONS_PER_EPOCH
            losses.append(running_loss)
            pbar.set_description(f"Loss {running_loss}")
            pbar.update()
    torch.save(model.state_dict(), "train_visualisation_classifier.pt")
else:
    model.load_state_dict(torch.load("train_visualisation_classifier.pt"))

In [None]:
plt.plot(losses)
plt.title("Value Function Train loss")
min(losses)

In [None]:
torch.cuda.empty_cache()

env_returns_sorted_index = torch.argsort(
    env_returns.storage["episode_reward"], descending=True
)
best_5 = env_returns_sorted_index[:5]
worst_5 = env_returns_sorted_index[-5:]

fig, axs = plt.subplots(2, 5)
fig.set_size_inches(5 * FIGURE_SIZE_CNST, 2 * FIGURE_SIZE_CNST)

for i, idx in enumerate(best_5):
    ax = axs[0, i]
    layout = storage_to_layout(
        env_returns.storage["env"][idx], cfg.scenario.agent_idxs, cfg.scenario.goal_idxs
    )
    print(env_returns.storage["episode_reward"][idx])
    warehouse = Warehouse(layout=layout, render_mode="rgb_array")
    im = warehouse.render()
    ax.imshow(im)
    warehouse.close()
    ax.axis("off")

print("===")

for i, idx in enumerate(worst_5):
    ax = axs[1, i]
    layout = storage_to_layout(
        env_returns.storage["env"][idx], cfg.scenario.agent_idxs, cfg.scenario.goal_idxs
    )
    print(env_returns.storage["episode_reward"][idx])
    warehouse = Warehouse(layout=layout, render_mode="rgb_array")
    im = warehouse.render()
    ax.imshow(im)
    warehouse.close()
    ax.axis("off")

In [None]:
torch.cuda.empty_cache()
N = 50

model.eval()
good_envs = env_returns.storage["env"][env_returns_sorted_index[:N]].to(
    device=device, dtype=torch.float32
)
goal_map_batch = goal_map.expand(good_envs.shape[0], -1, -1).unsqueeze(1)
good_envs = torch.cat([good_envs, goal_map_batch], dim=1)
print(f"Good environments: {model(good_envs * 2 - 1).mean()}")

bad_envs = env_returns.storage["env"][env_returns_sorted_index[-N:]].to(
    device=device, dtype=torch.float32
)
bad_envs = torch.cat([bad_envs, goal_map_batch], dim=1)
print(f"Bad environments: {model(bad_envs * 2 - 1).mean()}")

del good_envs
del bad_envs

In [None]:
agent_idxs = [1, 88, 132, 233, 162]
goal_idxs = [39, 185, 237, 238, 159]


generator = Generator(gen_cfg, guidance_wt=50)
model.eval()
operation = OptimizerDetails()
operation.num_recurrences = 8
operation.backward_steps = 0
operation.operated_image = goal_map * 2 - 1


def show_batch(environment_batch, n: int = 8):
    layouts = []
    for image in environment_batch:
        layout = storage_to_layout(image, agent_idxs, goal_idxs)
        warehouse = Warehouse(layout=layout, render_mode="rgb_array")
        layouts.append(warehouse.render())
        warehouse.close()

    fig, axs = plt.subplots(3, 3, figsize=(12, 12))
    axs = axs.ravel()
    for ax in axs:
        ax.axis("off")
    for i in range(n):
        axs[i].imshow(layouts[i])
    return fig, axs


environment_batch = generator.generate_batch(
    value=model, use_operation=True, operation_override=operation
)


for env in environment_batch:
    layout = storage_to_layout(env, agent_idxs, goal_idxs)
    print(len(layout.reset_shelves()))
fig, axs = show_batch(environment_batch)
fig.suptitle("Guided Generation")
fig.tight_layout()

X_batch = (
    torch.from_numpy(environment_batch)
    .to(device=device, dtype=torch.float32)
    .moveaxis((0, 1, 2, 3), (0, 2, 3, 1))
)
X_batch = torch.cat([X_batch, goal_map.unsqueeze(0).expand(8, -1, -1, -1)], dim=1)
X_batch = (X_batch * 2) - 1
print(model(X_batch))

In [None]:
# Counterfactual: randomly generated environments

environment_batch = generator.generate_batch()


for env in environment_batch:
    layout = storage_to_layout(env, agent_idxs, goal_idxs)
    print(len(layout.reset_shelves()))
fig, axs = show_batch(environment_batch)
fig.suptitle("Guided Generation")
fig.tight_layout()

X_batch = (
    torch.from_numpy(environment_batch)
    .to(device=device, dtype=torch.float32)
    .moveaxis((0, 1, 2, 3), (0, 2, 3, 1))
)
X_batch = torch.cat([X_batch, goal_map.unsqueeze(0).expand(8, -1, -1, -1)], dim=1)
X_batch = (X_batch * 2) - 1
print(model(X_batch))