In [1]:
import sys
sys.path.append("/scratch/gpfs/graceliu/contrastive_rl")

import copy
import functools

from matplotlib import pyplot as plt
from matplotlib import animation
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import HTML

import jax
import optax
import numpy as np
from acme import specs
import tensorflow as tf

from acme.tf.savers import SaveableAdapter

from contrastive.config import ContrastiveConfig
from contrastive import utils as contrastive_utils
from contrastive import make_networks
from contrastive.utils import make_environment
from contrastive import ContrastiveLearner

# disable tensorflow_probability warning: The use of `check_types` is deprecated and does not have any effect.
import logging
logger = logging.getLogger("root")

class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()

logger.addFilter(CheckTypesFilter())

from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
import seaborn as sn
import pandas as pd

No module named 'flow'
No module named 'dotmap'
pybullet build time: Nov 28 2023 23:52:03
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [2]:
def load_checkpoint(alpha, env_name, log_dir, seed, fix_goals = False, ckpt_num = None):
    state_entropy_coefficient = alpha
    if(alg == 'contrastive_cpc'):
        misc_params = '{}_None'.format(alpha)
    else:
        misc_params = '{}_0.0'.format(alpha)
    
    ckpt_dir = '{}/{}_{}_{}_{}/checkpoints/learner'.format(log_dir, alg, env_name, misc_params, seed)
    
    fixed_start_end = None
    fixed_goal = False
        
    if fix_goals:
        # this value doesn't actually matter for the sawyer box environment since the goal position is overwritten in env_utils.py
        fixed_start_end = np.array([0.12, 0.7, 0.02]) 
    else:
        fixed_start_end = None
    
    
    env_factory = lambda seed: make_environment(env_name, config.start_index, 
                                                config.end_index, seed=np.random.randint(1e6), fixed_start_end=fixed_start_end)[0] 
    dummy_seed = 1
    environment_spec = specs.make_environment_spec(env_factory(np.random.randint(1e6)))

    obs_dim = make_environment(env_name, config.start_index, config.end_index, seed=np.random.randint(1e6), fixed_start_end=fixed_start_end)[1]
                                   
    network_factory = functools.partial(
      make_networks, obs_dim=obs_dim, repr_dim=config.repr_dim,
      repr_norm=config.repr_norm, twin_q=False,
      use_image_obs=config.use_image_obs,
      hidden_layer_sizes=config.hidden_layer_sizes,
    )

    random_key = jax.random.PRNGKey(np.random.choice(int(1e6)))
    networks = network_factory(environment_spec)
    policy_optimizer = optax.adam(
      learning_rate=config.actor_learning_rate)
    q_optimizer = optax.adam(.001)#learning_rate=config.critic_learning_rate

    trained_learner = ContrastiveLearner(
      networks=networks,
      rng=random_key,
      policy_optimizer=policy_optimizer,
      q_optimizer=q_optimizer,
      iterator=None,
      counter=None,
      logger=None,
      obs_to_goal=functools.partial(contrastive_utils.obs_to_goal_2d,
                                    start_index=config.start_index,
                                    end_index=config.end_index),
      config=config)

    returns_list = []
    success_rate_list = []
    
    env = env_factory(np.random.randint(1e6))
    obs_dim = env.observation_spec().shape[0] // 2
    episode_returns = np.zeros([NUM_EPISODES, ])

    ckpt = tf.train.Checkpoint(learner=SaveableAdapter(trained_learner))
    ckpt_mgr = tf.train.CheckpointManager(ckpt, ckpt_dir, 1)
    if ckpt_num is not None:
        ckpt.restore(ckpt_dir + '/ckpt-' + str(ckpt_num)).assert_consumed()
        print(ckpt_dir + str(ckpt_num))
    else:
        ckpt.restore(ckpt_mgr.latest_checkpoint).assert_consumed()

    trained_learner_state = trained_learner._state
    
    print("Model loaded from: {}".format(ckpt_dir))
    return trained_learner_state, env, networks

In [3]:
def plot_episode(trained_learner_state, env, networks, alpha, verbose = False, plot_tsne = True, plot_norms = True, show_path = True, show_video = True):
    repr_obs = np.zeros((EPISODE_LENGTH, obs_repr_shape))
    repr_action = np.zeros((EPISODE_LENGTH, action_repr_shape))
    repr_obs_home = np.zeros((EPISODE_LENGTH, obs_repr_shape))
    repr_obs_reach = np.zeros((EPISODE_LENGTH, obs_repr_shape))
    repr_obs_state = np.zeros((EPISODE_LENGTH, obs_repr_shape))

    imgs = np.zeros([EPISODE_LENGTH, IMAGE_HEIGHT, IMAGE_WIDTH, 3], dtype=np.uint8)

    t = 0
    env.seed(np.random.randint(1e6)) 
    timestep = env.reset()
    episode_return = 0
    initial_state = np.zeros((state_repr_shape))
    while t<EPISODE_LENGTH:
        img = env.render(offscreen=True, camera_name="corner3", resolution=(320, 240))
        imgs[t] = img

        obs = timestep.observation
        if t == 0:
            initial_state = obs[:state_repr_shape]

        dist = networks.policy_network.apply(
          trained_learner_state.policy_params,
          obs
        )

        action = np.array(dist.mode())
        timestep = env.step(action)
        repr_obs[t][:] = np.expand_dims(obs, axis=0)

        home_obs = np.copy(obs)
        home_obs[state_repr_shape:] = initial_state
        repr_obs_home[t][:] = np.expand_dims(home_obs, axis=0)

        reach_obs = np.copy(obs)
        reach_obs[state_repr_shape:] = obs[:state_repr_shape]
        home_obs[:state_repr_shape] = initial_state
        repr_obs_reach[t][:] = np.expand_dims(reach_obs, axis=0)

        state_obs = np.copy(obs)
        state_obs[state_repr_shape:] = obs[:state_repr_shape]
        repr_obs_state[t][:] = np.expand_dims(state_obs, axis=0)

        repr_action[t][:] = np.expand_dims(action, axis=0)

        t += 1
        episode_return += timestep.reward
        if (episode_return <= 2 and verbose):
            print("timestep=" + str(t))
            q_action, sa_repr, g_repr = networks.q_network.apply(trained_learner_state.q_params, np.expand_dims(obs, axis=0), np.expand_dims(action, axis=0), 
                                                                     use_l2_critic=use_l2_critic, 
                                                                     use_phi_critic=use_phi_critic, 
                                                                     use_sa_reg=False, 
                                                                     alpha=alpha)

            print("obs: " + str(obs))
            print("action: " + str(action))
            print("sa norm: " + str(np.einsum('ik,ik->i', sa_repr, sa_repr)))
            print("critic: " + str(np.einsum('ik,jk->ij', sa_repr, g_repr)))
            print()
    
    dist_from_start = np.linalg.norm(obs[:state_repr_shape] - initial_state[:state_repr_shape])
    print("episode_return: " + str(episode_return))
    print("dist_from_start: " + str(dist_from_start))
    print()
    print()

    q_action, sa_repr, g_repr = networks.q_network.apply(trained_learner_state.q_params, repr_obs, repr_action, 
                                                                 use_l2_critic=use_l2_critic, 
                                                                 use_phi_critic=use_phi_critic, 
                                                                 use_sa_reg=False, 
                                                                 alpha=alpha)
    
    q_action0, sa_repr0, g_repr0 = networks.q_network.apply(trained_learner_state.q_params, repr_obs_home, repr_action, 
                                                                 use_l2_critic=use_l2_critic, 
                                                                 use_phi_critic=use_phi_critic, 
                                                                 use_sa_reg=False, 
                                                                 alpha=alpha)

    q_action1, sa_repr1, g_repr1 = networks.q_network.apply(trained_learner_state.q_params, repr_obs_reach, repr_action, 
                                                                     use_l2_critic=use_l2_critic, 
                                                                     use_phi_critic=use_phi_critic, 
                                                                     use_sa_reg=False, 
                                                                     alpha=alpha)

    q_action2, sa_repr2, g_repr2 = networks.q_network.apply(trained_learner_state.q_params, repr_obs_state, repr_action, 
                                                                     use_l2_critic=use_l2_critic, 
                                                                     use_phi_critic=use_phi_critic, 
                                                                     use_sa_reg=False, 
                                                                     alpha=alpha)

    if (plot_norms):
        sa_norm_arr = np.einsum('ik,ik->i', sa_repr, sa_repr)
        g_norm = np.einsum('ik,ik->i', g_repr, g_repr)[0]
        critic_arr = np.diag(np.einsum('ik,jk->ij', sa_repr, g_repr))
        home_critic_arr = np.diag(np.einsum('ik,jk->ij', sa_repr0, g_repr0))
        reach_critic_arr = np.diag(np.einsum('ik,jk->ij', sa_repr1, g_repr1))
        state_critic_arr = np.diag(np.einsum('ik,jk->ij', sa_repr2, g_repr2))

        corr = np.corrcoef(sa_norm_arr, critic_arr)
        print("correlation between critic and phi sa norm: " + str(corr[0][1]))
        print("psi g norm: " + str(g_norm))
        #plt.scatter(np.arange(EPISODE_LENGTH), sa_norm_arr, s=8, label = "phi norm")
        plt.scatter(np.arange(EPISODE_LENGTH), critic_arr, s=8, label = "critic")
        #plt.scatter(np.arange(EPISODE_LENGTH), home_critic_arr, s=8, label = "home reachability")
        #plt.scatter(np.arange(EPISODE_LENGTH), reach_critic_arr, s=8, label = "reachability from home")
        #plt.scatter(np.arange(EPISODE_LENGTH), state_critic_arr, s=8, label = "state reachability")

        plt.legend()
        plt.title("phi norm vs critic")
        plt.show()

    if (plot_tsne):
        standardized_data = StandardScaler().fit_transform(sa_repr)
        model = TSNE(n_components = 2, random_state = 0)

        tsne_data = model.fit_transform(standardized_data) 
        timestep = np.arange(EPISODE_LENGTH)
        tsne_data = np.vstack((tsne_data.T, timestep)).T
        tsne_df = pd.DataFrame(data = tsne_data,
             columns =("Dim_1", "Dim_2", "Timestep"))

        # Plotting the result of tsne
        sn.scatterplot(data=tsne_df, x='Dim_1', y='Dim_2', hue='Timestep')
        plt.show()

    if (show_video):
        play_video([imgs])

    return #sa_norm_arr, critic_arr, g_repr[0], sa_repr


In [4]:
def gen_images(alpha, env_name, log_dir, seed):    
    
    imgs_list = []

    trained_learner_state, env, networks = load_checkpoint(alpha, env_name, log_dir, seed, fix_goals=True)
    episode_returns = np.zeros([NUM_EPISODES, ])
    imgs = np.zeros([NUM_EPISODES, env._step_limit, IMAGE_HEIGHT, IMAGE_WIDTH, 3], dtype=np.uint8)
    for epi in range(NUM_EPISODES):
      t = 0
      env.seed(epi)  # use fixed seed for different methods
      timestep = env.reset()
      episode_return = 0

      while not timestep.last():
        # render images
        img = env.render(offscreen=True, camera_name="corner3", resolution=(1280, 960)) #mode='rgb_array', height=IMAGE_HEIGHT, width=IMAGE_WIDTH
        # # add a border if success
        # img = Image.fromarray(img)
        # orig_size = img.size
        # border = (8, 8, 8, 8)
        # if episode_return >= 1:
        #   img = ImageOps.expand(img, border=border, fill="#008000")
        # else:
        #   img = ImageOps.expand(img, border=border, fill="#FF3131")
        # img = img.resize(orig_size)
        # img = np.array(img)
        imgs[epi, t] = img

        dist = networks.policy_network.apply(
          trained_learner_state.policy_params,
          timestep.observation
        )
        action = np.array(dist.mode())
        timestep = env.step(action)

        # Book-keeping.
        t += 1
        episode_return += timestep.reward

      # assert t == env._step_limit
      print("episode return = {}".format(episode_return))
      episode_returns[epi] = episode_return

    imgs = imgs.reshape([-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
    imgs_list.append(imgs)
    print("avg episode return: {}".format(np.mean(episode_returns)))
    print("success rate: {}".format(np.mean(episode_returns >= 1)))
    print()
    
    return imgs_list

In [5]:
def play_video(imgs_list):
    for imgs in imgs_list:
        fig = plt.figure()
        im = plt.imshow(imgs[0, :, :, :])

        plt.close() # this is required to not display the generated image

        def init():
            im.set_data(imgs[0, :, :, :])

        def animate(i):
            im.set_data(imgs[i, :, :, :])
            return im

        anim = animation.FuncAnimation(fig, animate, init_func=init, frames=imgs.shape[0],
                                       interval=100)
        display(HTML(anim.to_html5_video()))

In [20]:
fixed_goal_obs_dict={'sawyer_bin2': [0.12, 0.7, 0.05, 0.4, 0.12, 0.7, 0.02]}

config = ContrastiveConfig(fixed_goal_obs_dict = fixed_goal_obs_dict)
seeds_list = [0, 1, 2, 3, 4]
IMAGE_HEIGHT = 240
IMAGE_WIDTH = 320
NUM_EPISODES = 1000

config.use_cpc_symm = True
config.use_goal_action = True
config.use_quasimetric_logit = True
config.use_goal_potential = True
config.twin_q = True
fixed_goal=False
goal_difficulty = None
fixed_goal_coords=[0.12, 0.7, 0.02]
goal_difficulty_list = [0.2, 0.6, 1.0]
fix_point_goals = True

use_l2_critic = False
use_phi_critic = True
use_sa_reg = True
is_kl = False

NUM_EPISODES = 20
fix_point_goals = True
EPISODE_LENGTH = 500


## Load hard sawyer box checkpoint and show video of episode

In [18]:
alpha = "0.1"
env_name = 'sawyer_box2'
log_dir = '/scratch/gpfs/graceliu/crl/hard_envs'
alg = 'contrastive_cpc'
seed = 1

obs_repr_shape = 28
state_repr_shape = 14
action_repr_shape = 4

In [19]:
# if ckpt_num is not specified, automatically load the most recent checkpoint in the checkpoint directory
trained_learner_state, env, networks = load_checkpoint(alpha, env_name, log_dir, seed, fix_goals = True)
plot_episode(trained_learner_state, env, networks, alpha, show_video = True, plot_norms=False, plot_tsne=False)

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Model loaded from: /scratch/gpfs/graceliu/crl/hard_envs/contrastive_cpc_sawyer_box2_0.1_None_1/checkpoints/learner
episode_return: 0.0
dist_from_start: 1.1348726


