In [36]:
import json
import math
import os
from pathlib import Path

import gymnasium as gym
import mani_skill2.envs
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from mani_skill2.utils.wrappers import RecordEpisode
from torch.nn import (Flatten, Linear, TransformerEncoder,
                      TransformerEncoderLayer)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils.data_utils import flatten_obs, make_path
from utils.train_utils import init_deque, update_deque
from torch.distributions import Normal
from collections import deque, namedtuple

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_path = make_path('PPO', 'checkpoints')
log_path = make_path('PPO', 'logs')
tensorboard_path = make_path('PPO', 'logs', 'tensorboard')

Path(ckpt_path).mkdir(exist_ok=True, parents=True)
Path(log_path).mkdir(exist_ok=True, parents=True)
Path(tensorboard_path).mkdir(exist_ok=True, parents=True)

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_value', 'log_prob', 'terminated_next', 'truncated_next'))

In [38]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 d_model: int,
                 dropout: float = 0.1,
                 max_len: int = 100):

        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)
                             * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)  # [seq_len, d_model]
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:x.size(1)]
        return self.dropout(x)


class Policy(nn.Module):
    def __init__(self,
                 seq_len=8,
                 obs_dim=55,
                 act_dim=8,
                 dropout=0.1,
                 d_model=128,
                 dim_ff=128,
                 num_heads=8,
                 num_layers=3):
        super(Policy, self).__init__()

        self.d_model = d_model

        self.embedding = Linear(in_features=obs_dim,
                                out_features=d_model)  # project obs dimension to d_model dimension

        self.pos_encoder = PositionalEncoding(d_model=d_model,
                                              dropout=dropout)

        encoder_layer = TransformerEncoderLayer(d_model=d_model,
                                                nhead=num_heads,
                                                dim_feedforward=dim_ff,
                                                dropout=dropout,
                                                batch_first=True)  # define one layer of encoder multi-head attention

        self.encoder = TransformerEncoder(encoder_layer=encoder_layer,
                                          num_layers=num_layers)  # chain multiple layers of encoder multi-head attention

        self.flatten = Flatten(start_dim=1,
                               end_dim=-1)

        self.mean = Linear(in_features=d_model*seq_len,
                                out_features=act_dim)  # project d_model dimension to act_dim dimension
        self.log_std = Linear(in_features=d_model*seq_len,
                               out_features=1) 

    def forward(self, x):
        x = self.embedding(x)*math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        feature = self.encoder(x)
        feature = self.flatten(feature)
        return self.mean(feature), self.log_std(feature)

class Value(nn.Module):
    def __init__(self,
                 seq_len=8,
                 obs_dim=55,
                 act_dim=8,
                 dropout=0.1,
                 d_model=128,
                 dim_ff=128,
                 num_heads=8,
                 num_layers=3):
        super(Value, self).__init__()

        self.d_model = d_model

        self.embedding = Linear(in_features=obs_dim,
                                out_features=d_model)  # project obs dimension to d_model dimension

        self.pos_encoder = PositionalEncoding(d_model=d_model,
                                              dropout=dropout)

        encoder_layer = TransformerEncoderLayer(d_model=d_model,
                                                nhead=num_heads,
                                                dim_feedforward=dim_ff,
                                                dropout=dropout,
                                                batch_first=True)  # define one layer of encoder multi-head attention

        self.encoder = TransformerEncoder(encoder_layer=encoder_layer,
                                          num_layers=num_layers)  # chain multiple layers of encoder multi-head attention

        self.flatten = Flatten(start_dim=1,
                               end_dim=-1)
        
        self.linear = Linear(in_features=d_model*seq_len,
                               out_features=1) 

    def forward(self, x):
        x = self.embedding(x)*math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        feature = self.encoder(x)
        feature = self.flatten(feature)
        return self.linear(feature)

In [39]:
class PPO:
    def __init__(self,
                 env: gym.Env,
                 gamma: float = 0.99,
                 seq_len: int = 8,
                 seed: int = 42,
                 pol_ckpt: str = None) -> None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        self._env = env
        self._gamma = gamma
        self._seq_len = seq_len
        self._pol_net = Policy(seq_len=seq_len).to(device)
        self._val_net = Value(seq_len=seq_len).to(device)
        if pol_ckpt is not None:
            self._pol_net.load_state_dict(torch.load(pol_ckpt))

    def _policy(self,
                obs: np.ndarray) -> tuple:
        obs = torch.FloatTensor(obs).to(device)
        mean, log_std = self._pol_net(obs)
        std = torch.exp(log_std)
        normal = Normal(mean, std)
        action = normal.sample()
        log_prob = normal.log_prob(action)
        return action, log_prob

    def _value(self,
               obs: np.ndarray) -> torch.Tensor:
        obs = torch.FloatTensor(obs).to(device)
        return self._val_net(obs)

    def _unroll(self,
                steps: int,
                buffer: deque) -> tuple:
        state = torch.tensor(buffer).unsqueeze(0)
        trajectory = []
        with torch.no_grad():
            for _ in range(steps):
                # action: [1, 8], log_prob: [1, 8]
                action, log_prob = self._policy(state)
                obs, reward, terminated, truncated, info = self._env.step(
                    action[0].cpu().numpy())
                obs = flatten_obs(obs)
                next_state = torch.tensor(
                    update_deque(obs, buffer)).unsqueeze(0)
                next_value = self._value(next_state)
                transition = Transition(
                    state=state,
                    action=action.cpu(),
                    reward=reward,
                    next_value=next_value.cpu(),
                    log_prob=log_prob.cpu(),
                    terminated_next=terminated,
                    truncated_next=truncated
                )
                trajectory.append(transition)
                if terminated or truncated:
                    obs, info = self._env.reset()
                    obs = flatten_obs(obs)
                    buffer = init_deque(obs, seq_len=8)
                    state = torch.tensor(buffer).unsqueeze(0)
                else:
                    state = next_state
        return trajectory, buffer

    def _compute_advantages(self,
                            deltas: torch.Tensor,
                            terminates: torch.Tensor,
                            truncates: torch.Tensor,
                            lmda: float) -> torch.Tensor:
        advantages = []
        advantage = 0
        for delta, terminate, truncate in zip(reversed(deltas), reversed(terminates), reversed(truncates)):
            advantage = delta + (1 - terminate) * (1 - truncate) * self._gamma * lmda * advantage
            advantages.append(advantage)
        advantages.reverse()
        advantages = torch.stack(advantages, dim=0)
        return advantages

    def _normalize_advantages(self,
                              advantages: torch.Tensor) -> torch.Tensor:
        return (advantages - advantages.mean(dim=0)) / (advantages.std(dim=0) + 1e-10)

    def _process_traj(self,
                      trajectory: list) -> tuple:
        states = []
        actions = []
        log_probs = []
        returns = []
        terminates = []
        truncates = []
        for transition in trajectory:
            states.append(transition.state)
            actions.append(transition.action)
            log_probs.append(transition.log_prob)
            reward = transition.reward
            terminated_next = transition.terminated_next
            terminates.append(terminated_next)
            truncated_next = transition.truncated_next
            truncates.append(truncated_next)
            value_next = transition.next_value
            target = reward + (1 - terminated_next) * self._gamma * value_next
            returns.append(target)
        states = torch.cat(states, dim=0)
        actions = torch.cat(actions, dim=0)
        log_probs = torch.cat(log_probs, dim=0)
        returns = torch.cat(returns, dim=0)
        terminates = np.array(terminates, dtype=np.int32)
        terminates = torch.from_numpy(terminates)
        truncates = np.array(truncates, dtype=np.int32)
        truncates = torch.from_numpy(truncates)
        return states, actions, log_probs, returns, terminates, truncates

    def train(self,
              grad_steps=10000,
              batch_size=128,
              pol_lr=1e-4,
              val_lr=3e-4,
              eps_clip=0.2,
              lmda=0.90,
              log_freq=10,
              ckpt_freq=100):

        writer = SummaryWriter(tensorboard_path)
        pol_optim = optim.Adam(self._pol_net.parameters(), lr=pol_lr)
        val_optim = optim.Adam(self._val_net.parameters(), lr=val_lr)

        obs, info = self._env.reset()
        obs = flatten_obs(obs)
        buffer = init_deque(obs, seq_len=self._seq_len)
        for grad_step in tqdm(range(1, grad_steps + 1)):

            traj, buffer = self._unroll(steps=batch_size, buffer=buffer)

            states, actions, old_log_probs, returns, terminates, truncates = self._process_traj(
                traj)

            states = states.to(device)
            actions = actions.to(device)
            old_log_probs = old_log_probs.to(device)
            returns = returns.to(device)
            terminates = terminates.to(device)
            truncates = truncates.to(device)

            means, log_stds = self._pol_net(states)
            normal = Normal(means, torch.exp(log_stds))
            new_log_probs = normal.log_prob(actions)
            ratio = torch.exp(new_log_probs - old_log_probs)

            deltas = returns - self._val_net(states)
            advantages = self._compute_advantages(deltas, 
                                                terminates, 
                                                truncates, 
                                                lmda)
            
            norm_adv = self._normalize_advantages(advantages).detach()
            surr1 = ratio * norm_adv
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * norm_adv
            pol_loss = -torch.min(surr1, surr2).mean()

            val_loss = advantages.pow(2).mean()

            loss = pol_loss + val_loss

            pol_optim.zero_grad()
            val_optim.zero_grad()
            loss.backward()
            pol_optim.step()
            val_optim.step()

            if grad_step % log_freq == 0:
                writer.add_scalar('Loss/Policy', pol_loss, grad_step)
                writer.add_scalar('Loss/Value', val_loss, grad_step)
            if grad_step % ckpt_freq == 0:
                torch.save(self._pol_net.state_dict(),
                           os.path.join(ckpt_path, f'pol_{grad_step}.pt'))
                torch.save(self._val_net.state_dict(),
                           os.path.join(ckpt_path, f'val_{grad_step}.pt'))
        writer.flush()
        writer.close()

In [40]:
def test(ckpt: str,
         seq_len: int,
         max_steps: int = 300,
         num_episodes: int = 100):

    env = gym.make('StackCube-v0',
                   obs_mode="state_dict",
                   control_mode="pd_joint_delta_pos",
                   render_mode="cameras",
                   reward_mode="normalized_dense",
                   max_episode_steps=max_steps)

    model = Policy(seq_len=seq_len)
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    best_return = -np.inf
    best_seed = None
    returns = {}
    success_seeds = []
    writer = SummaryWriter(tensorboard_path)

    for seed in tqdm(range(num_episodes)):
        obs, _ = env.reset(seed=seed)
        obs = flatten_obs(obs)
        buffer = init_deque(obs, seq_len)
        sequence = np.array(buffer)
        G = 0
        terminated = False
        truncated = False
        with torch.no_grad():
            while not terminated and not truncated:
                sequence = torch.from_numpy(sequence[None]).to(device)
                mean, log_std = model(sequence)
                dist = Normal(mean, torch.exp(log_std))
                action = dist.sample()
                action = action.cpu().numpy()
                obs, reward, terminated, truncated, info = env.step(action[0])
                obs = flatten_obs(obs)
                sequence = update_deque(obs=obs, window=buffer)
                G += reward

        if G > best_return:
            best_return = G
            best_seed = seed

        if info['success']:
            success_seeds.append(seed)

        returns[seed] = G
        writer.add_scalar('Return', G, seed)
    env.close()

    log = dict(returns=returns,
               best_seed=best_seed,
               best_return=best_return,
               max_steps=max_steps,
               num_episodes=num_episodes,
               success_seeds=success_seeds,
               success_rate=len(success_seeds) / num_episodes)

    with open(os.path.join(log_path, 'test_log.json'), 'w') as f:
        json.dump(log, f, indent=4)

    writer.flush()
    writer.close()
    return success_seeds

In [41]:
def render_video(ckpt: str,
                 seq_len: int,
                 seed: int,
                 max_steps: int = 300):
    
    env = gym.make('StackCube-v0',
                render_mode="cameras",
                enable_shadow=True,
                obs_mode="state_dict",
                control_mode="pd_joint_delta_pos", 
                max_episode_steps=max_steps)

    env = RecordEpisode(
        env,
        log_path,
        info_on_video=True,
        save_trajectory=False
    )


    model = Policy(seq_len=seq_len)
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    model.eval()

    obs, _ = env.reset(seed=seed)
    obs = flatten_obs(obs)
    buffer = init_deque(obs, seq_len)
    sequence = np.array(buffer)
    terminated = False
    truncated = False
    
    with torch.no_grad():
        while not terminated and not truncated:
            sequence = torch.from_numpy(sequence[None]).to(device)
            mean, log_std = model(sequence)
            dist = Normal(mean, torch.exp(log_std))
            action = dist.sample()
            action = action.cpu().numpy()
            obs, reward, terminated, truncated, info = env.step(action[0])
            obs = flatten_obs(obs)
            sequence = update_deque(obs=obs, window=buffer)

    env.flush_video(suffix=f'PPO_{seed}')
    env.close()

In [42]:
env = gym.make('StackCube-v0',
               obs_mode="state_dict",
               control_mode="pd_joint_delta_pos",
               reward_mode="normalized_dense",
               render_mode="cameras",
               max_episode_steps=250)
# init_ckpt = os.path.join('.', 'BC_Transformer_Gaussian', 'checkpoints', 'bc_195.pt')
agent = PPO(env=env, pol_ckpt=None)
agent.train(grad_steps=10000, batch_size=64)
env.close()

  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [1:35:39<00:00,  1.74it/s]


In [44]:
best_ckpt = './PPO/checkpoints/pol_10000.pt'
# success_seeds = test(best_ckpt, seq_len=8, num_episodes=100)
success_seeds = [43]
for seed in success_seeds:
    render_video(best_ckpt, seq_len=8, seed=seed)

