## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

root_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root_dir not in sys.path:
    sys.path.insert(0, root_dir)

In [3]:
import gymnasium as gym
from diffuser.utils.config import Config, get_params, get_device_settings
from diffuser.utils.training import Trainer
from diffuser.utils.arrays import report_parameters, batchify

import numpy as np
import pdb
from minari import DataCollector, StepDataCallback
import torch
import matplotlib.pyplot as plt
import h5py
import pandas as pd
from datetime import datetime
import wandb

## Parse Arguments and Paramters

In [None]:
# Get settings from the config file

parser = get_params()

args = parser.parse_known_args(sys.argv[1:])[0]

# Set Seeds
seed = args.seed
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# Get device settings
device = get_device_settings(args)

In [None]:
dataset_config = Config(
    args.loader,
    savepath=(args.savepath, "dataset_config.pkl"),
    env=args.env_name,
    horizon=args.horizon,
    normalizer=args.normalizer,
    preprocess_fns=args.preprocess_fns,
    use_padding=args.use_padding,
    max_path_length=args.max_path_length,
)

render_config = Config(
    args.renderer,
    savepath=(args.savepath, "render_config.pkl"),
    env=args.env_name,
)

model_config = Config(
    args.model,
    savepath=(args.savepath, "model_config.pkl"),
    horizon=args.horizon,
    transition_dim=args.observation_dim + args.action_dim,
    cond_dim=args.observation_dim,
    dim_mults=args.dim_mults,
    device=device,
)
diffusion_config = Config(
    _class="models.diffuser.GaussianDiffusion",
    savepath=(args.savepath, "diffusion_config.pkl"),
    horizon=args.horizon,
    observation_dim=args.observation_dim,
    action_dim=args.action_dim,
    n_timesteps=args.n_timesteps,
    loss_type=args.loss_type,
    clip_denoised=args.clip_denoised,
    predict_epsilon=args.predict_epsilon,
    # loss weighting
    action_weight=args.action_weight,
    loss_weights=args.loss_weights,
    loss_discount=args.loss_discount,
    device=device,
)

trainer_config = Config(
    Trainer,
    savepath=(args.savepath, "trainer_config.pkl"),
    train_batch_size=args.train_batch_size,
    name=args.env_name,
    train_lr=args.train_lr,
    gradient_accumulate_every=args.gradient_accumulate_every,
    ema_decay=args.ema_decay,
    sample_freq=args.sample_freq,
    save_freq=args.save_freq,
    label_freq=args.label_freq,
    save_parallel=args.save_parallel,
    results_folder=args.savepath,
    bucket=args.bucket,
    n_reference=args.n_reference,
    n_samples=args.n_samples,
    device=device,
)

In [None]:
# Load objects
dataset = dataset_config()
renderer = render_config()
model = model_config()
diffuser = diffusion_config(model)
trainer = trainer_config(diffuser, dataset, renderer, device)

In [None]:
model_path = "saved/10_ep/state_900000.pt"
trainer.load(directory=model_path, epoch=500000)

In [None]:
env = gym.make(args.env_name)
env.reset()
for i in range(10):
    trainer.render_samples(env, get_cond_from_env=True, batch_size=1)

## Forward pass is working

In [None]:
report_parameters(model)

print("Testing forward...", end=" ", flush=True)
batch = batchify(dataset[0])
loss, _ = diffuser.loss(*batch)
loss.backward()
print("✓")

## Using the trainer requires taking care of the 'device' in the folders

# Training process inlcluding rendering

In [None]:
current_time = datetime.now().strftime("%d_%m_%Y-%H-%M")

if args.use_wandb:
    run = wandb.init(
        config=args,
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=f"Run_{current_time}",
        group="Group-Name",
        job_type="training",
        reinit=True,
    )

In [None]:
from tqdm import tqdm

# n_epochs = int(args.n_train_steps // args.n_steps_per_epoch)
n_epochs = 5
diffuser.to(device)
for i in tqdm(range(n_epochs)):
    print(f"Epoch {i} / {n_epochs} | {args.savepath}")
    trainer.train(n_train_steps=10000)

In [None]:
def extract_datasets(file_path):
    with h5py.File(file_path, "r") as f:
        # Extract observations dataset
        observations = np.array(f["observations"])

        # Extract infos/qpos dataset
        qpos = np.array(f["infos/qpos"])

    return observations, qpos


# Replace 'your_file_path_here' with the actual path to the HDF5 file
file_path = "/Users/magic-rabbit/Downloads/maze2d-large-dense-v1.hdf5"
observations, qpos = extract_datasets(file_path)

# Printing shapes of the arrays to confirm extraction
print("Observations shape:", observations.shape)
print("Qpos shape:", qpos.shape)

LARGE_MAZE = [
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
    [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
    [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],
    [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
    [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
    [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
background_array = np.array(LARGE_MAZE)

# Define the extent to center the plot at (0, 0)

plt.clf()
fig = plt.gcf()

plt.imshow(
    background_array,
    cmap=plt.cm.binary,
    # vmin=0,
    # vmax=1,
)

path_length = len(observations)
# observations = observations.reshape(len(observations), -1)
colors = plt.cm.jet(np.linspace(0, 1, 100000))
plt.plot(observations[:100000, 1], observations[:100000, 0], c="black", zorder=10)
plt.scatter(observations[:100000, 1], observations[:100000, 0], c=colors, zorder=20)

In [None]:
# Medium

# Replace 'your_file_path_here' with the actual path to the HDF5 file
file_path = "/Users/magic-rabbit/Downloads/maze2d-medium-dense-v1.hdf5"
observations, qpos = extract_datasets(file_path)

# Printing shapes of the arrays to confirm extraction
print("Observations shape:", observations.shape)
print("Qpos shape:", qpos.shape)

MEDIUM_MAZE = [
    [1, 1, 1, 1, 1, 1, 1, 1],
    [1, 0, 0, 1, 1, 0, 0, 1],
    [1, 0, 0, 1, 0, 0, 0, 1],
    [1, 1, 0, 0, 0, 1, 1, 1],
    [1, 0, 0, 1, 0, 0, 0, 1],
    [1, 0, 1, 0, 0, 1, 0, 1],
    [1, 0, 0, 0, 1, 0, 0, 1],
    [1, 1, 1, 1, 1, 1, 1, 1],
]
background_array = np.array(MEDIUM_MAZE)


plt.clf()
fig = plt.gcf()

plt.imshow(
    background_array,
    cmap=plt.cm.binary,
    # vmin=0,
    # vmax=1,
)

path_length = len(observations)
# observations = observations.reshape(len(observations), -1)
colors = plt.cm.jet(np.linspace(0, 1, 100000))
plt.plot(observations[:100000, 1], observations[:100000, 0], c="black", zorder=10)
plt.scatter(observations[:100000, 1], observations[:100000, 0], c=colors, zorder=20)

# Visualize diversity in the dataset from certain regions: 

In [None]:
print(dataset.fields)
print(dataset.fields.observations.shape)


start_coords = [
    (observation[0][0], observation[0][1])
    for observation in dataset.fields.observations
]
end_coords = [
    (observation[255][0], observation[-1][1])
    for observation in dataset.fields.observations
]

# Print min max of start and end coords
print("Start coords min:", np.min(start_coords, axis=0))
print("Start coords max:", np.max(start_coords, axis=0))
print("End coords min:", np.min(end_coords, axis=0))
print("End coords max:", np.max(end_coords, axis=0))

In [None]:
# # Initialize the trajectories list
trajectories = []

# # Create a sample start_coords and end_coords list as a tuple (x,y) betwen [-4,4]
# start_coords = [
#     (np.random.uniform(-4, 4), np.random.uniform(-4, 4)) for _ in range(10000)
# ]
# end_coords = [
#     (np.random.uniform(-4, 4), np.random.uniform(-4, 4)) for _ in range(10000)
# ]

# print(start_coords[:5])
# Transform the lists into the desired structure
for start, end in zip(start_coords, end_coords):
    trajectories.append({"start": start, "end": end})

print(len(trajectories))
# Define the boundaries for 3D
# Creae a list with step 1 from [-4,4]
x_boundaries = np.arange(-4, 4, 1)
x_boundaries = [-4.0, -2.0, 0.0, 2.0, 4.0]
y_boundaries = [-4.0, -2.0, 0.0, 2.0, 4.0]
# Define regions
regions = {}
index = 1
for i in range(4):
    for j in range(4):
        region_name = f"R_{i+1}{j+1}"
        regions[region_name] = {
            "x": (x_boundaries[i], x_boundaries[i + 1]),
            "y": (y_boundaries[j], y_boundaries[j + 1]),
        }
        index += 1


# Function to determine the region of a point
def find_region(point):
    for region, bounds in regions.items():
        # print(f"Bounds: {bounds}")
        # print(f"Point:{point}")
        if (
            bounds["x"][0] <= point[0] < bounds["x"][1]
            and bounds["y"][0] <= point[1] < bounds["y"][1]
        ):
            return region

    print(f"No region found for point: {point}")
    print(f"Bound: {bounds}")
    return None


# Cluster trajectories
clusters = {
    region: {other_region: {"count": 0, "indices": []} for other_region in regions}
    for region in regions
}
print(clusters)
for idx, trajectory in enumerate(trajectories):
    start_region = find_region(trajectory["start"])
    end_region = find_region(trajectory["end"])
    if start_region and end_region:
        clusters[start_region][end_region]["count"] += 1
        clusters[start_region][end_region]["indices"].append(idx)

    else:
        print("No matching region found")
        # print(trajectory["start"])
        # print(trajectory["end"])


cluster_counts = {
    region: {
        other_region: clusters[region][other_region]["count"]
        for other_region in regions
    }
    for region in regions
}
clusters_df = pd.DataFrame(cluster_counts).T
print(clusters_df)
print(clusters_df.sum(axis=1).sum())
# Get the max entry from the dataframe
max_entry = clusters_df.max().max()

In [None]:
cluster_trajectories = dataset.fields.observations[clusters["R_33"]["R_13"]["indices"]]

In [None]:
print(cluster_trajectories.shape)