In [6]:
import numpy as np
import robosuite as suite
import gymnasium as gym
from robosuite.wrappers import GymWrapper, DomainRandomizationWrapper
from robosuite.controllers import load_part_controller_config, ALL_COMPOSITE_CONTROLLERS, ALL_PART_CONTROLLERS
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from utils import to_numpy
import imageio
from robosuite_testing import PPONetwork, RobosuitePolicy, RobosuiteValue
import skimage as sk
import skimage.io as skio
import mediapy as media

device = ('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
def transform_obs(obs, camera_name='agentview_image'):
    image = np.flip(obs[camera_name], 0).mean(-1, keepdims=True) / 255
    image = sk.transform.rescale(image, 0.4)
    proprio = obs['robot0_proprio-state'] * 10
    # 0-34 is up to eef_quat, 35-38 are quat site, 39-40 are gripper qpos and 41-42
    proprio = np.concatenate([proprio[:35], proprio[39:41]])
    dim = image.shape[0] * image.shape[1]
    new_channel = np.zeros(dim)
    new_channel[:proprio.size] = proprio
    new_channel = new_channel.reshape(image.shape[0], image.shape[1], 1)
    new_obs = np.concatenate([image, new_channel], axis=-1)
    return new_obs

def transform_framestacked_obs(obs):
    _, h, w, _ = obs.shape
    obs = obs.transpose(0, 3, 1, 2)
    obs = obs.reshape(-1, h, w)
    return obs

def make_env(env_name, camera_dim, seed, framestack=4, eval=False):
    def thunk():
        controller = load_part_controller_config(default_controller="OSC_POSE")
        env = suite.make(
            env_name=env_name, # try with other tasks like "Stack" and "Door"
            robots="Sawyer",  # try with other robots like "Sawyer" and "Jaco"
            # has_renderer=eval,
            # controller_configs=controller,
            render_collision_mesh=False,
            has_offscreen_renderer=True,
            use_camera_obs=True,
            camera_names=["agentview", 'frontview'],
            use_object_obs=False,
            # object_type='can',
            # single_object_mode=2,
            camera_heights=250,
            camera_widths=250,
            reward_shaping=eval,
            hard_reset=False,
            horizon=256,
            control_freq=10,
            table_full_size=(0.8, 2.0, 0.05),
            table_offset=(0, 0, 0.7)
        )
        env = DomainRandomizationWrapper(env, seed=seed, randomize_every_n_steps=0, randomize_color=False)
        # if not eval:
        env = GymWrapper(env, flatten_obs=False)
        env = gym.wrappers.TransformObservation(env, 
                                                transform_obs, 
                                                gym.spaces.Box(-np.inf, np.inf, shape=(camera_dim, camera_dim, 2))
                                                )
        env = gym.wrappers.FrameStackObservation(env, framestack)
        env = gym.wrappers.TransformObservation(env, 
                                                transform_framestacked_obs, 
                                                gym.spaces.Box(-np.inf, np.inf, shape=(2 * framestack, camera_dim, camera_dim))
                                                )
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env
    return thunk

In [8]:
def eval(ppo_network, eval_env):
    done = False
    obs, _ = eval_env.reset()
    total_reward = 0
    num_completed = 0
    frames1 = []
    frames2 = []
    while not done:
        obs = torch.as_tensor(obs).to(device).float()
        action, _ = ppo_network.policy_network.policy_fn(obs, det=True)
        action = torch.tanh(action)
        obs, reward, term, trunc, infos = eval_env.step(to_numpy(action))
        done = term or trunc
        total_reward += reward
        frames1.append(np.flip(eval_env.env.env.env.env._get_observations()['frontview_image'], 0))
        frames2.append(np.flip(eval_env.env.env.env.env._get_observations()['agentview_image'], 0))
        for info in infos:
            if 'episode' in info:
                total_reward += infos['episode']['r']
                num_completed += 1
    eval_env.reset()
    return total_reward, frames1, frames2

def eval_action(action, eval_env):
    done = False
    obs, _ = eval_env.reset()
    total_reward = 0
    num_completed = 0
    frames1 = []
    frames2 = []
    while not done:
        obs, reward, term, trunc, infos = eval_env.step(action)
        done = term or trunc
        total_reward += reward
        frames1.append(np.flip(eval_env.env.env.env.env._get_observations()['frontview_image'], 0))
        frames2.append(np.flip(eval_env.env.env.env.env._get_observations()['agentview_image'], 0))
        for info in infos:
            if 'episode' in info:
                total_reward += infos['episode']['r']
                num_completed += 1
    eval_env.reset()
    return total_reward, frames1, frames2

In [9]:
eval_env = make_env("Lift", 100, 1, 4, eval=True)()

[1m[32m[robosuite INFO] [0mLoading controller configuration from: /home/antony/106b-final-project/robosuite/controllers/config/robots/default_sawyer.json (composite_controller_factory.py:121)


In [10]:
policy = RobosuitePolicy(7, 100)
value = RobosuiteValue(100)
ppo_network = PPONetwork(policy, value).to(device)
ppo_network.load_state_dict(torch.load('dense.pt'))
_, frames, _ = eval(ppo_network, eval_env)
media.show_video(frames)
imageio.mimwrite('dense.gif', frames[::4], loop=0, fps=20)

0
This browser does not support the video tag.


In [18]:
frames[0].shape
imageio.imsave('obs.png', (sk.transform.rescale(frames[0].mean(-1), 0.4)).astype(np.uint8))

In [6]:
policy = RobosuitePolicy(7, 100)
value = RobosuiteValue(100)
ppo_network = PPONetwork(policy, value).to(device)
ppo_network.load_state_dict(torch.load('denseRND.pt'))
_, frames, _ = eval(ppo_network, eval_env)
media.show_video(frames)
imageio.mimwrite('denseRND.gif', frames[::4], loop=0, fps=20)

0
This browser does not support the video tag.


In [7]:
policy = RobosuitePolicy(7, 100)
value = RobosuiteValue(100)
ppo_network = PPONetwork(policy, value).to(device)
ppo_network.load_state_dict(torch.load('sparse.pt'))
_, frames, _ = eval(ppo_network, eval_env)
media.show_video(frames)
imageio.mimwrite('sparse.gif', frames[::4], loop=0, fps=20)

0
This browser does not support the video tag.


In [8]:
policy = RobosuitePolicy(7, 100)
value = RobosuiteValue(100)
ppo_network = PPONetwork(policy, value).to(device)
ppo_network.load_state_dict(torch.load('sparseRND.pt'))
_, frames, _ = eval(ppo_network, eval_env)
media.show_video(frames)
imageio.mimwrite('sparseRND.gif', frames[::4], loop=0, fps=20)

0
This browser does not support the video tag.


In [23]:
policy = RobosuitePolicy(7, 100)
value = RobosuiteValue(100)
ppo_network = PPONetwork(policy, value).to(device)
_, frames1, frames2 = eval(ppo_network, eval_env)
media.show_video(frames1)
media.show_video(frames2)
imageio.mimwrite('random1.gif', frames1[::4], loop=0, fps=20)
imageio.mimwrite('random2.gif', frames2[::4], loop=0, fps=20)

0
This browser does not support the video tag.


0
This browser does not support the video tag.
