## Imports

In [15]:
import functools
import warnings

import dreamerv3
from dreamerv3 import embodied
from dreamerv3.embodied.envs import color_dmc
from dreamerv3.embodied.envs import from_gym
from dreamerv3 import ninjax as nj

from wrappers import color_grid_utils

import jax

tree_map = jax.tree_util.tree_map

## Setting up config and DMC Color Environment

In [2]:
config = embodied.Config(dreamerv3.configs['defaults'])
config = config.update(dreamerv3.configs['medium'])
config = config.update(dreamerv3.configs['dmc_vision'])
for name in [
    'action_evil',
    'dynamic_value_gradient',
    'latent_value_gradient_normed',
    'latent_value_gradient_norm_keep_magnitude'
]:
    config = config.update(dreamerv3.configs[name])
config = embodied.Flags(config).parse([
    '--logdir',
    '/media/hdd/Storage/distracting_benchmarks/logdir/dreamerv3'
    '/action_vaml_scaling_dyn_normed_keep_magnitude',
    '--task', 'cheetah_run',
    '--jax.platform', 'cpu',
    '--jax.jit', 'False'
])
logdir = embodied.Path(config.logdir)

In [3]:
env = color_dmc.DMC(
    config.task,
    repeat=config.env.dmc.repeat,
    size=config.env.dmc.size,
    camera=config.env.dmc.camera,
    num_cells_per_dim=config.evil.num_cells_per_dim,
    num_colors_per_cell=config.evil.num_colors_per_cell,
    evil_level=color_grid_utils.EVIL_CHOICE_CONVENIENCE_MAPPING[
        config.evil.evil_level
    ],
    action_dims_to_split=config.evil.action_dims_to_split,
    action_power=(
        config.evil.action_power if config.evil.action_power >= 0
        else None),
    action_splits=(
        config.evil.action_splits if config.evil.action_power < 0
        else None),
)

## Setting up DreamerV3 Agent & Dataset, Loading from Checkpoint

In [4]:
env = dreamerv3.wrap_env(env, config)
env = embodied.BatchEnv([env], parallel=False)

In [5]:
step = embodied.Counter()
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
replay = embodied.replay.Uniform(
    config.batch_length, config.replay_size, logdir / 'replay')

Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
JAX devices (1): [CpuDevice(id=0)]
Policy devices: TFRT_CPU_0
Train devices:  TFRT_CPU_0
Optimizer model_opt has 15,687,811 variables.
Optimizer actor_opt has 1,056,780 variables.
Optimizer critic_opt has 1,181,439 variables.


In [6]:
checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt')
checkpoint.step = step
checkpoint.agent = agent
checkpoint.replay = replay
# checkpoint.load(logdir / 'checkpoint.ckpt')
checkpoint.load_or_save()

Found existing checkpoint.
Loading checkpoint: /media/hdd/Storage/distracting_benchmarks/logdir/dreamerv3/action_vaml_scaling_dyn_normed_keep_magnitude/checkpoint.ckpt
Loaded checkpoint from 2021177 seconds ago.


In [8]:
def get_embed_post_prior_v(sample, state):
    vf = agent.agent.task_behavior.ac.critics['extr'].net
    post, (embed, prior) = agent.agent.wm.get_embed_post_prior({}, sample, state)
    v_mean = jax.jacrev(lambda post: vf(post).mean())
    latent_v_grad = v_mean(post)
    
    return embed, post, prior, latent_v_grad
get_embed_post_prior_v = jax.vmap(get_embed_post_prior_v, in_axes=[0, 0,])

In [9]:
def preprocess_and_get_embed_post_prior_v(sample):
    state = agent.agent.train_initial(len(sample['is_first']))
    sample = agent.agent.preprocess(sample)
    return get_embed_post_prior_v(sample, state)

In [10]:
preprocess_and_get_embed_post_prior_v = nj.pure(preprocess_and_get_embed_post_prior_v)

In [12]:
dataset = agent.dataset(replay.dataset)
sample = next(dataset)

In [13]:
varibs = {}
rng = agent._next_rngs(agent.policy_devices)
(embed, post, prior, latent_v_grad), varibs = preprocess_and_get_embed_post_prior_v(varibs, rng, sample)

In [16]:
tree_map(lambda x: (x > 0).any(), latent_v_grad)

{'deter': Array(False, dtype=bool),
 'logit': Array(False, dtype=bool),
 'stoch': Array(False, dtype=bool)}