## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import gymnasium as gym
from diffuser.utils.config import Config, get_params, get_device_settings
from justin_arm.training_justin import Justin_Trainer
from justin_arm.helper import (
    create_state_action_array,
    interpolate_trajectories,
    condition_start_end_per_trajectory,
)
from rokin import robots, vis
from diffuser.utils.arrays import report_parameters, batchify
from diffuser.datasets.sequence import TrajectoryDataset
import numpy as np
import torch
import sys
import matplotlib.pyplot as plt
from datetime import datetime
import wandb
from tqdm import tqdm

# Render original and diffused trajectories:
from justin_arm.visualize import (
    plot_trajectory_per_frames,
    plot_q_values_per_trajectory,
    plot_multiple_trajectories,
)
from justin_arm.helper import robot_env_dist, analyze_distance
from diffuser.utils.arrays import apply_dict, batch_to_device, to_device, to_np

import os

## Parse Arguments and Paramters

In [None]:
# Get settings from the config file

parser = get_params()

# overwrite params for Justin Arm
args = args = parser.parse_args(
    [
        "--action_dim",
        "7",
        "--observation_dim",
        "7",
        "--train_batch_size",
        "16",
        "--savepath",
        "saved_justin_ep100_n100/",
        "--dataset",
        "new_dataset",
        "--horizon",
        "32",
        "--save_freq",
        "10000",
        "--train_lr",
        "0.001",
        "--n_timesteps",
        "256",
    ]
)

# Set Seeds
seed = args.seed
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# Get device settings
device = get_device_settings(args)

# Check if saved path exists else create it :
if not os.path.exists(args.savepath):
    os.makedirs(args.savepath)

In [None]:
dataset = np.load("justin_arm/data/q_paths_4123.npy")
dataset_image = np.load("justin_arm/data/image_4123.npy")
trajectory_dataset = TrajectoryDataset(
    dataset=dataset, horizon=args.horizon, image=dataset_image
)
robot = robots.JustinArm07()


model_config = Config(
    args.model,
    savepath=(args.savepath, "model_config.pkl"),
    horizon=args.horizon,
    transition_dim=args.observation_dim + args.action_dim,
    cond_dim=args.observation_dim,
    dim_mults=args.dim_mults,
    device=device,
)
diffusion_config = Config(
    _class="models.diffuser.GaussianDiffusion",
    savepath=(args.savepath, "diffusion_config.pkl"),
    horizon=args.horizon,
    observation_dim=args.observation_dim,
    action_dim=args.action_dim,
    n_timesteps=args.n_timesteps,
    loss_type=args.loss_type,
    clip_denoised=args.clip_denoised,
    predict_epsilon=args.predict_epsilon,
    # loss weighting
    action_weight=args.action_weight,
    loss_weights=args.loss_weights,
    loss_discount=args.loss_discount,
    device=device,
)

trainer_config = Config(
    Justin_Trainer,
    savepath=(args.savepath, "trainer_config.pkl"),
    train_batch_size=args.train_batch_size,
    train_lr=args.train_lr,
    name=args.env_name,
    gradient_accumulate_every=args.gradient_accumulate_every,
    ema_decay=args.ema_decay,
    sample_freq=args.sample_freq,
    save_freq=args.save_freq,
    label_freq=args.label_freq,
    save_parallel=args.save_parallel,
    results_folder=args.savepath,
    bucket=args.bucket,
    n_reference=args.n_reference,
    n_samples=args.n_samples,
    device=device,
)

In [None]:
# Print min and max:

print(trajectory_dataset.normalizer.maxs)

print("Min: ", np.min(trajectory_dataset.normalized_data))
print("Max: ", np.max(trajectory_dataset.normalized_data))
# Print mean and std:
print("Mean: ", np.mean(trajectory_dataset.normalized_data))
print("Std: ", np.std(trajectory_dataset.normalized_data))

In [None]:
# Load objects

model = model_config()
diffuser = diffusion_config(model)
trainer = trainer_config(diffuser, trajectory_dataset, device, robot)

## Forward pass is working

In [None]:
report_parameters(model)

print("Testing forward...", end=" ", flush=True)
batch = batchify(trajectory_dataset[0])
loss, _ = diffuser.loss(*batch)
loss.backward()
print("✓")

## Using the trainer requires taking care of the 'device' in the folders

# Training process inlcluding rendering

In [None]:
current_time = datetime.now().strftime("%d_%m_%Y-%H-%M")

if args.use_wandb:
    run = wandb.init(
        config=args,
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=f"Run100_100_{current_time}",
        group="Group-Name",
        job_type="training",
        reinit=True,
    )

## Training

In [None]:
# n_epochs = int(args.n_train_steps // args.n_steps_per_epoch)
n_epochs = 100
diffuser.to(device)
for i in tqdm(range(n_epochs)):
    print(f"Epoch {i} / {n_epochs} | {args.savepath}")
    trainer.train(n_train_steps=100)

## Training for a single datapoint : 1. Experiment!

In [None]:
# Overfit to a single datapoint


# Choose a single trajectory through the dataloader given a batch size of 1, then we do not know exactly what idx that is
# single_input = next(iter(trainer.dataloader))
# single_input = batch_to_device(single_input, device)

# Choose a single trajectory arbitrarily
single_input = batchify(trajectory_dataset[800])
single_input = batch_to_device(single_input, device)

# Just sample the trajectory directly from the dataset:
single_input_unnormalized = interpolate_trajectories(dataset[10], 32)

# # Start training
# diffuser.to(device)
# n_epochs = 500  # Overfitting typically requires fewer epochs
# for i in tqdm(range(n_epochs)):
#     print(f"Epoch {i} / {n_epochs} | {args.savepath}")
#     trainer.train_single_datapoint(n_train_steps=20, single_input=single_input)

## Loading existing model and visualize performance:

In [None]:
# Load and test the model on the single datapoint:
model_path = "saved_justin_ep100_n100//state_10000.pt"
trainer.load(directory=model_path, epoch=100)

### Plot the trajectory taken directly from the dataset:

In [None]:
# # Plor the orginal trajectory
# %matplotlib inline

# # Plot the original and diffused trajectories:
# # Original:

# # Get collision_metric:
# distance = robot_env_dist(
#     q=single_input_unnormalized[0], robot=trainer.robot, img=trainer.dataset.image[0]
# )

# score = analyze_distance(distance)

# print(f"Collision score: {score}")


# print(single_input_unnormalized[0].shape)

# plot_trajectory_per_frames(single_input_unnormalized[0])
# plot_q_values_per_trajectory(single_input_unnormalized[0])


# limits = np.array([[-1.25, +1.25], [-1.25, +1.25], [-1.25, +1.25]])
# vis.three_pv.animate_path(
#     robot=trainer.robot,
#     q=single_input_unnormalized[0],
#     kwargs_robot=dict(color="red"),
#     kwargs_world=dict(img=trainer.dataset.image[0], limits=limits, color="yellow"),
# )

### Plotting the reference trajectory:

In [None]:
%matplotlib inline
trainer.plot_reference_data(single_input)


### Plotting the diffused reconstruction:

In [None]:
# Now plotting for the diffused trajectory:
%matplotlib inline
collision_score = trainer.render_given_sample(single_input, render_3d=True)
print(f"Collision score: {collision_score}")

In [None]:
# Out of wqorld generalization:
dataset = np.load("justin_arm/data/q_paths_6547.npy")
dataset_image = np.load("justin_arm/data/image_6547.npy")
trajectory_dataset = TrajectoryDataset(
    dataset=dataset, horizon=args.horizon, image=dataset_image
)

robot = robots.JustinArm07()
trainer = trainer_config(diffuser, trajectory_dataset, device, robot)


# Load model:
# Load and test the model on the single datapoint:
model_path = "saved_justin_ep100_n100//state_10000.pt"
trainer.load(directory=model_path, epoch=100)


# Load random datapoints:
single_input = batchify(trajectory_dataset[10])
single_input = batch_to_device(single_input, device)


# Compare reference and diffusion:

In [None]:
%matplotlib inline
collision_score = trainer.plot_reference_data(single_input, render_3d=False)
print(f"Collision score: {collision_score}")

In [None]:
# Now plotting for the diffused trajectory:
%matplotlib inline
# empty ndarray of shape (0,args.horizon,7)
q_paths = np.zeros((0, args.horizon, 7))
for i in range(20):
    q_path, collision_score = trainer.render_given_sample(single_input, render_3d=False)
    q_path = np.expand_dims(q_path, axis=0)
    q_paths = np.append(q_paths, q_path, axis=0)


In [None]:
print(q_paths.shape)

In [None]:
%matplotlib inline
plot_multiple_trajectories(q_paths[:10], q_paths.shape[0])