In [None]:
import d3rlpy
import minari
import time
import imageio
import os 

# Experiment to visualize

In [None]:
experiment = 'online' # offline - finetuning - online
task = 'hammer' # pen - relocate - hammer - door
algorithm = 'CQL' # IQL - CQL - BC - TD3+BC - AWAC

save_gif = True

# Loading of datasets

In [None]:
# Loading Minari datasets for the tasks
pen_dataset = minari.load_dataset("D4RL/pen/expert-v2")
relocate_dataset = minari.load_dataset("D4RL/relocate/expert-v2")
hammer_dataset = minari.load_dataset("D4RL/hammer/expert-v2")
door_dataset = minari.load_dataset("D4RL/door/expert-v2")

# Loading of policies

### Pen

In [None]:
policies_pen = {
    "IQL": d3rlpy.load_learnable(f"policies/{experiment}/pen_iql.d3"),
    "CQL": d3rlpy.load_learnable(f"policies/{experiment}/pen_cql.d3"),
    "TD3+BC": d3rlpy.load_learnable(f"policies/{experiment}/pen_td3bc.d3"),
    "AWAC": d3rlpy.load_learnable(f"policies/{experiment}/pen_awac.d3")
}

if experiment == 'offline':
    policies_pen["BC"] = d3rlpy.load_learnable(f"policies/{experiment}/pen_bc.d3")

### Relocate

In [None]:
policies_relocate = {
    "IQL": d3rlpy.load_learnable(f"policies/{experiment}/relocate_iql.d3"),
    "CQL": d3rlpy.load_learnable(f"policies/{experiment}/relocate_cql.d3"),
    "TD3+BC": d3rlpy.load_learnable(f"policies/{experiment}/relocate_td3bc.d3"),
    "AWAC": d3rlpy.load_learnable(f"policies/{experiment}/relocate_awac.d3")
}

if experiment == 'offline':
    policies_relocate["BC"] = d3rlpy.load_learnable(f"policies/{experiment}/relocate_bc.d3")

### Hammer

In [None]:
policies_hammer = {
    "IQL": d3rlpy.load_learnable(f"policies/{experiment}/hammer_iql.d3"),
    "CQL": d3rlpy.load_learnable(f"policies/{experiment}/hammer_cql.d3"),
    "TD3+BC": d3rlpy.load_learnable(f"policies/{experiment}/hammer_td3bc.d3"),
    "AWAC": d3rlpy.load_learnable(f"policies/{experiment}/hammer_awac.d3")
}

if experiment == 'offline':
    policies_hammer["BC"] = d3rlpy.load_learnable(f"policies/{experiment}/hammer_bc.d3")

### Door

In [None]:
policies_door = {
    "IQL": d3rlpy.load_learnable(f"policies/{experiment}/door_iql.d3"),
    "CQL": d3rlpy.load_learnable(f"policies/{experiment}/door_cql.d3"),
    "TD3+BC": d3rlpy.load_learnable(f"policies/{experiment}/door_td3bc.d3"),
    "AWAC": d3rlpy.load_learnable(f"policies/{experiment}/door_awac.d3")
}

if experiment == 'offline':
    policies_door["BC"] = d3rlpy.load_learnable(f"policies/{experiment}/door_bc.d3")

# Visualization of policies

In [None]:
def visualize(env, policy):
    obs, _ = env.reset()
    done = False
    total_reward = 0
    frames = []

    while not done:
        action = policy.predict(obs[None])[0]
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        total_reward += reward

        if save_gif:
            frames.append(env.render())
        else:
            time.sleep(0.01)

    env.close()

    if save_gif:
        gif_path = f'results/gifs/{task}_{experiment}_{algorithm}.gif'
        os.makedirs(os.path.dirname(gif_path), exist_ok=True)
        imageio.mimsave(gif_path, frames, duration=1/30)
        print(f"GIF saved to: {gif_path}")

    print(f"Episode finished with return: {total_reward:.2f}")

In [None]:
if save_gif:
    render_mode = "rgb_array"
else:
    render_mode = "human"


# Select environment and policy
if task == 'pen':
    env = pen_dataset.recover_environment(render_mode=render_mode)
    policy = policies_pen[algorithm]
elif task == 'relocate':
    env = relocate_dataset.recover_environment(render_mode=render_mode)
    policy = policies_relocate[algorithm]
elif task == 'hammer':
    env = hammer_dataset.recover_environment(render_mode=render_mode)
    policy = policies_hammer[algorithm]
elif task == 'door':
    env = door_dataset.recover_environment(render_mode=render_mode)
    policy = policies_door[algorithm]
    
visualize(env, policy)