In [1]:
import sys

import copy
import numpy as np 
import os
import pickle

import pytest
import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data.sampler import WeightedRandomSampler

from src.config import EnvironmentConfig, OnlineTrainConfig, RunConfig, TransformerModelConfig
from src.decision_transformer.offline_dataset import TrajectoryDataset
from src.decision_transformer.train import test
from src.decision_transformer.eval import evaluate_dt_agent
from src.environments.environments import make_env
from src.environments.memory import MemoryEnv
from src.models.trajectory_transformer import (
    CloneTransformer,
    DecisionTransformer,
)
from src.utils.trajectory_sampling import get_filtered_trajectories
from src.utils.trajectory_writer import TrajectoryWriter

from minigrid.wrappers import ViewSizeWrapper
import plotly.express as px 

env = MemoryEnv(size = 7, random_length=False, random_start_pos=False, max_steps=200, render_mode='rgb_array')
env.reset()
px.imshow(env.render()).show()

pygame 2.3.0 (SDL 2.24.2, Python 3.9.15)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
environment_config = EnvironmentConfig(
    env_id="MiniGrid-MemoryS7-v0",
    one_hot_obs=False,
    view_size=7,
    fully_observed=False,
    capture_video=False,
    render_mode="rgb_array",
    max_steps=1000,
)

run_config = RunConfig(
    exp_name="test",
    seed=1,
    track=False,
    wandb_project_name="test",
    wandb_entity="test",
)

transformer_model_config = TransformerModelConfig(
    d_model=128,
    n_heads=4,
    d_mlp=256,
    n_layers=2,
    state_embedding_type="grid",  # hard-coded for now to minigrid.
    n_ctx=26,  # one timestep of context
    device="cuda" if torch.cuda.is_available() else "cpu",
)

online_config = OnlineTrainConfig(
    use_trajectory_model=False,
    hidden_size=64,
    total_timesteps=180000,
    learning_rate=0.00025,
    decay_lr=False,
    num_envs=4,
    num_steps=128,
    gamma=0.99,
    gae_lambda=0.95,
    num_minibatches=4,
    update_epochs=4,
    clip_coef=0.4,
    ent_coef=0.2,
    vf_coef=0.5,
    max_grad_norm=2,
    trajectory_path=None,
    fully_observed=False,
    device=torch.device("cpu"),
)

dt = DecisionTransformer(
        environment_config=environment_config,
        transformer_config=copy.deepcopy(transformer_model_config),
    )
dt.transformer_config.n_ctx = 26

In [3]:
trajectory_path = "tests/tmp/test_trajectories.pkl"
num_trajectories = 300
trajectory_shape = num_trajectories if num_trajectories % 8 == 0 else (num_trajectories // 8 + 1) * 8

environment_config.max_steps = 10  # speed up test
batch = 0
eval_env_func = make_env(
    environment_config,
    seed=batch,
    idx=0,
    run_name=f"dt_eval_videos_{batch}",
)

trajectory_writer = TrajectoryWriter(
    path=trajectory_path,
    run_config=run_config,
    environment_config=environment_config,
    online_config=online_config,
    model_config=None,
)

statistics = evaluate_dt_agent(
    env_id=environment_config.env_id,
    model=dt,
    env_func=eval_env_func,
    track=False,
    initial_rtg=1,
    trajectories=num_trajectories,
    device="cuda" if torch.cuda.is_available() else "cpu",
    trajectory_writer=trajectory_writer
)

Evaluating DT:   0%|          | 0/300 [00:00<?, ?it/s]


Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\utils\tensor_new.cpp:233.)

Evaluating DT: Finished running 304 episodes.Current episodes are at timestep [10, 10, 10, 10, 10, 10, 10, 10] for reward [0. 0. 0. 0. 0. 0. 0. 0.]: 100%|██████████| 300/300 [00:04<00:00, 61.74it/s]


Writing to tests/tmp/test_trajectories.pkl
Trajectory written to tests/tmp/test_trajectories.pkl


In [4]:
import sys 
sys.path.append('..')
from src.decision_transformer.offline_dataset import TrajectoryDataset
dataset = TrajectoryDataset(trajectory_path, max_len = 10, prob_go_from_end=1, pct_traj=100)

from src.visualization import render_minigrid_observation, render_minigrid_observations

ims_shown = 0
common_arrays = [np.array([0., 0., 0.]), np.array([1., 0., 0.]), np.array([2., 5., 0.]), np.array([5., 1., 0.]), np.array([6., 1., 0.])]
s, a, r, d, rtg, ti, m = dataset[0]
px.imshow(render_minigrid_observations(env, s), animation_frame=0).show()