In [1]:
import collections

import dreamerv3
from dreamerv3 import embodied
from dreamerv3.embodied.envs import color_dmc
from dreamerv3 import ninjax as nj

from wrappers import color_grid_utils

import jax
import jax.numpy as jnp
import matplotlib as mpl
import numpy as np
import pandas as pd
import scipy as sp

tree_map = jax.tree_util.tree_map

from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML, FileLink
import tqdm

In [2]:
MODELS = {
    'v3_action_seq_new_vaml': {
        'configs': [
            'action_seq_evil_2500_new',
            'image_value_gradient'
        ],
        'logdir': (
            '/media/miles/File_Storage/'
            'distracting_benchmarks_logdir/node_2/'
            'logdir/dreamerv3/'
            'cheetah_action_seq_evil_2500_new_vaml'
        ),
        'task': 'cheetah_run',
    },
    'v3_action_seq_new': {
        'configs': [
            'action_seq_evil_2500_new'
        ],
        'logdir': (
            '/media/miles/File_Storage/'
            'distracting_benchmarks_logdir/'
            'node_2/logdir/dreamerv3/'
            'cheetah_action_seq_evil_2500_new'
        ),
        'task': 'cheetah_run',
    },
    'v3_action_seq_new_vaml_guided_silu': {
        'configs': [
            'action_seq_evil_2500_new',
            'image_value_gradient',
            'guided_silu_expl',
        ],
        'logdir': (
            '/media/miles/File_Storage/'
            'distracting_benchmarks_logdir/node_2/'
            'logdir/dreamerv3/'
            'cheetah_action_seq_evil_2500_new_vaml'
        ),
        'task': 'cheetah_run',
    },
    'v3_action_seq_new_guided_silu': {
        'configs': [
            'action_seq_evil_2500_new',
            'guided_silu_expl',
        ],
        'logdir': (
            '/media/miles/File_Storage/'
            'distracting_benchmarks_logdir/'
            'node_2/logdir/dreamerv3/'
            'cheetah_action_seq_evil_2500_new'
        ),
        'task': 'cheetah_run',
    },
#     'v3_action_seq_new_vaml_reward_expl': {
#         'configs': [
#             'action_seq_evil_2500_new',
#             'image_value_gradient',
#             'reward_head_gradient_weighting',
#         ],
#         'logdir': (
#             '/media/miles/File_Storage/'
#             'distracting_benchmarks_logdir/node_2/'
#             'logdir/dreamerv3/'
#             'cheetah_action_seq_evil_2500_new_vaml'
#         ),
#         'task': 'cheetah_run',
#     },
#     'v3_action_seq_new_reward_expl': {
#         'configs': [
#             'action_seq_evil_2500_new',
#             'reward_head_gradient_weighting',
#         ],
#         'logdir': (
#             '/media/miles/File_Storage/'
#             'distracting_benchmarks_logdir/'
#             'node_2/logdir/dreamerv3/'
#             'cheetah_action_seq_evil_2500_new'
#         ),
#         'task': 'cheetah_run',
#     },
    'v3_none_vaml_guided_silu': {
        'configs': [
            'no_evil',
            'image_value_gradient',
            'guided_silu_expl',
        ],
        'logdir': (
            '/media/miles/File_Storage/'
            'distracting_benchmarks_logdir/node_1/'
            'dreamerv3/'
            'none_small_vaml_scaling_img'
        ),
        'task': 'cheetah_run',
    },
    'v3_none_guided_silu': {
        'configs': [
            'no_evil',
            'guided_silu_expl',
        ],
        'logdir': (
            '/media/miles/File_Storage/'
            'distracting_benchmarks_logdir/node_1/'
            'dreamerv3/'
            'none'
        ),
        'task': 'cheetah_run',
    },

  'v3_none_vaml': {
    'configs': [
      'no_evil',
      'image_value_gradient'
    ],
    'logdir': (
      '/media/miles/File_Storage/'
      'distracting_benchmarks_logdir/node_1/'
      'dreamerv3/'
      'none_small_vaml_scaling_img'
    ),
    'task': 'cheetah_run',
  },
  'v3_none': {
    'configs': [
      'no_evil'
    ],
    'logdir': (
      '/media/miles/File_Storage/'
      'distracting_benchmarks_logdir/node_1/'
      'dreamerv3/'
      'none'
    ),
    'task': 'cheetah_run',
  },
}

In [3]:
def get_config(model_info):
    config = embodied.Config(dreamerv3.configs['defaults'])
    config = config.update(dreamerv3.configs['dmc_vision'])
#     config = config.update(dreamerv3.configs['cpu_full'])
#     config = config.update(dreamerv3.configs['multicpu'])
    for config_name in model_info['configs']:
        config = config = config.update(
            dreamerv3.configs[config_name])
    config = embodied.Flags(config).parse([
        '--logdir', model_info['logdir'],
        '--task', model_info['task'],
#         '--jax.jit', 'False',
#         '--jax.policy_devices', '0'
    ])
    return config

In [4]:
def get_env(config, include_foreground_mask=False):
    return color_dmc.DMC(
        config.task,
        repeat=config.env.dmc.repeat,
        size=config.env.dmc.size,
        camera=config.env.dmc.camera,
        num_cells_per_dim=config.evil.num_cells_per_dim,
        num_colors_per_cell=config.evil.num_colors_per_cell,
        evil_level=color_grid_utils.EVIL_CHOICE_CONVENIENCE_MAPPING[
            config.evil.evil_level
        ],
        action_dims_to_split=config.evil.action_dims_to_split,
        action_power=(
            config.evil.action_power if config.evil.action_power >= 0
            else None),
        action_splits=(
            config.evil.action_splits if config.evil.action_power < 0
            else None),
        include_foreground_mask=include_foreground_mask
    )

## Functions for Setting Up DreamerV3 Agent & Dataset, Loading from Checkpoint

In [5]:
def get_dreamer_env(config, include_foreground_mask=False):
    env = get_env(
        config,
        include_foreground_mask=include_foreground_mask)
    env = dreamerv3.wrap_env(env, config)
    env = embodied.BatchEnv([env], parallel=False)
    return env

In [6]:
def get_checkpoint(config):
    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    env = get_dreamer_env(config)
    agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
    checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt')
    checkpoint.agent = agent
    checkpoint.load(keys=['agent'])
    return checkpoint

In [11]:
model_episodes = collections.defaultdict(dict)
for k, config in tqdm.tqdm(
    MODELS.items(), total=len(MODELS), desc='model'):
    for v_expl_mode in [
        'gradient_x_intensity',
        'gradient',
        'integrated_gradient',
    ]:
        model_checkpoint = get_checkpoint(get_config(config))
        
        env = get_dreamer_env(
            get_config(config),
            include_foreground_mask=True)
        
        driver = embodied.Driver(env)
        eval_episodes = []
        def per_episode(ep):
            eval_episodes.append(ep)
        driver.on_episode(lambda ep, worker: per_episode(ep))
        
        agent = model_checkpoint._values['agent']
        policy = lambda *args: agent.policy(
            *args, mode='eval', include_recon=True, v_expl_mode=v_expl_mode)
        for _ in tqdm.trange(7*10, desc='policy'):
            driver(policy, steps=100)
            
        model_episodes[k][v_expl_mode] = eval_episodes
del model_checkpoint

model:   0%|          | 0/8 [00:00<?, ?it/s]

Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
JAX devices (1): [gpu(id=0)]
Policy devices: gpu:0
Train devices:  gpu:0
Tracing train function.
Optimizer model_opt has 15,687,811 variables.
Optimizer actor_opt has 1,056,780 variables.
Optimizer critic_opt has 1,181,439 variables.
Loading checkpoint: /media/miles/File_Storage/distracting_benchmarks_logdir/node_2/logdir/dreamerv3/cheetah_action_seq_evil_2500_new_vaml/checkpoint.ckpt
Loaded checkpoint from 4850896 seconds ago.



policy:   0%|          | 0/70 [00:00<?, ?it/s][A

Tracing policy function.


policy:   0%|          | 0/70 [1:37:26<?, ?it/s]
model:   0%|          | 0/8 [1:37:45<?, ?it/s]


KeyboardInterrupt: 

In [8]:
model_episodes['v3_none_vaml']

{'gradient_x_intensity': [{'reward': array([0.00000000e+00, 4.97998809e-03, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 1.59430923e-03, 5.39048051e-04, 0.00000000e+00,
          1.46176002e-03, 2.18551303e-03, 1.77431088e-02, 1.03630170e-01,
          1.68650150e-01, 2.60597467e-01, 3.68559361e-01, 3.89973938e-01,
          4.14905906e-01, 4.28697824e-01, 4.66158867e-01, 5.05913079e-01,
          5.16572833e-01, 5.16361117e-01, 5.10311246e-01, 5.15820026e-01,
          5.18683672e-01, 5.20215511e-01, 5.11716604e-01, 5.06637633e-01,
          5.24950624e-01, 5.46185791e-01, 5.40593386e-01, 5.27788877e-01,
          5.13428628e-01, 5.09651423e-01, 5.22720695e-01, 5.28301954e-01,
          5.05342007e-01, 4.40038860e-01, 4.35096622e-01, 4.26327020e-01,
          4.37273860e-01, 4.32665169e-01, 4.38637614e-01, 4.43808377e-01,
          4.47750628e-01, 4.41748857e-01, 4.38793361e-01, 3.87607515e-01,
          3.48871797e-01, 3.56112450e-01, 3.74943435e-01, 4.00986820e-01,
    

In [9]:
{k: v.shape for k, v in model_episodes['v3_none_vaml']['gradient_x_intensity'][0].items()}

{'reward': (501,),
 'is_first': (501,),
 'is_last': (501,),
 'is_terminal': (501,),
 'position': (501, 8),
 'velocity': (501, 9),
 'image': (501, 64, 64, 3),
 'action': (501, 6),
 'image_expl': (501, 64, 64, 3),
 'log_entropy': (501,),
 'recon': (501, 64, 64, 3),
 'v': (501,),
 'reset': (501,)}

In [8]:
def rolling_average(image_expl, window_size, axis=0):
    return sp.ndimage.convolve1d(
        image_expl,
        np.ones(window_size),
        axis=axis,
        mode='reflect') / window_size

In [11]:
rolling_average(
    model_episodes['v3_none_vaml']['gradient_x_intensity'][0]['image_expl'],
    3).shape

(501, 64, 64, 3)

In [9]:
def normalize(im):
    p99 = np.percentile(im, 99)
    return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)

def stack_episodes(episodes):
    stacked_episodes = {}
    for k in episodes[0]:
        stacked_episodes[k] = np.stack([
            episode[k] for episode in episodes[1:]
        ])
    return stacked_episodes

def show_episode(
        episodes, 
        download=False, 
        image_expl_window=None,
        name=None):
    episodes = stack_episodes(episodes)
    if image_expl_window is not None:
        episodes['image_expl'] = rolling_average(
            episodes['image_expl'],
            image_expl_window,
            axis=1
        )

    def show_for_i(i):
        im = episodes['image'][:, i] / 255.
        recon = episodes['recon'][:, i]

        error = np.sum((im - recon) ** 2, axis=-1)
        # bg = episodes['foreground_mask'][:, i]
        expl = np.zeros_like(im)
        expl[..., 1] = np.absolute(
            episodes['image_expl'][:, i].sum(axis=-1))
        expl = normalize(expl)


        color_coded_error = np.zeros_like(im)
        # color_coded_error[..., 2] = error * bg
        # color_coded_error[..., 0] = error * (~bg)
        bg_error = color_coded_error.copy()
        bg_error[..., 0] = 0
        fg_error = color_coded_error.copy()
        fg_error[..., 2] = 0

        recon = np.clip(recon, 0, 1)
        color_coded_error = normalize(color_coded_error)
        bg_error = normalize(bg_error)
        fg_error = normalize(fg_error)

        # bg = bg[..., None]
        # bg = bg * 1.
        # bg = np.repeat(bg, 3, axis=-1)

        frame = np.concatenate([           # [B, H, 7*W, C]
            im, recon, 
            # color_coded_error, 
            expl,
            # bg_error, 
            # fg_error,
            # bg
        ], axis=2)
        frame = frame.reshape(-1, *frame.shape[2:]) # [B * 7 * H, W, C]
        return frame

    fig = plt.figure()
    im = plt.imshow(show_for_i(0))
    plt.close()

    def init():
        im.set_data(show_for_i(0))

    def animate(i):
        im.set_data(show_for_i(i))
        return im

    anim = animation.FuncAnimation(
        fig, animate, init_func=init,
        frames=episodes['image'].shape[1],
        interval=50)
    if download:
        mpl.rcParams['animation.bitrate'] = 8192
        if name is None:
            name = str(np.random.randint(100000))
        fname = f'./tmp/{name}_dreamer_v3_movie.mov'
        anim.save(fname)
        local_file = FileLink(fname)
        return display(local_file)
    else:
        mpl.rcParams['animation.bitrate'] = -1
        return HTML(anim.to_html5_video())


In [10]:
for model_name in model_episodes:
    for technique in model_episodes[model_name]:
        episodes = model_episodes[model_name][technique]
        for window in [None, 3, 5, 7]:
            show_episode(
                model_episodes[model_name][technique],
                download=True,
                name=f'{model_name}_{technique}_{window}',
                image_expl_window=window
            )

  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)
  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)


  return np.clip((im - im.min()) / (p99 - im.min()), 0, 1)
