# Diffusion Visualisation

This notebook is used to visualise and test the guided diffusion pipelines.

## Unguided Generation

In [71]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from guided_diffusion.script_util import create_classifier, classifier_defaults
from rware.warehouse import Warehouse

from diffusion_co_design.pretrain.rware.transform import (
    rgb_to_layout,
    storage_to_layout,
)
from diffusion_co_design.pretrain.rware.generate import generate
from diffusion_co_design.pretrain.rware.generator import Generator, GeneratorConfig
from diffusion_co_design.utils import OUTPUT_DIR
from diffusion_co_design.utils import cuda as device

cfg = GeneratorConfig(
    generator_model_path=os.path.join(
        OUTPUT_DIR,
        "diffusion_pretrain",
        "default",
        # "diffusion_pretrain",
        "model100000.pt",
    ),
    size=16,
    batch_size=9,
    num_channels=3,
    # num_channels=1,
)
agent_idxs = [1, 88, 132, 233, 162]
goal_idxs = [39, 185, 237, 238, 159]
# to_layout = storage_to_layout
to_layout = rgb_to_layout


def show_batch(environment_batch):
    layouts = []
    for image in environment_batch:
        # layout = rgb_to_layout(image)
        layout = to_layout(image, agent_idxs, goal_idxs)
        warehouse = Warehouse(layout=layout, render_mode="rgb_array")
        layouts.append(warehouse.render())
        warehouse.close()

    fig, axs = plt.subplots(3, 3, figsize=(12, 12))
    axs = axs.ravel()
    for i in range(9):
        axs[i].imshow(layouts[i])
        axs[i].axis("off")
    return fig, axs

In [None]:
generator = Generator(cfg)
environment_batch = generator.generate_batch()
fig, axs = show_batch(environment_batch)
fig.suptitle("Unguided Generation")
fig.tight_layout()

In [None]:
# Co-variance Test

NUM_TEST_IMAGES = 1000

layouts = []
test_cfg = cfg.model_copy()
test_cfg.batch_size = 20
generator = Generator(test_cfg)
with tqdm(total=NUM_TEST_IMAGES) as pbar:
    while len(layouts) < NUM_TEST_IMAGES:
        environment_batch = generator.generate_batch()
        for image in environment_batch:
            layout = to_layout(image, agent_idxs, goal_idxs)
            layouts.append(layout)
            pbar.update()
layouts_1 = np.stack([l.highways.flatten() for l in layouts])


layouts = []
with tqdm(total=NUM_TEST_IMAGES) as pbar:
    while len(layouts) < NUM_TEST_IMAGES:
        layout = storage_to_layout(
            generate(
                size=16, n_shelves=50, agent_idxs=agent_idxs, goal_idxs=goal_idxs, n=1
            )[0],
            agent_idxs,
            goal_idxs,
        )
        layouts.append(layout)
        pbar.update()
layouts_2 = np.stack([l.highways.flatten() for l in layouts])

In [None]:
fig, axs = plt.subplots(2)
axs[0].set_yscale("log")
axs[0].set_title("Diffusion Generated")
axs[0].hist(np.cov(layouts_1).flatten(), bins=50)
axs[1].set_title("Randomly Generated")
axs[1].set_yscale("log")
axs[1].hist(np.cov(layouts_2).flatten(), bins=50)
fig.set_constrained_layout(True)
pass

In [None]:
plt.hist(np.abs(np.cov(layouts).flatten()), bins=50)

In [None]:
GUIDANCE_WT = 10
TRAIN_NUM_ITERATIONS = 30
TRAIN_BATCH_SIZE = 128
VALUE_LR = 3e-4
VALUE_WEIGHT_DECAY = 0.05

# Build pseudo value function
generator = Generator(cfg, guidance_wt=GUIDANCE_WT)

X = generate(
    size=16,
    n_shelves=50,
    agent_idxs=[1, 88, 132, 233, 162],
    goal_idxs=[39, 185, 237, 238, 158],
    n=10_000,
    rgb=True,
)
X = torch.tensor(X)
y = []
for x in X:
    layout = rgb_to_layout(x)
    y.append(sum([s.x for s in layout.reset_shelves()]))

y = torch.tensor(y).to(device)
y = y - y.min() + 1
y = y / y.mean(dtype=torch.float32)
X = X.movedim(-1, -3).to(device)

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=TRAIN_BATCH_SIZE)

In [None]:
# Create value model
model_dict = classifier_defaults()
model_dict["image_size"] = cfg.size
model_dict["image_channels"] = 3
model_dict["classifier_width"] = 256
model_dict["classifier_depth"] = 2
model_dict["classifier_attention_resolutions"] = "16, 8, 4"
model_dict["output_dim"] = 1

model = create_classifier(**model_dict).to(device)

# Train
optim = torch.optim.Adam(
    model.parameters(), lr=VALUE_LR, weight_decay=VALUE_WEIGHT_DECAY
)
criterion = torch.nn.MSELoss()

model.train()
losses = []
with tqdm(range(TRAIN_NUM_ITERATIONS)) as pbar:
    for epoch in pbar:
        running_loss = 0.0
        for X_batch, y_batch in dataloader:
            optim.zero_grad()

            X_batch = X_batch.to(torch.float32)
            y_batch = y_batch.to(torch.float32)
            t, _ = generator.schedule_sampler.sample(len(X_batch), device)
            X_batch = generator.diffusion.q_sample(X_batch, t)

            y_pred = model(X_batch, t).squeeze()
            loss = criterion(y_pred, y_batch)
            running_loss += loss.item()
            loss.backward()
            optim.step()
        running_loss = running_loss / len(dataloader)
        pbar.set_description(f"Epoch {epoch} | Loss {running_loss}")
        losses.append(running_loss)

In [None]:
generator = Generator(cfg, guidance_wt=100)
model.eval()
environment_batch = generator.generate_batch(value=model)
for env in environment_batch:
    layout = rgb_to_layout(env)
    print(len(layout.reset_shelves()))
fig, axs = show_batch(environment_batch)
fig.suptitle("Guided Generation")
fig.tight_layout()

# Note: Guidance weights need to be set really high for any decent signal?