In [1]:
import os
import wandb
wandb.login()

%env WANDB_PROJECT=Decision_Transformer
%env WANDB_NOTEBOOK_NAME=Project_DT

import random
from dataclasses import dataclass

import numpy as np
import torch
from datasets import load_dataset
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrahulboipai[0m ([33miisc-rl[0m). Use [1m`wandb login --relogin`[0m to force relogin


env: WANDB_PROJECT=Decision_Transformer
env: WANDB_NOTEBOOK_NAME=Project_DT


In [2]:

%env WANDB_LOG_MODEL=true

env: WANDB_LOG_MODEL=true


# Dataset

In [3]:
#dataset = load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-expert-v2")
dataset = load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-medium-v2")
#dataset = load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-medium-replay-v2")

In [4]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 17  # size of state space
    act_dim: int = 6  # size of action space
    max_ep_len: int = 1000 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["observations"][0])
        self.dataset = dataset
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []
        for obs in dataset["observations"]:
            states.extend(obs)
            traj_lens.append(len(obs))
        self.n_traj = len(traj_lens)
        states = np.vstack(states)
        self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

# Training

In [5]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the DT loss
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        loss = torch.mean((action_preds - action_targets) ** 2)
        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

In [6]:
collator = DecisionTransformerGymDataCollator(dataset["train"])

config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDT(config)

In [7]:

wandb.init(
    project="Decision_Transformer",
    name="HalfCheetahm-v2-Train",
    group="Half-Cheetah-medium-v2"
)

training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=120,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
    report_to="wandb",
    run_name="halfcheetah-medium-v2",
    logging_steps=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collator,
)

trainer.train()



Step,Training Loss
1,0.6802
2,0.6922
3,0.6833
4,0.6924
5,0.6916
6,0.69
7,0.6869
8,0.6813
9,0.6904
10,0.6783


Checkpoint destination directory output/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory output/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory output/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.


TrainOutput(global_step=1920, training_loss=0.08645531896230144, metrics={'train_runtime': 1911.0615, 'train_samples_per_second': 62.792, 'train_steps_per_second': 1.005, 'total_flos': 147340224000000.0, 'train_loss': 0.08645531896230144, 'epoch': 120.0})

# Testing 

In [8]:
import gymnasium as gym
import mujoco_py

env = gym.make("HalfCheetah-v4")

/data/home/rahulboipai/miniconda3/envs/state/lib/python3.10/site-packages/glfw/__init__.py:916: GLFWError: (65544) b'X11: Failed to open display localhost:10.0'


In [9]:
# Function that gets an action from the model using autoregressive prediction with a window of the previous 20 timesteps.
def get_action(model, states, actions, rewards, returns_to_go, timesteps):
    # This implementation does not condition on past rewards

    states = states.reshape(1, -1, model.config.state_dim)
    actions = actions.reshape(1, -1, model.config.act_dim)
    returns_to_go = returns_to_go.reshape(1, -1, 1)
    timesteps = timesteps.reshape(1, -1)

    states = states[:, -model.config.max_length :]
    actions = actions[:, -model.config.max_length :]
    returns_to_go = returns_to_go[:, -model.config.max_length :]
    timesteps = timesteps[:, -model.config.max_length :]
    padding = model.config.max_length - states.shape[1]
    # pad all tokens to sequence length
    attention_mask = torch.cat([torch.zeros(padding), torch.ones(states.shape[1])])
    attention_mask = attention_mask.to(dtype=torch.long).reshape(1, -1)
    states = torch.cat([torch.zeros((1, padding, model.config.state_dim)), states], dim=1).float()
    actions = torch.cat([torch.zeros((1, padding, model.config.act_dim)), actions], dim=1).float()
    returns_to_go = torch.cat([torch.zeros((1, padding, 1)), returns_to_go], dim=1).float()
    timesteps = torch.cat([torch.zeros((1, padding), dtype=torch.long), timesteps], dim=1)

    state_preds, action_preds, return_preds = model.original_forward(
        states=states,
        actions=actions,
        rewards=rewards,
        returns_to_go=returns_to_go,
        timesteps=timesteps,
        attention_mask=attention_mask,
        return_dict=False,
    )

    return action_preds[0, -1]

In [10]:
model = model.to("cpu")
max_ep_len = 10000
device = 'cpu'
scale = 1000.0  # normalization for rewards/returns
TARGET_RETURN = 12000 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

state_mean = collator.state_mean.astype(np.float32)
state_std = collator.state_std.astype(np.float32)
print(state_mean)

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
# Create the decision transformer model

state_mean = torch.from_numpy(state_mean).to(device=device)
state_std = torch.from_numpy(state_std).to(device=device)

[-0.06845804  0.01641462 -0.18355839 -0.2762486  -0.34113872 -0.0933961
 -0.21321191 -0.08774357  5.1731944  -0.04275185 -0.0361088   0.14053658
  0.06049891  0.09550849  0.06739013  0.00562735  0.01338256]


In [11]:
import wandb
run = wandb.init(
    project="Decision_Transformer",
    name="HalfCheetah-v2-Test",
    group="Half-Cheetah-medium-v2"
)

from gymnasium.experimental.wrappers import RecordVideoV0
env = gym.make("HalfCheetah-v2", render_mode="rgb_array")
env = RecordVideoV0(env, video_folder = './video')

#Interaction
episode_return, episode_length = 0,0
state = env.reset()
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1,1)
states = torch.from_numpy(state[0]).reshape(1, state_dim).to(device, dtype=torch.float32)
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
rewards = torch.zeros(0, device=device, dtype=torch.float32)

timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
for t in range(max_ep_len):
    actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    action = get_action(
        model,
        (states - state_mean) / state_std,
        actions,
        rewards,
        target_return,
        timesteps,
    )
    actions[-1] = action
    action = action.detach().cpu().numpy()

    #state, reward, done, _ = env.step(action)
    state, reward, done, truncated, info = env.step(action)

    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)
    rewards[-1] = reward

    pred_return = target_return[0, -1] - (reward / scale)
    target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
    timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

    episode_return += reward
    episode_length += 1
    wandb.log({'rewards':rewards, 'score': episode_return, 'episode': episode_length}, step=t)


    if done:
        break
    
env.close()


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import display


VBox(children=(Label(value='4.812 MB of 4.812 MB uploaded (0.001 MB deduped)\r'), FloatProgress(value=1.0, max…

  from IPython.core.display import HTML, display  # type: ignore


0,1
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,▂▃▅▇███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
train/loss,█▆▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁
train/train_samples_per_second,▁
train/train_steps_per_second,▁

0,1
train/epoch,120.0
train/global_step,1920.0
train/learning_rate,0.0
train/loss,0.0606
train/total_flos,147340224000000.0
train/train_loss,0.08646
train/train_runtime,1911.0615
train/train_samples_per_second,62.792
train/train_steps_per_second,1.005


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import display


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112621933635738, max=1.0…

  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  logger.deprecation(
  logger.deprecation(
  logger.warn(


Found 3 GPUs for rendering. Using device 0.


IndexError: index out of range in self

In [None]:
import wandb
run = wandb.init(
    project="Decision_Transformer",
    name="HalfCheetah-v2-Performance",
    group="Half-Cheetah-medium-v2"
)


env = gym.make("HalfCheetah-v4", render_mode="rgb_array")


for target in range(1,12000,100):
    
    TARGET_RETURN = target / scale

    #Interaction
    episode_return, episode_length = 0,0
    state = env.reset()
    target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1,1)
    states = torch.from_numpy(state[0]).reshape(1, state_dim).to(device, dtype=torch.float32)
    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)

    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
    for t in range(max_ep_len):
        actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)])

        action = get_action(
            model,
            (states - state_mean) / state_std,
            actions,
            rewards,
            target_return,
            timesteps,
        )
        actions[-1] = action
        action = action.detach().cpu().numpy()

        #state, reward, done, _ = env.step(action)
        state, reward, done, truncated, info = env.step(action)

        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward

        pred_return = target_return[0, -1] - (reward / scale)
        target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
        timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

        episode_return += reward
        episode_length += 1
        
        if done:
        
            break
        
    wandb.log({'Duration':episode_length, 'Performance': episode_return, 'episode': episode_length}, step=int(TARGET_RETURN))
    
env.close()


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import display


VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

  from IPython.core.display import HTML, display  # type: ignore


0,1
episode,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
rewards,▁
score,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
episode,4096.0
score,20674.52493


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import display


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113757929868169, max=1.0…

  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


  from IPython.core.display import HTML, display  # type: ignore


IndexError: index out of range in self