In [1]:
import json
import math
import os
from pathlib import Path

import gymnasium as gym
import mani_skill2.envs
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from mani_skill2.utils.wrappers import RecordEpisode
from torch.nn import (Flatten, Linear, TransformerEncoder,
                      TransformerEncoderLayer)
from torch.nn.functional import relu
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from data.dataset import StackDatasetOriginalSequential
from utils.data_utils import flatten_obs, make_path
from utils.train_utils import init_deque, update_deque
from torch.distributions import Normal

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_path = make_path('BC_Transformer_Gaussian', 'checkpoints')
log_path = make_path('BC_Transformer_Gaussian', 'logs')
tensorboard_path = make_path('BC_Transformer_Gaussian', 'logs', 'tensorboard')

Path(ckpt_path).mkdir(exist_ok=True, parents=True)
Path(log_path).mkdir(exist_ok=True, parents=True)
Path(tensorboard_path).mkdir(exist_ok=True, parents=True)

In [3]:

class PositionalEncoding(nn.Module):
    def __init__(self,
                 d_model: int,
                 dropout: float = 0.1,
                 max_len: int = 100):

        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)
                             * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)  # [seq_len, d_model]
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:x.size(1)]
        return self.dropout(x)


class BC(nn.Module):
    def __init__(self,
                 seq_len=8,
                 obs_dim=55,
                 act_dim=8,
                 dropout=0.1,
                 d_model=128,
                 dim_ff=128,
                 num_heads=8,
                 num_layers=3):
        super(BC, self).__init__()

        self.d_model = d_model

        self.embedding = Linear(in_features=obs_dim,
                                out_features=d_model)  # project obs dimension to d_model dimension

        self.pos_encoder = PositionalEncoding(d_model=d_model,
                                              dropout=dropout)

        encoder_layer = TransformerEncoderLayer(d_model=d_model,
                                                nhead=num_heads,
                                                dim_feedforward=dim_ff,
                                                dropout=dropout,
                                                batch_first=True)  # define one layer of encoder multi-head attention

        self.encoder = TransformerEncoder(encoder_layer=encoder_layer,
                                          num_layers=num_layers)  # chain multiple layers of encoder multi-head attention

        self.flatten = Flatten(start_dim=1,
                               end_dim=-1)

        self.mean = Linear(in_features=d_model*seq_len,
                                out_features=act_dim)  # project d_model dimension to act_dim dimension
        self.log_std = Linear(in_features=d_model*seq_len,
                               out_features=1) 

    def forward(self, x):
        x = self.embedding(x)*math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        feature = self.encoder(x)
        feature = self.flatten(feature)
        return self.mean(feature), self.log_std(feature)

In [4]:
def train(lr: float = 3e-5,
          weight_decay: float = 1e-7,
          batch_size: int = 256,
          seq_len: int = 16,
          epochs: int = 100,
          seed: int = 42,
          log_freq: int = 5):

    torch.manual_seed(seed)
    dataset = StackDatasetOriginalSequential(seq_len=seq_len, train=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = BC(seq_len=seq_len).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    criterion = nn.MSELoss(reduction='mean')

    train_epoch_idx = []
    train_losses = []
    validation_epoch_idx = []
    validation_losses = []
    best_ckpt = None
    best_loss = np.inf

    writer = SummaryWriter(tensorboard_path)

    for epoch in tqdm(range(epochs)):
        if epoch % log_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(ckpt_path, f'bc_{epoch}.pt'))
            validation_loss = validate(model, seq_len)
            validation_epoch_idx.append(epoch)
            validation_losses.append(validation_loss)
            writer.add_scalar('Loss/Validation', validation_loss, epoch)
            model.train()
            if validation_loss < best_loss:
                best_loss = validation_loss
                best_ckpt = os.path.join(ckpt_path, f'bc_{epoch}.pt')

        for obs, action in dataloader:
            obs = obs.to(device)
            action = action.to(device)

            mean, log_std = model(obs)
            dist = Normal(mean, torch.exp(log_std))
            pred = dist.rsample()
            train_loss = criterion(pred, action)

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        train_epoch_idx.append(epoch + 1)
        train_losses.append(train_loss.item())
        writer.add_scalar('Loss/Train', train_loss.item(), epoch + 1)

    torch.save(model.state_dict(), os.path.join(ckpt_path, f'bc_{epoch+1}.pt'))
    validation_loss = validate(model, seq_len)
    validation_epoch_idx.append(epoch+1)
    validation_losses.append(validation_loss)
    writer.add_scalar('Loss/Validation', validation_loss, epoch+1)
    if validation_loss < best_loss:
        best_loss = validation_loss
        best_ckpt = os.path.join(ckpt_path, f'bc_{epoch+1}.pt')

    log = dict(train_epochs=train_epoch_idx,
               validation_epochs=validation_epoch_idx,
               train_losses=train_losses,
               validation_losses=validation_losses,
               best_ckpt=best_ckpt,
               best_loss=best_loss,
               lr=lr,
               weight_decay=weight_decay,
               batch_size=batch_size,
               epochs=epochs,
               seed=seed,
               log_freq=log_freq)

    with open(os.path.join(log_path, 'train_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()

    return best_ckpt


def validate(model: BC, seq_len: int):
    model.eval()
    dataset = StackDatasetOriginalSequential(seq_len=seq_len, train=False)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=False)
    criterion = nn.MSELoss(reduction='sum')
    losses = []
    with torch.no_grad():
        for obs, action in dataloader:
            obs = obs.to(device)
            action = action.to(device)

            mean, log_std = model(obs)
            dist = Normal(mean, torch.exp(log_std))
            pred = dist.sample()
            loss = criterion(pred, action)
            losses.append(loss.item())

    return np.sum(losses) / len(dataset)

In [5]:
def test(ckpt: str,
         seq_len: int,
         max_steps: int = 300,
         num_episodes: int = 100):

    env = gym.make('StackCube-v0',
                   obs_mode="state_dict",
                   control_mode="pd_joint_delta_pos",
                   reward_mode="sparse",
                   max_episode_steps=max_steps)

    model = BC(seq_len=seq_len)
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    best_return = -np.inf
    best_seed = None
    returns = {}
    success_seeds = []
    writer = SummaryWriter(tensorboard_path)

    for seed in tqdm(range(num_episodes)):
        obs, _ = env.reset(seed=seed)
        obs = flatten_obs(obs)
        buffer = init_deque(obs, seq_len)
        sequence = np.array(buffer)
        G = 0
        terminated = False
        truncated = False
        with torch.no_grad():
            while not terminated and not truncated:
                sequence = torch.from_numpy(sequence[None]).to(device)
                mean, log_std = model(sequence)
                dist = Normal(mean, torch.exp(log_std))
                action = dist.sample()
                action = action.cpu().numpy()
                obs, reward, terminated, truncated, info = env.step(action[0])
                obs = flatten_obs(obs)
                sequence = update_deque(obs=obs, window=buffer)
                G += reward

        if G > best_return:
            best_return = G
            best_seed = seed

        if info['success']:
            success_seeds.append(seed)

        returns[seed] = G
        writer.add_scalar('Return', G, seed)
    env.close()

    log = dict(returns=returns,
               best_seed=best_seed,
               best_return=best_return,
               max_steps=max_steps,
               num_episodes=num_episodes,
               success_seeds = success_seeds,
               success_rate = len(success_seeds) / num_episodes)

    with open(os.path.join(log_path, 'test_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()
    return success_seeds

In [6]:
def render_video(ckpt: str,
                 seq_len: int,
                 seed: int,
                 max_steps: int = 300):
    
    env = gym.make('StackCube-v0',
                render_mode="cameras",
                enable_shadow=True,
                obs_mode="state_dict",
                control_mode="pd_joint_delta_pos", 
                max_episode_steps=max_steps)

    env = RecordEpisode(
        env,
        log_path,
        info_on_video=True,
        save_trajectory=False
    )


    model = BC(seq_len=seq_len)
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    obs, _ = env.reset(seed=seed)
    obs = flatten_obs(obs)
    buffer = init_deque(obs, seq_len)
    sequence = np.array(buffer)
    terminated = False
    truncated = False
    
    with torch.no_grad():
        while not terminated and not truncated:
            sequence = torch.from_numpy(sequence[None]).to(device)
            mean, log_std = model(sequence)
            dist = Normal(mean, torch.exp(log_std))
            action = dist.sample()
            action = action.cpu().numpy()
            obs, reward, terminated, truncated, info = env.step(action[0])
            obs = flatten_obs(obs)
            sequence = update_deque(obs=obs, window=buffer)

    env.flush_video(suffix=f'BC_{seed}')
    env.close()

In [7]:
SEQ_LEN = 8
print('Training...')
best_ckpt = train(seq_len=SEQ_LEN, epochs=200)
print('Testing...')
success_seeds = test(ckpt=best_ckpt, seq_len=SEQ_LEN)
print('Rendering...')
for seed in success_seeds:
    render_video(ckpt=best_ckpt, seq_len=SEQ_LEN, seed=seed, max_steps=500)
print('Done')

Training...


100%|██████████| 200/200 [37:47<00:00, 11.34s/it]


Testing...


[2023-11-11 00:49:05.376] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing
100%|██████████| 100/100 [03:27<00:00,  2.08s/it]


Rendering...


[2023-11-11 00:52:38.573] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing
[2023-11-11 00:52:49.271] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing
[2023-11-11 00:53:12.370] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing
[2023-11-11 00:53:35.084] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing


Done
