In [1]:
pip install panda_gym torchvision av

Collecting panda_gym
  Downloading panda_gym-3.0.7-py3-none-any.whl.metadata (4.3 kB)
Collecting av
  Downloading av-14.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.5 kB)
Collecting pybullet (from panda_gym)
  Downloading pybullet-3.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Downloading panda_gym-3.0.7-py3-none-any.whl (23 kB)
Downloading av-14.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (33.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.0/33.0 MB[0m [31m58.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading pybullet-3.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (103.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.2/103.2 MB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: pybullet, av, panda_gym
Successfully installed av-14.0.1 panda_gym-3.0.7 pybullet-3.2.6
Note: y

In [3]:
import gymnasium as gym
import panda_gym

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.callbacks import BaseCallback

import torch
import torch.nn as nn
import torchvision.models as models

import panda_ppo_utils
from panda_ppo_utils import (
                             ObservationWrapper, 
                             FeaturesExtractor,
                             ValidationCallback,
)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
resnet18 = models.resnet18(pretrained=True)
feature_extractor = torch.nn.Sequential(*list(resnet18.children())[:-1]) 
feature_extractor.eval()  

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 171MB/s] 


Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

In [6]:
train_env = gym.make("PandaPickAndPlace-v3", render_mode="rgb_array")
train_env = ObservationWrapper(train_env)

val_env = gym.make("PandaPickAndPlace-v3", render_mode="rgb_array")
val_env = ObservationWrapper(val_env)

policy_kwargs = dict(
    features_extractor_class=FeaturesExtractor,
    features_extractor_kwargs=dict(feature_extractor=feature_extractor, features_dim=258),  
    net_arch=dict(pi=[128], vf=[128])
)

model = PPO(
    policy="CnnPolicy",
    env=train_env,
    policy_kwargs=policy_kwargs,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    learning_rate=3e-4,
    gamma=0.99,
    gae_lambda=0.95,
    verbose=1,
    device=device
)

validation_callback = ValidationCallback(
    validation_env=val_env,
    validation_steps=20000, 
    output_filename="validation_video",
    fps=10,
    max_steps=200,
    verbose=1
)

pybullet build time: Nov 28 2023 23:45:17


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [14]:
model.learn(total_timesteps=1000000, callback=validation_callback)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 50       |
|    ep_rew_mean     | -50      |
| time/              |          |
|    fps             | 4        |
|    iterations      | 1        |
|    time_elapsed    | 491      |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 49.4         |
|    ep_rew_mean          | -49.4        |
| time/                   |              |
|    fps                  | 3            |
|    iterations           | 2            |
|    time_elapsed         | 1040         |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0061873477 |
|    clip_fraction        | 0.0669       |
|    clip_range           | 0.2          |
|    e

KeyboardInterrupt: 