In [1]:
import os
os.environ["MUJOCO_GL"] = "egl"

from dm_control import suite
from dm_control.suite.wrappers import pixels
import numpy as np



# Load one task:
env = suite.load(domain_name="cartpole", task_name="swingup")
env = pixels.Wrapper(env, render_kwargs={'width': 64, 'height': 64})

# Step through an episode and print out reward, discount and observation.
action_spec = env.action_spec()
time_step = env.reset()
while not time_step.last():
  action = np.random.uniform(action_spec.minimum,
                             action_spec.maximum,
                             size=action_spec.shape)
  time_step = env.step(action)
  print(time_step.reward, time_step.discount, time_step.observation)
  break

7.081447400492573e-05 1.0 OrderedDict([('pixels', array([[[ 39,  56,  72],
        [ 35,  52,  69],
        [ 35,  52,  69],
        ...,
        [ 36,  53,  70],
        [ 37,  54,  70],
        [ 37,  54,  71]],

       [[ 35,  52,  69],
        [ 34,  51,  68],
        [ 34,  51,  68],
        ...,
        [ 42,  58,  74],
        [ 43,  59,  75],
        [ 48,  64,  79]],

       [[ 35,  51,  68],
        [ 34,  50,  67],
        [ 36,  52,  69],
        ...,
        [ 34,  50,  67],
        [ 39,  55,  71],
        [ 35,  52,  69]],

       ...,

       [[ 40,  70,  99],
        [ 40,  71, 100],
        [ 40,  71, 101],
        ...,
        [ 49,  80, 110],
        [ 49,  79, 110],
        [ 48,  79, 109]],

       [[ 57,  86, 117],
        [ 57,  88, 118],
        [ 58,  87, 118],
        ...,
        [ 33,  64,  93],
        [ 32,  62,  91],
        [ 31,  62,  91]],

       [[ 59,  88, 118],
        [ 59,  89, 119],
        [ 58,  88, 118],
        ...,
        [ 32,  62,  91],

In [2]:
def preprocess_obs(obs, bit_depth=5):
    """
    Reduces the bit depth of image for the ease of training
    and convert to [-0.5, 0.5]
    In addition, add uniform random noise same as original implementation
    """
    obs = obs.astype(np.float32)
    reduced_obs = np.floor(obs / 2 ** (8 - bit_depth))
    normalized_obs = reduced_obs / 2**bit_depth - 0.5
    normalized_obs += np.random.uniform(0.0, 1.0 / 2**bit_depth, normalized_obs.shape)
    return normalized_obs

In [3]:
x = preprocess_obs(time_step.observation['pixels']) 

In [4]:
from planet.models.encoder import ImageEncoderModel
model = ImageEncoderModel(
    hidden_state_size=32,
    observation_size=32,
    state_size=32,
    hidden_layer_size=128,
    in_channels=3,
)

In [5]:
import torch
hidden_state = torch.Tensor(np.zeros((1, 32)))
observation = torch.Tensor(x).permute((2, 0, 1)).unsqueeze(0)

In [6]:
hidden_state.shape, observation.shape

(torch.Size([1, 32]), torch.Size([1, 3, 64, 64]))

In [8]:
a, b = model(hidden_state, observation)

In [9]:
a.shape, b.shape

(torch.Size([1, 32]), torch.Size([1, 32]))