In [4]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torch.nn as nn
import time

from envs.random_walk import RandomWalkEnv
from algorithms.random_policy import RandomPolicy
from algorithms.sequence_models.decision_transformer.decision_transformer import DecisionTransformer
from algorithms.sequence_models.decision_transformer.trainer import DecisionTransformerTrainer
from datasets.random_walk_dataset import RandomWalkDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--env', type=str, default='hopper')
# parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
# parser.add_argument('--mode', type=str, default='normal')  # normal for standard setting, delayed for sparse
# parser.add_argument('--K', type=int, default=20)
# parser.add_argument('--pct_traj', type=float, default=1.)
# parser.add_argument('--batch_size', type=int, default=64)
# parser.add_argument('--model_type', type=str, default='dt')  # dt for decision transformer, bc for behavior cloning
# parser.add_argument('--embed_dim', type=int, default=128)
# parser.add_argument('--n_layer', type=int, default=3)
# parser.add_argument('--n_head', type=int, default=1)
# parser.add_argument('--activation_function', type=str, default='relu')
# parser.add_argument('--dropout', type=float, default=0.1)
# parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
# parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
# parser.add_argument('--warmup_steps', type=int, default=10000)
# parser.add_argument('--num_eval_episodes', type=int, default=100)
# parser.add_argument('--max_iters', type=int, default=10)
# parser.add_argument('--num_steps_per_iter', type=int, default=10000)
# parser.add_argument('--device', type=str, default='cuda')
# parser.add_argument('--log_to_wandb', '-w', type=bool, default=False)


class Defaults:
    lr: float = 1e-4
    weight_decay: float = 1e-4
    batch_size: int = 64
    warmup_steps: int = 10000


In [None]:
def eval_episodes(target_rew):
    def fn(model):
        returns, lengths = [], []
        for _ in range(num_eval_episodes):
            with torch.no_grad():
                if model_type == 'dt':
                    ret, length = evaluate_episode_rtg(
                        env,
                        state_dim,
                        act_dim,
                        model,
                        max_ep_len=max_ep_len,
                        scale=scale,
                        target_return=target_rew/scale,
                        mode=mode,
                        state_mean=state_mean,
                        state_std=state_std,
                        device=device,
                    )
                else:
                    ret, length = evaluate_episode(
                        env,
                        state_dim,
                        act_dim,
                        model,
                        max_ep_len=max_ep_len,
                        target_return=target_rew/scale,
                        mode=mode,
                        state_mean=state_mean,
                        state_std=state_std,
                        device=device,
                    )
            returns.append(ret)
            lengths.append(length)
        return {
            f'target_{target_rew}_return_mean': np.mean(returns),
            f'target_{target_rew}_return_std': np.std(returns),
            f'target_{target_rew}_length_mean': np.mean(lengths),
            f'target_{target_rew}_length_std': np.std(lengths),
        }
    return fn


In [42]:
env = RandomWalkEnv()
dataset = RandomWalkDataset(n_trajectories=100000)

model = DecisionTransformer(
        hidden_size=64,
        dataset=dataset,
        block_size=64,  # todo experiment with block_size. we might need to have block_size >= SPLIT_SEQUENCE_LENGTH but I'm not sure about this
        max_length=None,
        action_tanh=True,
        gpt_config={}
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=Defaults.lr,
    weight_decay=Defaults.weight_decay,
)

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda steps: min((steps+1)/Defaults.warmup_steps, 1)
)

env_targets = [0.8]

trainer = DecisionTransformerTrainer(
        model=model,
        optimizer=optimizer,
        batch_size=Defaults.batch_size,
        scheduler=scheduler,
        loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a)**2),
        eval_fns=[eval_episodes(tar) for tar in env_targets],
)



# dataloader = DataLoader(dataset, batch_size=3, shuffle=True)
# observations, actions, rewards, dones, returns_to_go = next(iter(dataloader))
# observations, actions, rewards, dones, returns_to_go

In [26]:
# todo probably there is a better way to embed the return. We can try writing it in Gaussian Basis
# in the current implementation they just use a linear layer


In [9]:
class Trainer:
    def __init__(self, model, optimizer, batch_size, dataset, loss_fn, scheduler=None, eval_fns=None):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        
        self.dataset = dataset
        self.train_loader = DataLoader(
            dataset,
            sampler=torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=int(1e10)),
            shuffle=False,
            pin_memory=True,
            batch_size=batch_size,
            num_workers=0, # todo or 1?
        )

        self.loss_fn = loss_fn
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()

        self.start_time = time.time()

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):

        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        # todo maybe we should pass attention mask from the dataset
        # todo for the ones after done, should it only attend to itself or to all the previous ones?
        for _, batch in zip(range(num_steps), self.train_loader):
            train_loss = self.train_step(batch)

            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        eval_start = time.time()

        self.model.eval()
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model)
            for k, v in outputs.items():
                logs[f'evaluation/{k}'] = v

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def process_batch(self, batch):
        observations, actions, rewards, dones, returns_to_go = batch
        
        
        observations = torch.from_numpy(np.concatenate(observations, axis=0)).to(dtype=torch.float32, device=device)
        actions = torch.from_numpy(np.concatenate(actions, axis=0)).to(dtype=torch.float32, device=device)
        rewards = torch.from_numpy(np.concatenate(rewards, axis=0)).to(dtype=torch.float32, device=device)
        dones = torch.from_numpy(np.concatenate(dones, axis=0)).to(dtype=torch.long, device=device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
        attention_mask = torch.from_numpy(np.concatenate(attention_mask, axis=0)).to(device=device)

        return observations, actions, rewards, dones, attention_mask, returns_to_go
    
    
    def train_step(self, batch):
        states, actions, rewards, dones, attention_mask, returns = self.process_batch(batch)
        state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards)

        state_preds, action_preds, reward_preds = self.model.forward(
            states, actions, rewards, masks=None, attention_mask=attention_mask, target_return=returns,
        )

        # todo wtf! wdym by not fully correct :))
        # note: currently indexing & masking is not fully correct
        loss = self.loss_fn(
            state_preds, action_preds, reward_preds,
            state_target[:,1:], action_target, reward_target[:,1:],
        )
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.detach().cpu().item()

# observations, actions, rewards, dones, returns_to_go

In [11]:
model = DecisionTransformer(state_dim=1,
                    act_dim=1,
                    hidden_size=16, # todo configure for the problem
                    max_length=None,
                    max_ep_len=16, # todo configure for the problem
                    action_tanh=True,
                    gpt_config={})

number of parameters: 0.09M


In [16]:
from mingpt.trainer import Trainer




1.0