In [None]:
import warnings
import dreamerv3
from dreamerv3 import embodied
warnings.filterwarnings('ignore', '.*truncated to dtype int32.*')

In [None]:
# See configs.yaml for all options.
config = embodied.Config(dreamerv3.configs['defaults'])
config = config.update(dreamerv3.configs['medium'])
config = config.update({
    'logdir': '~/logdir/run1',
    'run.train_ratio': 64,
    'run.log_every': 30,  # Seconds
    'batch_size': 16,
    'jax.prealloc': False,
    'encoder.mlp_keys': '$^',
    'decoder.mlp_keys': '$^',
    'encoder.cnn_keys': 'image',
    'decoder.cnn_keys': 'image',
    # 'jax.platform': 'cpu',
})
config = embodied.Flags(config).parse(argv=[])

In [None]:
logdir = embodied.Path(config.logdir)
step = embodied.Counter()
logger = embodied.Logger(step, [
    embodied.logger.TerminalOutput(),
    embodied.logger.JSONLOutput(logdir, 'metrics.jsonl'),
    embodied.logger.TensorBoardOutput(logdir),
    # embodied.logger.WandBOutput(logdir.name, config),
    # embodied.logger.MLFlowOutput(logdir.name),
])

In [None]:
import crafter
from embodied.envs import from_gym
env = crafter.Env()  # Replace this with your Gym env.
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)
env = embodied.BatchEnv([env], parallel=False)

In [None]:
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
replay = embodied.replay.Uniform(
    config.batch_length, config.replay_size, logdir / 'replay')
args = embodied.Config(
    **config.run, logdir=config.logdir,
    batch_steps=config.batch_size * config.batch_length)

In [None]:
step = logger.step
checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt')
checkpoint.step = step
checkpoint.agent = agent
checkpoint.replay = replay
checkpoint.load(args.from_checkpoint)

In [None]:
print(env.obs_space)
print(env.act_space)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
action = np.zeros((1, 17))
action[0, 0] = 1
r = env.step({'reset': [True], 'action': action})
plt.imshow(r['image'][0])

In [None]:
action = np.zeros((1, 17))
action[0, 2] = 1
r = env.step({'reset': [False], 'action': action})
plt.imshow(r['image'][0])

In [None]:
state = None

In [None]:
outputs, state = agent.policy(r, state)

print(outputs)

r = env.step({'reset': [False], 'action': outputs['action']})
plt.imshow(r['image'][0])

In [None]:
obs = env.step({"action": action, "reset": [False]})
print(obs.keys())
print(obs["image"].shape)

plt.imshow(obs['image'][0])

In [None]:
import dreamerv3.ninjax as nj

In [None]:
kw = dict(device=agent.policy_devices[0])
_preprocess = nj.pure(agent.agent.preprocess)
_preprocess = nj.jit(_preprocess, **kw)

In [None]:
print(agent.policy_devices)

In [None]:

obsc = agent._convert_inps(obs, agent.policy_devices)
rng = agent._next_rngs(agent.policy_devices)
varibs = agent.varibs if agent.single_device else agent.policy_varibs
obsp = _preprocess(varibs, rng, obsc)
#print(agent._convert_outs(obsp, agent.policy_devices))


#states, _ = self.agent.agent.wm.rssm.observe(embed, data['action'], data['is_first'], self.input_state)

In [None]:
import types
def encode(self, obs):
    obs = self.preprocess(obs)
    embed = self.wm.encoder(obs)
    return embed
agent.agent.encode = types.MethodType(encode, agent.agent)

In [None]:
_encoder = nj.pure(agent.agent.encode)
_encoder = nj.jit(_encoder, **kw)

In [None]:
embed = _encoder(varibs, rng, obsc)
print(agent._convert_outs(embed, agent.policy_devices))

In [None]:
import jax.numpy as jnp

def env_step(self, obs, state, action):
    if state is None:
        state, _ = self.wm.initial(1)
    obs = self.preprocess(obs)
    embed = self.wm.encoder(obs)
    context, _ = self.wm.rssm.obs_step(
        state, action, embed, obs['is_first'])
    #latentb = {k: jnp.expand_dims(v, 0) for k, v in context.items()}
    return context, self.decode(context)

def imag_step(self, state, action):
    prior = self.wm.rssm.img_step(state, action)
    return prior, self.decode(prior)

def decode(self, state):
    recon = self.wm.heads['decoder'](state)
    result = {}
    for key in self.wm.heads['decoder'].cnn_shapes.keys():
        result[key] = recon[key].mode()
    return result

agent.agent.env_step = types.MethodType(env_step, agent.agent)
agent.agent.imag_step = types.MethodType(imag_step, agent.agent)
agent.agent.decode = types.MethodType(decode, agent.agent)

_env_step = nj.pure(agent.agent.env_step)
_env_step = nj.jit(_env_step, **kw)
_imag_step = nj.pure(agent.agent.imag_step)
_imag_step = nj.jit(_imag_step, **kw)
#_decode = nj.pure(agent.agent.decode)
#_decode = nj.jit(_decode, **kw)

In [None]:
obsc = agent._convert_inps(obs, agent.policy_devices)
actionc = agent._convert_inps(action, agent.policy_devices)
rng = agent._next_rngs(agent.policy_devices)
varibs = agent.varibs if agent.single_device else agent.policy_varibs

results, _ = _env_step(varibs, rng, obsc, None, actionc)

latent = results[0]
#img = agent._convert_outs(img, agent.policy_devices)
latento = agent._convert_outs(latent, agent.policy_devices)

#latentb = ({k: np.expand_dims(v, 0) for k, v in latento[0].items()})

#latentc = agent._convert_inps(latentb, agent.policy_devices)

#img = _decode(varibs, rng, latentc)
img = agent._convert_outs(results[1], agent.policy_devices)

print(len(results))
#print(img)
#print(type(latent))
#print(type(latent[0]["deter"]))
#print(type(latentc))
#print(latent[0]["deter"].shape)
print(img.keys())
print(img["image"].shape)
print(obs["image"].shape)
print(img["image"][0].dtype)
print(obs["image"][0].dtype)
image = img["image"][0]
#image = image / np.max(image)
image = np.clip(image, 0, 1)
image = (image * 255).astype(np.uint8)
plt.imshow(image)
#plt.imshow(img["image"][0])

In [None]:
obs = env.step({"action": action, "reset": [False]})
obsc = agent._convert_inps(obs, agent.policy_devices)
latentc = agent._convert_inps((latento), agent.policy_devices)
results, _ = _env_step(varibs, rng, obsc, latentc, actionc)
latento = agent._convert_outs(latent, agent.policy_devices)
img = agent._convert_outs(results[1], agent.policy_devices)
latent = results[0]
plt.imshow(img["image"][0])

In [None]:
latentc = agent._convert_inps((latento), agent.policy_devices)
results, _ = _imag_step(varibs, rng, latentc, actionc)
latento = agent._convert_outs(latent, agent.policy_devices)
img = agent._convert_outs(results[1], agent.policy_devices)
latent = results[0]
plt.imshow(img["image"][0])

In [None]:
from embodied.envs import atari
env1 = atari.Atari("ms_pacman", gray=False, actions="needed", size=(64, 64))
env1 = dreamerv3.wrap_env(env1, config)
env1 = embodied.BatchEnv([env1], parallel=False)

In [None]:
print(env1.act_space)