In [3]:
import json # for loading json file
import os
from pathlib import Path

import gymnasium as gym
import mani_skill2.envs
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from mani_skill2.utils.wrappers import RecordEpisode
from torch.nn.functional import relu
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from tqdm import tqdm

from data.dataset import StackDatasetOriginalSequential
from utils.data_utils import flatten_obs, make_path
from utils.train_utils import init_deque, update_deque

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_path = make_path('BC_LSTM', 'checkpoints')
log_path = make_path('BC_LSTM', 'logs')
tensorboard_path = make_path('BC_LSTM', 'logs', 'tensorboard')

Path(ckpt_path).mkdir(exist_ok=True, parents=True)
Path(log_path).mkdir(exist_ok=True, parents=True)
Path(tensorboard_path).mkdir(exist_ok=True, parents=True)

In [7]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class BC(nn.Module):
    def __init__(self, obs_dim=55, act_dim=8, hidden_size=256, num_layers=4, num_heads=8, k=10):
        super(BC, self).__init__()

        self.obs_dim = obs_dim
        self.k = k
        self.positional_encoding = PositionalEncoding(d_model=hidden_size)

        # Transformer Encoder
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=hidden_size, 
            nhead=num_heads, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

        self.fc_in = nn.Linear(obs_dim * k, hidden_size)  # To match the Transformer input dimension
        self.fc_out = nn.Linear(hidden_size, act_dim)
    
    def forward(self, x):
        # Flatten the last k observations
        x = x.view(-1, self.k * self.obs_dim)

        # Pass through a fully connected layer to get the right dimension
        x = self.fc_in(x)
        x = x.unsqueeze(1)  # Add a dummy sequence length dimension expected by Transformer

        # Add positional encoding
        x = self.positional_encoding(x)

        # Transformer expects a sequence, but we're handling individual observations
        # So we treat each batch of k observations as a 'sequence' of length 1
        transformer_output = self.transformer_encoder(x)

        # Since we have a 'sequence' of length 1, we just select that one output
        transformer_output = transformer_output.squeeze(1)

        # And finally, pass it through the output fully connected layer
        action = self.fc_out(transformer_output)

        return action

# Initialize the model
bc_transformer = BCTransformer()

# Create a dummy input tensor of the correct shape
# Assuming input is (batch_size, k, obs_dim), where k is the number of past observations considered
dummy_input = torch.randn(100, 10, 55)  # 100 examples, each with 10 past observations

# Forward pass
output = bc_transformer(dummy_input)
print(output.shape)  # Should print torch.Size([100, 8]), assuming act_dim is the output dimension

bc=BC()
bc(torch.randn(100, 20, 55))

tensor([[-5.8444e-02, -3.2075e-02,  4.4568e-02,  8.4997e-03, -3.8891e-02,
         -7.3362e-02,  9.5360e-03, -1.8447e-03],
        [-5.8006e-02, -3.2533e-02,  4.3997e-02,  8.7425e-03, -3.8996e-02,
         -7.3627e-02,  8.0062e-03, -3.7727e-03],
        [-5.8667e-02, -2.9730e-02,  4.3368e-02,  7.6694e-03, -3.9701e-02,
         -7.5524e-02,  1.2014e-02, -2.7713e-03],
        [-5.9446e-02, -3.2544e-02,  4.4476e-02,  6.3791e-03, -3.7693e-02,
         -7.3764e-02,  1.0237e-02, -1.1036e-03],
        [-5.7198e-02, -3.1380e-02,  4.3718e-02,  5.8235e-03, -3.9074e-02,
         -7.4221e-02,  9.8430e-03, -1.2061e-03],
        [-6.0016e-02, -3.0119e-02,  4.4182e-02,  7.4890e-03, -3.8501e-02,
         -7.6406e-02,  1.2176e-02, -1.3518e-03],
        [-5.8501e-02, -3.1601e-02,  4.2506e-02,  9.0375e-03, -3.9788e-02,
         -7.4813e-02,  9.1560e-03, -2.8596e-03],
        [-5.8827e-02, -3.1592e-02,  4.4111e-02,  7.8881e-03, -3.7627e-02,
         -7.3742e-02,  9.8016e-03, -6.4881e-04],
        [-5.7797

In [8]:
def train(lr: float = 2e-4,
          weight_decay: float = 2e-6,
          batch_size: int = 256,
          seq_len: int = 8,
          epochs: int = 100,
          seed: int = 42,
          log_freq: int = 5):

    torch.manual_seed(seed)
    dataset = StackDatasetOriginalSequential(seq_len=seq_len, train=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = BC().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)
    criterion = nn.MSELoss(reduction='mean')

    train_epoch_idx = []
    train_losses = []
    validation_epoch_idx = []
    validation_losses = []
    best_ckpt = None
    best_loss = np.inf

    writer = SummaryWriter(tensorboard_path)
    writer.add_graph(model, torch.zeros(1, seq_len, 55).to(device))

    # summary(model, (SEQ_LEN, 55))

    for epoch in tqdm(range(epochs)):
        if epoch % log_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(ckpt_path, f'bc_{epoch}.pt'))
            validation_loss = validate(model, seq_len)
            validation_epoch_idx.append(epoch)
            validation_losses.append(validation_loss)
            writer.add_scalar('Loss/Validation', validation_loss, epoch)
            model.train()
            if validation_loss < best_loss:
                best_loss = validation_loss
                best_ckpt = os.path.join(ckpt_path, f'bc_{epoch}.pt')

        for obs, action in dataloader:
            obs = obs.to(device)
            action = action.to(device)

            pred = model(obs)
            train_loss = criterion(pred, action)

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        train_epoch_idx.append(epoch + 1)
        train_losses.append(train_loss.item())
        writer.add_scalar('Loss/Train', train_loss.item(), epoch + 1)

    torch.save(model.state_dict(), os.path.join(ckpt_path, f'bc_{epoch+1}.pt'))
    validation_loss = validate(model, seq_len)
    validation_epoch_idx.append(epoch+1)
    validation_losses.append(validation_loss)
    writer.add_scalar('Loss/Validation', validation_loss, epoch+1)
    if validation_loss < best_loss:
        best_loss = validation_loss
        best_ckpt = os.path.join(ckpt_path, f'bc_{epoch+1}.pt')

    log = dict(train_epochs=train_epoch_idx,
               validation_epochs=validation_epoch_idx,
               train_losses=train_losses,
               validation_losses=validation_losses,
               best_ckpt=best_ckpt,
               best_loss=best_loss,
               lr=lr,
               weight_decay=weight_decay,
               batch_size=batch_size,
               epochs=epochs,
               seed=seed,
               log_freq=log_freq)

    with open(os.path.join(log_path, 'train_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()

    return best_ckpt


def validate(model: BC, seq_len: int):
    model.eval()
    dataset = StackDatasetOriginalSequential(seq_len=seq_len, train=False)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=False)
    criterion = nn.MSELoss(reduction='sum')
    losses = []
    with torch.no_grad():
        for obs, action in dataloader:
            obs = obs.to(device)
            action = action.to(device)

            pred = model(obs)
            loss = criterion(pred, action)
            losses.append(loss.item())

    return np.sum(losses) / len(dataset)

In [9]:
def test(ckpt: str,
         seq_len: int,
         max_steps: int = 300,
         num_episodes: int = 100):

    env = gym.make('StackCube-v0',
                   obs_mode="state_dict",
                   control_mode="pd_joint_delta_pos",
                   max_episode_steps=max_steps)

    model = BC()
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    best_return = -np.inf
    best_seed = None
    returns = {}
    success = 0
    writer = SummaryWriter(tensorboard_path)

    for seed in tqdm(range(num_episodes)):
        obs, _ = env.reset(seed=seed)
        obs = flatten_obs(obs)
        buffer = init_deque(obs, seq_len)
        sequence = np.array(buffer)
        G = 0
        terminated = False
        truncated = False
        with torch.no_grad():
            while not terminated and not truncated:
                sequence = torch.from_numpy(sequence[None]).to(device)
                action = model(sequence)
                action = action.detach().cpu().numpy()
                obs, reward, terminated, truncated, info = env.step(action[0])
                obs = flatten_obs(obs)
                sequence = update_deque(obs=obs, window=buffer)
                G += reward

        if G > best_return:
            best_return = G
            best_seed = seed

        if info['success']:
            success += 1

        returns[seed] = G
        writer.add_scalar('Return', G, seed)
    env.close()

    log = dict(returns=returns,
               best_seed=best_seed,
               best_return=best_return,
               max_steps=max_steps,
               num_episodes=num_episodes,
               success_rate = success / num_episodes)

    with open(os.path.join(log_path, 'test_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()
    return best_seed

In [10]:
def render_video(ckpt: str,
                 seq_len: int,
                 seed: int,
                 max_steps: int = 300):
    
    env = gym.make('StackCube-v0',
                render_mode="cameras",
                enable_shadow=True,
                obs_mode="state_dict",
                control_mode="pd_joint_delta_pos", 
                max_episode_steps=max_steps)

    env = RecordEpisode(
        env,
        log_path,
        info_on_video=True,
        save_trajectory=False
    )


    model = BC()
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    obs, _ = env.reset(seed=seed)
    obs = flatten_obs(obs)
    buffer = init_deque(obs, seq_len)
    sequence = np.array(buffer)
    terminated = False
    truncated = False
    
    with torch.no_grad():
        while not terminated and not truncated:
            sequence = torch.from_numpy(sequence[None]).to(device)
            action = model(sequence)
            action = action.detach().cpu().numpy()
            obs, reward, terminated, truncated, info = env.step(action[0])
            obs = flatten_obs(obs)
            sequence = update_deque(obs=obs, window=buffer)

    env.flush_video(suffix=f'BC_{seed}')
    env.close()

In [11]:
SEQ_LEN = 20
print('Training...')
best_ckpt = train(seq_len=SEQ_LEN, epochs=150)
print('Testing...')
best_seed = test(ckpt=best_ckpt, seq_len=SEQ_LEN)
print('Rendering...')
render_video(ckpt=best_ckpt, seq_len=SEQ_LEN, seed=best_seed, max_steps=500)
print('Done')

Training...


 99%|█████████▉| 149/150 [21:26<00:08,  8.08s/it]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class BCTransformer(nn.Module):
    def __init__(self, obs_dim=55, act_dim=8, hidden_size=256, num_layers=4, num_heads=8, k=10):
        super(BCTransformer, self).__init__()

        self.obs_dim = obs_dim
        self.k = k
        self.positional_encoding = PositionalEncoding(d_model=hidden_size)

        # Transformer Encoder
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=hidden_size, 
            nhead=num_heads, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

        self.fc_in = nn.Linear(obs_dim * k, hidden_size)  # To match the Transformer input dimension
        self.fc_out = nn.Linear(hidden_size, act_dim)
    
    def forward(self, x):
        # Flatten the last k observations
        x = x.view(-1, self.k * self.obs_dim)

        # Pass through a fully connected layer to get the right dimension
        x = self.fc_in(x)
        x = x.unsqueeze(1)  # Add a dummy sequence length dimension expected by Transformer

        # Add positional encoding
        x = self.positional_encoding(x)

        # Transformer expects a sequence, but we're handling individual observations
        # So we treat each batch of k observations as a 'sequence' of length 1
        transformer_output = self.transformer_encoder(x)

        # Since we have a 'sequence' of length 1, we just select that one output
        transformer_output = transformer_output.squeeze(1)

        # And finally, pass it through the output fully connected layer
        action = self.fc_out(transformer_output)

        return action

# Initialize the model
bc_transformer = BCTransformer()

# Create a dummy input tensor of the correct shape
# Assuming input is (batch_size, k, obs_dim), where k is the number of past observations considered
dummy_input = torch.randn(100, 10, 55)  # 100 examples, each with 10 past observations

# Forward pass
output = bc_transformer(dummy_input)
print(output.shape)  # Should print torch.Size([100, 8]), assuming act_dim is the output dimension
