In [None]:
import json # for loading JSON
import os
from pathlib import Path

import gymnasium as gym
import mani_skill2.envs
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from mani_skill2.utils.wrappers import RecordEpisode
from torch.nn.functional import relu
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from tqdm import tqdm

from data.dataset import StackDatasetOriginalSequential
from utils.data_utils import flatten_obs, make_path
from utils.train_utils import init_deque, update_deque

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_path = make_path('BC_LSTM', 'checkpoints')
log_path = make_path('BC_LSTM', 'logs')
tensorboard_path = make_path('BC_LSTM', 'logs', 'tensorboard')

Path(ckpt_path).mkdir(exist_ok=True, parents=True)
Path(log_path).mkdir(exist_ok=True, parents=True)
Path(tensorboard_path).mkdir(exist_ok=True, parents=True)

In [None]:
class BC(nn.Module):
    def __init__(self, obs_dim=55, 
                 act_dim=8, 
                 hidden_size=256, 
                 num_layers=4):
        
        super(BC, self).__init__()
        self.lstm = nn.LSTM(obs_dim, 
                            hidden_size=hidden_size,
                            num_layers=num_layers, 
                            batch_first=True)

        self.fc = nn.Linear(hidden_size*2, act_dim)

    def forward(self, x):
        out, (hidden, cell) = self.lstm(x)
        feature = torch.cat((hidden[-1], cell[-1]), dim=1)
        return self.fc(feature)
    
bc = BC()
bc(torch.randn(100, 20, 55))

In [25]:
def train(lr: float = 2e-4,
          weight_decay: float = 2e-6,
          batch_size: int = 256,
          seq_len: int = 8,
          epochs: int = 100,
          seed: int = 42,
          log_freq: int = 5):

    torch.manual_seed(seed)
    dataset = StackDatasetOriginalSequential(seq_len=seq_len, train=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = BC().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    criterion = nn.MSELoss(reduction='mean')

    train_epoch_idx = []
    train_losses = []
    validation_epoch_idx = []
    validation_losses = []
    best_ckpt = None
    best_loss = np.inf

    writer = SummaryWriter(tensorboard_path)
    writer.add_graph(model, torch.zeros(1, seq_len, 55).to(device))

    # summary(model, (SEQ_LEN, 55))

    for epoch in tqdm(range(epochs)):
        if epoch % log_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(ckpt_path, f'bc_{epoch}.pt'))
            validation_loss = validate(model, seq_len)
            validation_epoch_idx.append(epoch)
            validation_losses.append(validation_loss)
            writer.add_scalar('Loss/Validation', validation_loss, epoch)
            model.train()
            if validation_loss < best_loss:
                best_loss = validation_loss
                best_ckpt = os.path.join(ckpt_path, f'bc_{epoch}.pt')

        for obs, action in dataloader:
            obs = obs.to(device)
            action = action.to(device)

            pred = model(obs)
            train_loss = criterion(pred, action)

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        train_epoch_idx.append(epoch + 1)
        train_losses.append(train_loss.item())
        writer.add_scalar('Loss/Train', train_loss.item(), epoch + 1)

    torch.save(model.state_dict(), os.path.join(ckpt_path, f'bc_{epoch+1}.pt'))
    validation_loss = validate(model, seq_len)
    validation_epoch_idx.append(epoch+1)
    validation_losses.append(validation_loss)
    writer.add_scalar('Loss/Validation', validation_loss, epoch+1)
    if validation_loss < best_loss:
        best_loss = validation_loss
        best_ckpt = os.path.join(ckpt_path, f'bc_{epoch+1}.pt')

    log = dict(train_epochs=train_epoch_idx,
               validation_epochs=validation_epoch_idx,
               train_losses=train_losses,
               validation_losses=validation_losses,
               best_ckpt=best_ckpt,
               best_loss=best_loss,
               lr=lr,
               weight_decay=weight_decay,
               batch_size=batch_size,
               epochs=epochs,
               seed=seed,
               log_freq=log_freq)

    with open(os.path.join(log_path, 'train_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()

    return best_ckpt


def validate(model: BC, seq_len: int):
    model.eval()
    dataset = StackDatasetOriginalSequential(seq_len=seq_len, train=False)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=False)
    criterion = nn.MSELoss(reduction='sum')
    losses = []
    with torch.no_grad():
        for obs, action in dataloader:
            obs = obs.to(device)
            action = action.to(device)

            pred = model(obs)
            loss = criterion(pred, action)
            losses.append(loss.item())

    return np.sum(losses) / len(dataset)

In [26]:
def test(ckpt: str,
         seq_len: int,
         max_steps: int = 300,
         num_episodes: int = 100):

    env = gym.make('StackCube-v0',
                   obs_mode="state_dict",
                   control_mode="pd_joint_delta_pos",
                   max_episode_steps=max_steps)

    model = BC()
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    best_return = -np.inf
    best_seed = None
    returns = {}
    success = 0
    writer = SummaryWriter(tensorboard_path)

    for seed in tqdm(range(num_episodes)):
        obs, _ = env.reset(seed=seed)
        obs = flatten_obs(obs)
        buffer = init_deque(obs, seq_len)
        sequence = np.array(buffer)
        G = 0
        terminated = False
        truncated = False
        with torch.no_grad():
            while not terminated and not truncated:
                sequence = torch.from_numpy(sequence[None]).to(device)
                action = model(sequence)
                action = action.detach().cpu().numpy()
                obs, reward, terminated, truncated, info = env.step(action[0])
                obs = flatten_obs(obs)
                sequence = update_deque(obs=obs, window=buffer)
                G += reward

        if G > best_return:
            best_return = G
            best_seed = seed

        if info['success']:
            success += 1

        returns[seed] = G
        writer.add_scalar('Return', G, seed)
    env.close()

    log = dict(returns=returns,
               best_seed=best_seed,
               best_return=best_return,
               max_steps=max_steps,
               num_episodes=num_episodes,
               success_rate = success / num_episodes)

    with open(os.path.join(log_path, 'test_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()
    return best_seed

In [27]:
def render_video(ckpt: str,
                 seq_len: int,
                 seed: int,
                 max_steps: int = 300):
    
    env = gym.make('StackCube-v0',
                render_mode="cameras",
                enable_shadow=True,
                obs_mode="state_dict",
                control_mode="pd_joint_delta_pos", 
                max_episode_steps=max_steps)

    env = RecordEpisode(
        env,
        log_path,
        info_on_video=True,
        save_trajectory=False
    )


    model = BC()
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    obs, _ = env.reset(seed=seed)
    obs = flatten_obs(obs)
    buffer = init_deque(obs, seq_len)
    sequence = np.array(buffer)
    terminated = False
    truncated = False
    
    with torch.no_grad():
        while not terminated and not truncated:
            sequence = torch.from_numpy(sequence[None]).to(device)
            action = model(sequence)
            action = action.detach().cpu().numpy()
            obs, reward, terminated, truncated, info = env.step(action[0])
            obs = flatten_obs(obs)
            sequence = update_deque(obs=obs, window=buffer)

    env.flush_video(suffix=f'BC_{seed}')
    env.close()

In [28]:
SEQ_LEN = 20
print('Training...')
best_ckpt = train(seq_len=SEQ_LEN, epochs=150)
print('Testing...')
best_seed = test(ckpt=best_ckpt, seq_len=SEQ_LEN)
print('Rendering...')
render_video(ckpt=best_ckpt, seq_len=SEQ_LEN, seed=best_seed, max_steps=500)
print('Done')

Training...


 15%|█▌        | 23/150 [04:05<22:42, 10.73s/it]