In [1]:
import json
import os
from pathlib import Path

import gymnasium as gym
import mani_skill2.envs
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from mani_skill2.utils.wrappers import RecordEpisode
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from tqdm import tqdm

from data.dataset import StackDatasetOriginal
from utils.data_utils import flatten_obs, make_path

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_path = make_path('BC_MLP', 'checkpoints')
log_path = make_path('BC_MLP', 'logs')
tensorboard_path = make_path('BC_MLP', 'logs', 'tensorboard')

Path(ckpt_path).mkdir(exist_ok=True, parents=True)
Path(log_path).mkdir(exist_ok=True, parents=True)
Path(tensorboard_path).mkdir(exist_ok=True, parents=True)

In [3]:
class BC(nn.Module):
    def __init__(self, obs_dim = 55, act_dim = 8):
        super(BC, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(obs_dim, 512),
            nn.Mish(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.Mish(),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 512),
            nn.Mish(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 256),
            nn.Mish(),
            nn.BatchNorm1d(256),
            nn.Linear(256, act_dim)
        )

    def forward(self, x):
        return self.mlp(x)

In [4]:
def train(lr: float = 1e-4,
          weight_decay: float = 1e-7,
          batch_size: int = 256,
          epochs: int = 100,
          seed: int = 42,
          log_freq: int = 5):

    torch.manual_seed(seed)
    dataset = StackDatasetOriginal(train=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = BC().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    criterion = nn.MSELoss(reduction='mean')

    train_epoch_idx = []
    train_losses = []
    validation_epoch_idx = []
    validation_losses = []
    best_ckpt = None
    best_loss = np.inf

    writer = SummaryWriter(tensorboard_path)
    writer.add_graph(model, torch.zeros(1, 55).to(device))

    summary(model, (55,))
    
    for epoch in tqdm(range(epochs)):
        if epoch % log_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(ckpt_path, f'bc_{epoch}.pt'))
            validation_loss = validate(model)
            validation_epoch_idx.append(epoch)
            validation_losses.append(validation_loss)
            writer.add_scalar('Loss/Validation', validation_loss, epoch)
            model.train()
            if validation_loss < best_loss:
                best_loss = validation_loss
                best_ckpt = os.path.join(ckpt_path, f'bc_{epoch}.pt')

        for obs, action in dataloader:
            obs = obs.to(device)
            action = action.to(device)

            pred = model(obs)
            train_loss = criterion(pred, action)

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        train_epoch_idx.append(epoch + 1)
        train_losses.append(train_loss.item())
        writer.add_scalar('Loss/Train', train_loss.item(), epoch + 1)

    torch.save(model.state_dict(), os.path.join(ckpt_path, f'bc_{epoch+1}.pt'))
    validation_loss = validate(model)
    validation_epoch_idx.append(epoch+1)
    validation_losses.append(validation_loss)
    writer.add_scalar('Loss/Validation', validation_loss, epoch+1)
    if validation_loss < best_loss:
        best_loss = validation_loss
        best_ckpt = os.path.join(ckpt_path, f'bc_{epoch+1}.pt')

    log = dict(train_epochs=train_epoch_idx,
               validation_epochs=validation_epoch_idx,
               train_losses=train_losses,
               validation_losses=validation_losses,
               best_ckpt=best_ckpt,
               best_loss=best_loss,
               lr=lr,
               weight_decay=weight_decay,
               batch_size=batch_size,
               epochs=epochs,
               seed=seed,
               log_freq=log_freq)

    with open(os.path.join(log_path, 'train_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()

    return best_ckpt


def validate(model: BC):
    model.eval()
    dataset = StackDatasetOriginal(train=False)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=False)
    criterion = nn.MSELoss(reduction='sum')
    losses = []
    with torch.no_grad():
        for obs, action in dataloader:
            obs = obs.to(device)
            action = action.to(device)

            pred = model(obs)
            loss = criterion(pred, action)
            losses.append(loss.item())

    return np.sum(losses) / len(dataset)

In [5]:
def test(ckpt: str,
         max_steps: int = 300,
         num_episodes: int = 100):

    env = gym.make('StackCube-v0',
                   obs_mode="state_dict",
                   control_mode="pd_joint_delta_pos",
                   max_episode_steps=max_steps)

    model = BC()
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    best_return = -np.inf
    best_seed = None
    returns = {}
    success = 0
    writer = SummaryWriter(tensorboard_path)

    for seed in tqdm(range(num_episodes)):
        obs, _ = env.reset(seed=seed)
        G = 0
        terminated = False
        truncated = False
        with torch.no_grad():
            while not terminated and not truncated:
                obs = flatten_obs(obs)
                obs = torch.from_numpy(obs[None]).to(device)
                action = model(obs)
                action = action.detach().cpu().numpy()
                obs, reward, terminated, truncated, info = env.step(action[0])
                G += reward

        if G > best_return:
            best_return = G
            best_seed = seed

        if info['success']:
            success += 1

        returns[seed] = G
        writer.add_scalar('Return', G, seed)
    env.close()

    log = dict(returns=returns,
               best_seed=best_seed,
               best_return=best_return,
               max_steps=max_steps,
               num_episodes=num_episodes,
               success_rate = success / num_episodes)

    with open(os.path.join(log_path, 'test_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()
    return best_seed

In [6]:
def render_video(ckpt: str,
                 seed: int,
                 max_steps: int = 300):
    
    env = gym.make('StackCube-v0',
                render_mode="cameras",
                enable_shadow=True,
                obs_mode="state_dict",
                control_mode="pd_joint_delta_pos", 
                max_episode_steps=max_steps)

    env = RecordEpisode(
        env,
        log_path,
        info_on_video=True,
        save_trajectory=False
    )


    model = BC()
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    obs, _ = env.reset(seed=seed)
    terminated = False
    truncated = False
    
    with torch.no_grad():
        while not terminated and not truncated:
            obs = flatten_obs(obs)
            obs = torch.from_numpy(obs[None]).to(device)
            action = model(obs)
            action = action.detach().cpu().numpy()
            obs, reward, terminated, truncated, info = env.step(action[0])

    env.flush_video(suffix=f'BC_{seed}')
    env.close()

In [7]:
print('Training...')
best_ckpt = train()
print('Testing...')
best_seed = test(ckpt=best_ckpt)
print('Rendering...')
render_video(ckpt=best_ckpt, seed=best_seed)
print('Done')

Training...
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]          28,672
              Mish-2                  [-1, 512]               0
       BatchNorm1d-3                  [-1, 512]           1,024
            Linear-4                 [-1, 1024]         525,312
              Mish-5                 [-1, 1024]               0
       BatchNorm1d-6                 [-1, 1024]           2,048
            Linear-7                  [-1, 512]         524,800
              Mish-8                  [-1, 512]               0
       BatchNorm1d-9                  [-1, 512]           1,024
           Linear-10                  [-1, 256]         131,328
             Mish-11                  [-1, 256]               0
      BatchNorm1d-12                  [-1, 256]             512
           Linear-13                    [-1, 8]           2,056
Total params: 1,216,776
Tra

100%|██████████| 100/100 [04:55<00:00,  2.95s/it]


Testing...


[2023-10-30 18:55:10.885] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing
100%|██████████| 100/100 [02:41<00:00,  1.62s/it]


Rendering...


[2023-10-30 18:57:57.711] [svulkan2] [error] GLFW error: X11: The DISPLAY environment variable is missing


Done
