In [69]:
from typing import Tuple
from gymnasium import spaces
from PIL import Image
from PIL import ImageDraw
import numpy as np

import jax
import jax.numpy as jnp
from jax import random
from flax.struct import dataclass

from evojax.task.base import TaskState
from evojax.task.base import VectorizedTask

SIZE_GRID = 3
AGENT_VIEW = 2


@dataclass
class AgentState(object):
    posx: jnp.int32
    posy: jnp.int32
    inventory: jnp.int32


@dataclass
class State(TaskState):
    obs: jnp.ndarray
    last_action: jnp.ndarray
    reward: jnp.ndarray
    state: jnp.ndarray
    agent: AgentState
    steps: jnp.int32
    permutation_recipe: jnp.ndarray
    key: jnp.ndarray
    grid_color: jnp.ndarray
    immo: jnp.int32


def get_obs(state: jnp.ndarray, posx: jnp.int32, posy: jnp.int32, grid_color: jnp.ndarray) -> jnp.ndarray:
    state_color = (jnp.expand_dims(state[:, :, 1:], axis=-1) * grid_color[:, :, 1:]).sum(axis=2)

    obs = jnp.ravel(jax.lax.dynamic_slice(
        jnp.pad(state_color, ((AGENT_VIEW, AGENT_VIEW), (AGENT_VIEW, AGENT_VIEW), (0, 0)), constant_values=1),
        (posx - AGENT_VIEW + AGENT_VIEW, posy - AGENT_VIEW + AGENT_VIEW, 0),
        (2 * AGENT_VIEW + 1, 2 * AGENT_VIEW + 1, 3)))
    return obs


def get_init_state_fn(key: jnp.ndarray) -> jnp.ndarray:
    grid = jnp.zeros((SIZE_GRID, SIZE_GRID, 6))
    posx, posy = (1, 1)
    grid = grid.at[posx, posy, 0].set(1)
    next_key, key = random.split(key)
    # pos_obj=jax.random.randint(key,(6,2),0,SIZE_GRID)
    # pos_obj=jnp.array([[0,1],[1,2],[2,1]])
    # grid=grid.at[pos_obj[0,0],pos_obj[0,1],1].add(1)
    # grid=grid.at[pos_obj[1,0],pos_obj[1,1],2].add(1)
    # grid=grid.at[pos_obj[2,0],pos_obj[2,1],3].add(1)
    # grid=grid.at[pos_obj[3,0],pos_obj[3,1],1].add(1)
    # grid=grid.at[pos_obj[4,0],pos_obj[4,1],2].add(1)
    # grid=grid.at[pos_obj[5,0],pos_obj[5,1],3].add(1)

    # next_key, key = random.split(next_key)
    # perm=jax.random.permutation(key,3)+1

    # grid=grid.at[pos_obj[0,0],pos_obj[0,1],perm[0]].add(1)
    # grid=grid.at[pos_obj[1,0],pos_obj[1,1],perm[1]].add(1)
    # grid=grid.at[pos_obj[2,0],pos_obj[2,1],perm[2]].add(1)

    # pos_ax=jax.random.randint(key,(3,),0,SIZE_GRID)

    next_key, key = random.split(next_key)
    perm = jax.random.randint(key, (3,), 0, SIZE_GRID)
    # pos_obj=jnp.array([[perm[0],0],[perm[1],1],[perm[2],2]])
    next_key, key = random.split(next_key)
    perm2 = jax.random.permutation(key, 3)

    pos_obj = jnp.array([[perm[0], perm2[0]], [perm[1], perm2[1]], [perm[2], perm2[2]]])

    # next_key, key = random.split(next_key)
    # perm=jax.random.permutation(key,3)+1

    grid = grid.at[pos_obj[0, 0], pos_obj[0, 1], 1].add(1)
    grid = grid.at[pos_obj[1, 0], pos_obj[1, 1], 2].add(1)
    grid = grid.at[pos_obj[2, 0], pos_obj[2, 1], 3].add(1)

    return (grid)


def test_recipes(items, recipes):
    recipe_done = jnp.where(items[recipes[0]] * items[recipes[1]] > 0, jnp.array([recipes[0], recipes[1], 4]),
                            jnp.zeros(3, jnp.int32))
    # recipe_done=jnp.where(items[recipes[2]]*items[4]>0,jnp.array([recipes[2],4,5]),recipe_done)
    product = recipe_done[2]
    reward = jnp.select([product == 0, product == 4], [0., 1.])
    return recipe_done, reward


def drop(grid, posx, posy, inventory, recipes):
    # vanilla drop

    # grid=grid.at[posx,posy,inventory].add(1)
    # inventory=0
    # cost=-0.
    # test recipe
    # recipe_done,reward=jax.lax.cond(grid[posx,posy,1:].sum()==2,test_recipes,lambda x,y:(jnp.zeros(3,jnp.int32),0.),*(grid[posx,posy,:],recipes))
    # grid=jnp.where(recipe_done[2]>0,grid.at[posx,posy,recipe_done[0]].set(0).at[posx,posy,recipe_done[1]].set(0).at[posx,posy,recipe_done[2]].set(1),grid)
    # reward=reward+cost

    # drop  only if right recipe otherwise stay in inventory
    grid = grid.at[posx, posy, inventory].add(1)
    # inventory=0
    cost = -0.
    # test recipe
    recipe_done, reward = jax.lax.cond(grid[posx, posy, 1:].sum() == 2, test_recipes,
                                       lambda x, y: (jnp.zeros(3, jnp.int32), 0.), *(grid[posx, posy, :], recipes))
    grid = jnp.where(recipe_done[2] > 0,
                     grid.at[posx, posy, recipe_done[0]].set(0).at[posx, posy, recipe_done[1]].set(0).at[
                         posx, posy, recipe_done[2]].set(1), grid.at[posx, posy, inventory].set(0))
    inventory = jnp.where(recipe_done[2] > 0, 0, inventory)

    empty_inv = jnp.logical_and(grid[posx, posy, 1:].sum() == 0, inventory > 0)
    grid = jnp.where(empty_inv, grid.at[posx, posy, inventory].set(1), grid)
    inventory = jnp.where(empty_inv, 0, inventory)

    reward = reward + cost

    return grid, inventory, reward


def collect(grid, posx, posy, inventory, key):
    # inventory=jnp.where(grid[posx,posy,1:].sum()>0,jnp.argmax(grid[posx,posy,1:])+1,0)
    inventory = jnp.where(grid[posx, posy, 1:].sum() > 0,
                          jax.random.categorical(key, jnp.log(grid[posx, posy, 1:] / (grid[posx, posy, 1:].sum()))) + 1,
                          0)
    grid = jnp.where(inventory > 0, grid.at[posx, posy, inventory].add(-1), grid)
    return grid, inventory


class Gridworld(VectorizedTask):
    """gridworld task."""

    def __init__(self,
                 max_steps: int = 200,
                 test: bool = False, spawn_prob=0.005):
        self.max_steps = max_steps
        self.obs_shape = tuple([(AGENT_VIEW * 2 + 1) * (AGENT_VIEW * 2 + 1) * 3 + 3 + 1, ])
        self.act_shape = tuple([7, ])
        self.test = test

        def reset_fn(key):
            next_key, key = random.split(key)
            posx, posy = (1, 1)
            agent = AgentState(posx=posx, posy=posy, inventory=0)
            grid = get_init_state_fn(key)

            next_key, key = random.split(next_key)
            # permutation_recipe=jax.random.permutation(key,3)[:3]+1
            permutation_recipe = jnp.array([1, 2, 3])

            next_key, key = random.split(next_key)
            # grid_color=jnp.concatenate([jnp.array([[[[1,0,0]]]]),jax.random.choice(key,jnp.array([0.1,0.5,1.]),(1,1,5,3))],axis=2)
            grid_color = jnp.concatenate([jnp.array([[[[1, 0, 0]]]]), jax.random.uniform(key, (1, 1, 5, 3))], axis=2)
            # rand=jax.random.uniform(key)
            # permutation_recipe=jnp.where(rand>0.5,jnp.array([1,2,3]),jnp.array([1,3,2]))
            # permutation_recipe=jnp.where(rand<0.5,jnp.array([2,3,1]),permutation_recipe)
            return State(state=grid, obs=jnp.concatenate(
                [get_obs(state=grid, posx=posx, posy=posy, grid_color=grid_color), jnp.zeros(3), jnp.zeros(1)]),
                         last_action=jnp.zeros((7,)), reward=jnp.zeros((1,)), agent=agent,
                         steps=jnp.zeros((), dtype=int), grid_color=grid_color, permutation_recipe=permutation_recipe,
                         key=next_key, immo=0)

        self._reset_fn = jax.jit(jax.vmap(reset_fn))

        def rest_keep_recipe(key, recipes, grid_color, steps,reward):
            next_key, key = random.split(key)
            posx, posy = (1, 1)
            agent = AgentState(posx=posx, posy=posy, inventory=0)
            grid = get_init_state_fn(key)

            return State(state=grid, obs=jnp.concatenate(
                [get_obs(state=grid, posx=posx, posy=posy, grid_color=grid_color), jnp.zeros(3), jnp.zeros(1)]),
                         last_action=jnp.zeros((7,)), reward=jnp.ones((1,)) * reward, agent=agent,
                         steps=steps, permutation_recipe=recipes, grid_color=grid_color, key=next_key, immo=0)

        def step_fn(state, action):
            # spawn food
            grid = state.state
            reward = 0

            # move agent
            key, subkey = random.split(state.key)
            # maybe later make the agent to output the one hot categorical
            action = action* (state.immo <= 0)
            action = jax.nn.one_hot(action, 7)

            action_int = action.astype(jnp.int32)

            posx = state.agent.posx - action_int[1] + action_int[3]
            posy = state.agent.posy - action_int[2] + action_int[4]
            posx = jnp.clip(posx, 0, SIZE_GRID - 1)
            posy = jnp.clip(posy, 0, SIZE_GRID - 1)
            grid = grid.at[state.agent.posx, state.agent.posy, 0].set(0)
            grid = grid.at[posx, posy, 0].set(1)
            # collect or drop
            inventory = state.agent.inventory
            key, subkey = random.split(key)
            grid, inventory, reward = jax.lax.cond(jnp.logical_and(action[5] > 0, inventory > 0), drop,
                                                   (lambda a, b, c, d, e: (a, d, 0.)),
                                                   *(grid, posx, posy, inventory, state.permutation_recipe))
            grid, inventory = jax.lax.cond(jnp.logical_and(action[6] > 0, inventory == 0), collect,
                                           (lambda a, b, c, d, e: (a, d)), *(grid, posx, posy, inventory, subkey))

            steps = state.steps + 1
            r_done = jnp.logical_or(grid[:, :, -2].sum() > 0, steps > self.max_steps-1)
            done= steps > self.max_steps-1
            immo = jnp.where(reward < 0, 0, state.immo - 1)
            immo = jnp.clip(immo, 0, 5)

            # key, subkey = random.split(key)
            # rand=jax.random.uniform(subkey)
            # catastrophic=jnp.logical_and(steps>40,rand<1)
            # done=jnp.logical_or(done, catastrophic)
            # a=state.permutation_recipe[1]
            # b=state.permutation_recipe[2]
            # permutation_recipe=jnp.where(catastrophic,state.permutation_recipe.at[1].set(b).at[2].set(a), state.permutation_recipe)
            # steps = jnp.where(catastrophic, jnp.zeros((), jnp.int32), steps)

            cur_state = State(state=grid, obs=jnp.concatenate(
                [get_obs(state=grid, posx=posx, posy=posy, grid_color=state.grid_color),
                 (jnp.expand_dims(jax.nn.one_hot(inventory, 6)[1:], axis=-1) * state.grid_color[0, 0, 1:]).sum(0),
                 (immo / 25) * jnp.ones(1)]),
                              last_action=action, reward=jnp.ones((1,)) * reward,
                              agent=AgentState(posx=posx, posy=posy, inventory=inventory),
                              steps=steps, permutation_recipe=state.permutation_recipe, grid_color=state.grid_color,
                              key=key, immo=immo)

            # keep it in case we let agent several trials
            state = jax.lax.cond(
                r_done,
                lambda x: rest_keep_recipe(key, state.permutation_recipe, grid_color=state.grid_color, steps=steps,reward=reward),
                lambda x: x, cur_state)

            return state, reward, done

        self._step_fn = jax.jit(jax.vmap(step_fn))

    def reset(self, key: jnp.ndarray) -> State:
        return self._reset_fn(key)

    def step(self,
             state: State,
             action: jnp.ndarray) -> Tuple[State, jnp.ndarray, jnp.ndarray]:
        return self._step_fn(state, action)



import gym
import numpy as np
import time


class Grid:
    def __init__(self,max_episode_steps=256):
        self.max_episode_steps=max_episode_steps
        self._env = Gridworld(max_steps=max_episode_steps)
        self.key=jax.random.PRNGKey(0)
        # Whether to make CartPole partial observable by masking out the velocity.


    @property
    def observation_space(self):
        self._observation_space = spaces.Box(
            low=-1.,
            high=1.0,
            shape=((2*AGENT_VIEW+1)*(2*AGENT_VIEW+1)*3+3+1+1+7,),
            dtype=np.float32)
        return self._observation_space

    @property
    def action_space(self):
        return spaces.Discrete(7)

    def reset(self):
        self._rewards = []
        self.key,key=jax.random.split(self.key)
        key=jax.random.split(key,1)
        state = self._env.reset(key)
        self.state=state
        obs=state.obs
        r=state.reward
        la=state.last_action
        return np.array(jnp.concatenate([obs,r,la],axis=1))[0]

    def step(self, action):

        state, reward, done = self._env.step(self.state,action[0]*jnp.ones((1)))
        self.state=state
        obs = state.obs
        reward=state.reward
        la = state.last_action
        obs=jnp.concatenate([obs,reward,la],axis=1)

        self._rewards.append(reward)
        if done[0]:
            self._rewards=np.array(self._rewards)
            info = [{"reward": sum(self._rewards[:,w]),
                    "length": self._rewards.shape[0]} for w in range(self._rewards.shape[1])]
        else:

            info = None
        return np.array(obs )[0], np.array(reward[:,0])[0] , np.array(done)[0], info

    def render(self):

        state_color = (jnp.expand_dims(self.state.state[0,:, :, :], axis=-1) * self.state.grid_color[0,:, :, :]).sum(axis=2)
        state_color=jnp.repeat(state_color,150,0)
        state_color=jnp.repeat(state_color,150,1)
        return(state_color)


    def close(self):
        print('not imp')

In [70]:
import os
os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
from IPython.display import HTML, display, clear_output

class VideoWriter:
  def __init__(self, filename, fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()

  def show(self, **kw):
      self.close()
      fn = self.params['filename']
      display(mvp.ipython_display(fn, **kw))


In [72]:
max_steps=256
env = Grid(max_episode_steps=max_steps)

In [73]:
!pip install ruamel_yaml


You should consider upgrading via the '/home/ghamon/miniconda3/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m

In [93]:

import numpy as np
import pickle
import torch
#torch.set_default_tensor_type("torch.cuda.FloatTensor")

from docopt import docopt
from model import ActorCriticModel
from utils import create_env


def init_transformer_memory(trxl_conf, max_episode_steps, device):
    """Returns initial tensors for the episodic memory of the transformer.

    Arguments:
        trxl_conf {dict} -- Transformer configuration dictionary
        max_episode_steps {int} -- Maximum number of steps per episode
        device {torch.device} -- Target device for the tensors

    Returns:
        memory {torch.Tensor}, memory_mask {torch.Tensor}, memory_indices {torch.Tensor} -- Initial episodic memory, episodic memory mask, and sliding memory window indices
    """
    # Episodic memory mask used in attention
    memory_mask = torch.tril(torch.ones((trxl_conf["memory_length"], trxl_conf["memory_length"])), diagonal=-1)
    # Episdic memory tensor
    memory = torch.zeros((1, max_episode_steps, trxl_conf["num_blocks"], trxl_conf["embed_dim"])).to(device)
    # Setup sliding memory window indices
    repetitions = torch.repeat_interleave(torch.arange(0, trxl_conf["memory_length"]).unsqueeze(0), max_episode_steps - 1, dim = 0).long()
    memory_indices = torch.stack([torch.arange(i, i + trxl_conf["memory_length"]) for i in range(max_episode_steps - trxl_conf["memory_length"] + 1)]).long()
    memory_indices = torch.cat((repetitions, memory_indices))
    return memory, memory_mask, memory_indices

_USAGE = """
Usage:
    enjoy.py [options]
    enjoy.py --help

Options:
    --model=<path>              Specifies the path to the trained model [default: ./models/run.nn].
"""
model_path = "models/my-trxl-training_mem16_3blocks_no_penalty_gamma95.nn"

# Set inference device and default tensor type
device = torch.device("cpu")
torch.set_default_tensor_type("torch.FloatTensor")

# Load model and config
state_dict, config = pickle.load(open(model_path, "rb"))
#print(state_dict)


# Instantiate environment


# Initialize model and load its parameters
model = ActorCriticModel(config, env.observation_space, (env.action_space.n,), env.max_episode_steps)
model.load_state_dict(state_dict)
model.train()
model.to(device)


# Run and render episode
done = False
episode_rewards = []
memory, memory_mask, memory_indices = init_transformer_memory(config["transformer"], env.max_episode_steps, device)

memory_length = config["transformer"]["memory_length"]
t = 0
obs = env.reset()
with VideoWriter("out.mp4",12.0) as vid:




    while not done:
        # Prepare observation and memory
        obs = torch.tensor(np.expand_dims(obs, 0), dtype=torch.float32, device=device)
        
        in_memory = memory[0, memory_indices[t].unsqueeze(0)]
        
        t_ = max(0, min(t, memory_length - 1))
        mask = memory_mask[t_].unsqueeze(0)
        #indices = memory_indices[t].unsqueeze(0)
        # Render environment
        rgb_im=env.render()
        vid.add(rgb_im)
        # Forward model
        #print(obs.device,in_memory.device,mask.device,indices.device)
        print(indices)
        policy, value, new_memory = model(obs, in_memory, mask, indices)
        
        memory[:, t] = new_memory
        # Sample action
        action = []
        for action_branch in policy:
            action.append(action_branch.sample().item())
        #print(action)
        # Step environemnt
        obs, reward, done, info = env.step(action)
        episode_rewards.append(reward)
        t += 1
    vid.show()

# after done, render last state
env.render()

print("Episode length: " + str(info[0]["length"]))
print("Episode reward: " + str(info[0]["reward"]))

env.close()

print(memory.shape)
print(mask.shape)
print(indices)
print(memory_indices)

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3, 

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3, 

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3, 

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3, 

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])


Episode length: 256
Episode reward: [16.]
not imp
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


In [90]:
print(indices)

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])


In [88]:
print(mask.shape)
print(obs.shape)

torch.Size([1, 16])
torch.Size([1, 87])


In [79]:
print(memory.shape)

torch.Size([1, 256, 3, 256])


In [83]:
print(memory[0,0,

tensor([-3.5152e-02, -1.9442e-02,  3.7384e-01, -4.9361e-01, -7.3528e-02,
        -2.3118e-01, -2.6802e-01, -1.6547e-01,  1.1482e+00, -1.0367e-01,
        -3.2227e-03, -5.8481e-01,  2.3024e-02,  2.1371e+00, -4.2670e-01,
        -4.8509e-01,  6.2796e-02, -1.6044e-01, -6.0848e-01, -3.3947e-01,
        -6.8730e-01,  1.4101e-01, -3.7369e-01, -4.3724e-01,  6.6944e-01,
        -1.6394e-01, -4.1897e-01, -6.7200e-01, -5.4942e-01, -2.8958e-01,
        -2.3336e-01,  7.3447e-01, -4.7756e-01,  1.2498e+00, -3.2395e-01,
        -5.0730e-01, -6.7888e-01, -3.1369e-01, -1.9708e-01,  1.4229e+00,
        -1.5309e-02, -3.6904e-01,  9.1445e-02, -3.6449e-01, -1.2770e-01,
         1.8963e-01, -4.0836e-02, -2.3859e-01,  2.3598e-01, -3.3576e-01,
         8.6697e-02, -1.5435e+00,  4.2491e-01,  2.5121e-01, -4.0899e-01,
         2.6914e-02,  5.4118e-01, -2.7005e-01, -7.3023e-01,  2.0949e-01,
        -6.8689e-01,  2.8380e-01, -1.8262e-01, -1.5466e-01,  5.2786e-01,
         3.5447e-01,  2.3430e-01, -2.5634e-01, -1.0

In [76]:
a=[]
for _ in range(16):
    _USAGE = """
    Usage:
        enjoy.py [options]
        enjoy.py --help

    Options:
        --model=<path>              Specifies the path to the trained model [default: ./models/run.nn].
    """
    model_path = "models/my-trxl-training_mem16_3blocks_no_penalty_gamma95.nn"

    # Set inference device and default tensor type
    device = torch.device("cpu")
    torch.set_default_tensor_type("torch.FloatTensor")

    # Load model and config
    state_dict, config = pickle.load(open(model_path, "rb"))


    # Instantiate environment


    # Initialize model and load its parameters
    model = ActorCriticModel(config, env.observation_space, (env.action_space.n,), env.max_episode_steps)
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)


    # Run and render episode
    done = False
    episode_rewards = []
    memory, memory_mask, memory_indices = init_transformer_memory(config["transformer"], env.max_episode_steps, device)

    memory_length = config["transformer"]["memory_length"]
    t = 0
    obs = env.reset()
    with VideoWriter("out.mp4",12.0) as vid:




        while not done:
            # Prepare observation and memory
            obs = torch.tensor(np.expand_dims(obs, 0), dtype=torch.float32, device=device)

            in_memory = memory[0, memory_indices[t].unsqueeze(0)]
            t_ = max(0, min(t, memory_length - 1))
            mask = memory_mask[t_].unsqueeze(0)
            indices = memory_indices[t].unsqueeze(0)
            # Render environment
            rgb_im=env.render()
            vid.add(rgb_im)
            # Forward model
            #print(obs.device,in_memory.device,mask.device,indices.device)
            policy, value, new_memory = model(obs, in_memory, mask, indices)

            memory[:, t] = new_memory
            # Sample action
            action = []
            for action_branch in policy:
                action.append(action_branch.sample().item())
            #print(action)
            # Step environemnt
            obs, reward, done, info = env.step(action)
            episode_rewards.append(reward)
            t += 1
        vid.show()

    # after done, render last state
    env.render()

    print("Episode length: " + str(info[0]["length"]))
    print("Episode reward: " + str(info[0]["reward"]))
    a.append(info[0]["reward"][0])


    print(memory.shape)
    print(mask.shape)
    print(indices)
    print(memory_indices)
    

Episode length: 256
Episode reward: [47.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [1.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [2.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [1.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [1.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [3.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [44.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [48.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [48.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [38.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [48.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [44.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [47.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [41.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [41.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


Episode length: 256
Episode reward: [51.]
torch.Size([1, 256, 3, 256])
torch.Size([1, 16])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        [  0,   1,   2,  ...,  13,  14,  15],
        ...,
        [238, 239, 240,  ..., 251, 252, 253],
        [239, 240, 241,  ..., 252, 253, 254],
        [240, 241, 242,  ..., 253, 254, 255]])


In [63]:
print(np.mean(a),np.std(a),a)

0.75 0.8291562 [2.0, 1.0, 0.0, 2.0, 2.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0]


In [40]:
from numpy import array
from numpy import float32

In [66]:
a=[{'reward': array([2.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([1.], dtype=float32), 'length': 128}, {'reward': array([2.], dtype=float32), 'length': 128}, {'reward': array([1.], dtype=float32), 'length': 128}, {'reward': array([2.], dtype=float32), 'length': 128}, {'reward': array([2.], dtype=float32), 'length': 128}, {'reward': array([3.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([3.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([1.], dtype=float32), 'length': 128}]

In [64]:
a=[{'reward': array([1.], dtype=float32), 'length': 128}, {'reward': array([1.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([1.], dtype=float32), 'length': 128}, {'reward': array([2.], dtype=float32), 'length': 128}, {'reward': array([1.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([2.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([1.], dtype=float32), 'length': 128}, {'reward': array([0.], dtype=float32), 'length': 128}, {'reward': array([2.], dtype=float32), 'length': 128}]

In [67]:
b=[]
for item in a:
    b.append(item["reward"])

b=np.array(b)
print(b)
print(np.mean(b))
print(np.std(b))

[[2.]
 [0.]
 [0.]
 [0.]
 [1.]
 [2.]
 [1.]
 [2.]
 [2.]
 [3.]
 [0.]
 [3.]
 [0.]
 [0.]
 [0.]
 [1.]]
1.0625
1.0879309
