<a href="https://colab.research.google.com/github/PsorTheDoctor/visuomotor-robot-policies/blob/main/behavior_transformer/mini_bet_vision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#miniBET: Behavior Transformer

In [None]:
!pip3 install -q torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 \
scikit-image==0.19.3 scikit-video==1.1.11 zarr==2.12.0 numcodecs==0.10.2 \
pygame==2.1.2 pymunk==6.2.1 gym==0.26.2 shapely==1.8.4
!git clone https://github.com/PsorTheDoctor/visuomotor-robot-policies.git
%cd visuomotor-robot-policies/

fatal: destination path 'visuomotor-robot-policies' already exists and is not an empty directory.
/content/visuomotor-robot-policies


In [None]:
from typing import Callable
import os
import numpy as np
import gdown
import torch
import collections
from skvideo.io import vwrite
from IPython.display import Video
from tqdm.auto import tqdm

from utils.env import PushTImageEnv
from utils.dataset import PushTImageDataset, normalize_data, unnormalize_data

env = PushTImageEnv()
env.seed(1000)
obs, info = env.reset()
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)

with np.printoptions(precision=4, suppress=True, threshold=5):
  print("Obs: ", repr(obs))
  print("Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]")
  print("Action: ", repr(action))
  print("Action:   [target_agent_x, target_agent_y]")

In [None]:
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
if not os.path.isfile(dataset_path):
  id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
  gdown.download(id=id, output=dataset_path, quiet=False)

# pred_horizon = 16
# obs_horizon = 2
# action_horizon = 8
batch_size = 64
horizon = 16

dataset = PushTImageDataset(
    dataset_path=dataset_path,
    pred_horizon=horizon,
    obs_horizon=horizon,
    action_horizon=horizon
)
stats = dataset.stats

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=True,
    pin_memory=True,
    persistent_workers=True
)
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['image'].shape)
print("batch['agent_pos'].shape:", batch['agent_pos'].shape)
print("batch['action'].shape", batch['action'].shape)



batch['image'].shape: torch.Size([64, 16, 3, 96, 96])
batch['agent_pos'].shape: torch.Size([64, 16, 2])
batch['action'].shape torch.Size([64, 16, 2])


In [None]:
obs_seq = batch['image'].reshape((batch_size, horizon, 3*96*96))
goal_seq = batch['agent_pos']
action_seq = batch['action']
print('obs_seq.shape:', obs_seq.shape)
print('goal_seq.shape:', goal_seq.shape)
print('action_seq.shape:', action_seq.shape)

obs_seq.shape: torch.Size([64, 16, 27648])
goal_seq.shape: torch.Size([64, 16, 2])
action_seq.shape: torch.Size([64, 16, 2])


In [None]:
!git clone https://github.com/notmahi/miniBET.git
%cd miniBET
%pip install --upgrade .

/content/visuomotor-robot-policies/miniBET


In [None]:
import torch
from behavior_transformer import BehaviorTransformer, GPT, GPTConfig
# from examples import dataset

conditional = False
obs_dim = 27648
act_dim = 2
goal_dim = obs_dim if conditional else 0
K = 32
T = 16
# batch_size = 256
epochs = 5

bet = BehaviorTransformer(
    obs_dim=obs_dim, act_dim=act_dim, goal_dim=goal_dim,
    gpt_model=GPT(GPTConfig(
        block_size=144, input_dim=obs_dim, n_layer=6, n_head=8, n_embd=256
    )),
    n_clusters=K, kmeans_fit_steps=5
)
optimizer = bet.configure_optimizers(
    weight_decay=2e-4, learning_rate=1e-5, betas=[0.9, 0.999]
)
with tqdm(range(epochs), desc='Epoch') as tglobal:
  for epoch_idx in tglobal:
    epoch_loss = list()
    with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
      for nbatch in tepoch:
        print(nbatch['image'].shape)
        obs_seq = nbatch['image'].reshape((batch_size, horizon, 3*96*96))[:, :horizon]
        goal_seq = nbatch['agent_pos'][:, :horizon]
        action_seq = nbatch['action']

        train_action, train_loss, train_loss_dict = bet(obs_seq, goal_seq, action_seq)

        # # Action inference
        # eval_action, eval_loss, eval_loss_dict = bet(obs_seq, goal_seq, None)
        # print('Eval loss:', eval_loss)

        epoch_loss.append(float(train_loss))
        tepoch.set_postfix(loss=train_loss)
    tglobal.set_postfix(loss=np.mean(epoch_loss))

In [None]:
%pip install torchviz torchview

In [None]:
from torchviz import make_dot

make_dot(train_action, params=dict(bet.named_parameters())).render(format='png')

In [None]:
from torchview import draw_graph

model_graph = draw_graph(bet(), input_size=(batch_size, T, obs_dim), expand_nested=False)
model_graph.visual_graph