In [1]:
# install dependencies
%%capture
!apt-get install -y \
    libgl1-mesa-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libosmesa6-dev \
    software-properties-common \
    patchelf \
    xvfb

In [2]:
# install OpenAI gym packages
%%capture
!pip install gym==0.21.0
!pip install free-mujoco-py
!pip install git+https://github.com/huggingface/transformers 

!pip install colabgymrender==1.0.2
!pip install xvfbwrapper
!pip install imageio==2.4.1

In [3]:
# mount Google Drive
from google.colab import drive
drive.mount('/drive')

Mounted at /drive


In [4]:
# import our AI packages
%%capture
import torch
import mujoco_py
import gym
import numpy as np

from colabgymrender.recorder import Recorder
from transformers import DecisionTransformerModel

In [6]:
# Build the environment
directory = '/drive/My Drive/video'

env = gym.make('Walker2d-v3')
env = Recorder(env, directory, fps=30)
max_ep_len = 1000
device = "cuda"
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 3600 / scale  # evaluation is conditioned on a return of 3600 and scaled accordingly

# mean and standard deviation computed from training dataset
state_mean = np.array(
    [ 1.2384834e+00, 1.9578537e-01, -1.0475016e-01, -1.8579608e-01, 2.3003316e-01, 
     2.2800924e-02, -3.7383768e-01, 3.3779100e-01, 3.9250960e+00, -4.7428459e-03, 2.5267061e-02, 
     -3.9287535e-03, -1.7367510e-02, -4.8212224e-01, 3.5432147e-04, -3.7124525e-03, 2.6285544e-03]

)
state_std = np.array(
    [0.06664903, 0.16980624, 0.17309439, 0.21843709, 0.74599105, 0.02410989, 0.3729872, 0.6226182, 
     0.9708009, 0.72936815, 1.504065, 2.495893, 3.511518, 5.3656907, 0.79503316, 4.317483, 6.1784487]
)

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

# Instantiate the decision transformer model
model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-walker2d-expert")
model = model.to(device)
print(list(model.encoder.wpe.parameters()))

state_mean = torch.from_numpy(state_mean).to(device=device)
state_std = torch.from_numpy(state_std).to(device=device)

[Parameter containing:
tensor([[0., -0., 0.,  ..., 0., 0., -0.],
        [0., -0., -0.,  ..., 0., -0., -0.],
        [-0., 0., 0.,  ..., -0., -0., -0.],
        ...,
        [-0., -0., -0.,  ..., -0., -0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., -0., -0.,  ..., -0., 0., -0.]], device='cuda:0',
       requires_grad=True)]


In [7]:
# Explore our state & action space
print('State Space Dimmension:', state_dim)
print('Action Space Dimmension:', act_dim)

State Space Dimmension: 17
Action Space Dimmension: 6


## Autoregressive Prediction Function
The model's prediction is conditioned on a sequences of states, actions, time-steps and returns. The action for the current time-step is included as zeros and masked to not skew the model's attention distribution.

In [19]:
# Function that gets an action from the model using autoregressive prediction 
# with a window of the previous 20 timesteps
def get_action(model, states, actions, rewards, returns_to_go, timesteps):
    # This implementation does not condition on past rewards

    states = states.reshape(1, -1, model.config.state_dim)
    actions = actions.reshape(1, -1, model.config.act_dim)
    returns_to_go = returns_to_go.reshape(1, -1, 1)
    timesteps = timesteps.reshape(1, -1)

    states = states[:, -model.config.max_length :]
    actions = actions[:, -model.config.max_length :]
    returns_to_go = returns_to_go[:, -model.config.max_length :]
    timesteps = timesteps[:, -model.config.max_length :]
    padding = model.config.max_length - states.shape[1] # pad all tokens to sequence length
    
    attention_mask = torch.cat([torch.zeros(padding), torch.ones(states.shape[1])])
    attention_mask = attention_mask.to(dtype=torch.long).reshape(1, -1)
    attention_mask = attention_mask.to(device)
    states = torch.cat([torch.zeros((1, padding, model.config.state_dim), device=device), states], dim=1).float()
    actions = torch.cat([torch.zeros((1, padding, model.config.act_dim), device=device), actions], dim=1).float()
    returns_to_go = torch.cat([torch.zeros((1, padding, 1), device=device), returns_to_go], dim=1).float()
    timesteps = torch.cat([torch.zeros((1, padding), dtype=torch.long, device=device), timesteps], dim=1)

    state_preds, action_preds, return_preds = model(
        states=states,
        actions=actions,
        rewards=rewards,
        returns_to_go=returns_to_go,
        timesteps=timesteps,
        attention_mask=attention_mask,
        return_dict=False,
    )

    return action_preds[0, -1]

In [20]:
# Interact with the environment and create a video
episode_return, episode_length = 0, 0
state = env.reset()
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
rewards = torch.zeros(0, device=device, dtype=torch.float32)

timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

for t in range(max_ep_len):
    actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    action = get_action(model, (states - state_mean) / state_std, actions, rewards, target_return, timesteps)
    actions[-1] = action
    action = action.detach().cpu().numpy()

    state, reward, done, _ = env.step(action)

    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)
    rewards[-1] = reward

    pred_return = target_return[0, -1] - (reward / scale)
    target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
    timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

    episode_return += reward
    episode_length += 1

    if done:
        break

In [21]:
# Play the video
env.play()

Output hidden; open in https://colab.research.google.com to view.