In [1]:
import numpy
import json
import utils
from utils import device
import hashlib
import torch
import numpy as np
from copy import deepcopy
import pickle
import matplotlib.pyplot as plt


def hash_state(state):
    state_string = pickle.dumps(state)
    return hashlib.sha256(state_string).hexdigest()



def depth_first_search(env, agent, depth):
    if depth == 0:
        return [hash_state(env)]
    
    next_states = get_next_states(env, agent)

    # for each of the next states, get their next states and append them to all_states
    all_states = []

    for next_state in next_states:
        env_new = pickle.loads(pickle.dumps(env)) # restore the environment state
        all_states.extend(depth_first_search(env_new, agent, depth=depth-1))
    
    return all_states

def hash_grid(env):
    # Retrieve the grid
    grid = env.grid.encode()

    # Retrieve the agent's position and direction
    agent_pos = env.agent_pos
    agent_dir = env.agent_dir

    # Flatten the grid and convert to string
    grid_string = ''.join(str(cell) for row in grid for cell in row)

    # Add the agent's position and direction to the string
    state_string = f'{grid_string},{agent_pos},{agent_dir}'
    
    print(agent_pos, agent_dir)
    # Hash the string
    return hashlib.sha256(state_string.encode('utf-8')).hexdigest()


def get_next_states(env, agent):
    """
    Returns all possible next states given current state
    """
    # find all possible actions
    action_space = env.action_space.n

    # initialize next_states list
    next_states = []
    saved_env = pickle.dumps(env)
    # get next state for each action
    for action in range(action_space):
        # Create a new environment instance by pickling and unpickling
        env_new = pickle.loads(saved_env)

        obs_new, _, _, _, _ = env_new.step(action)
        next_states.append(hash_state(env_new))
        if gif:
            frames.append(numpy.moveaxis(env_new.get_frame(), 2, 0))


    return set(next_states)


frames = []
# Replace command line arguments with hard-coded values.
env_name = "MiniGrid-DoorKey-6x6-v0"
model_name = "DoorKeya2c"
seed = 0
shift = 0
argmax = False
pause = 0.1
gif = True
episodes = 1
memory = False
text = False


# Set seed for all randomness sources
utils.seed(seed)

# Set device
print(f"Device: {device}\n")

# Load environment
env = utils.make_env(env_name, seed)
for _ in range(shift):
    env.reset()
print("Environment loaded\n")

# Load agent
model_dir = utils.get_model_dir(model_name)
agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                    argmax=argmax, use_memory=memory, use_text=text)
print("Agent loaded\n")

# Run the agent
if gif:
    from array2gif import write_gif
    frames = []

for episode in range(episodes):
    obs, _ = env.reset()
    cycle = 0

    while True:
        action = agent.get_action(obs)
        obs, reward, terminated, truncated, _ = env.step(action)

        next_states = depth_first_search(env, agent, depth=3)
        if gif:
            
            print(f"Saving gif {str(cycle) + 'decisions.gif'} of {len(frames)} length \n", end="")
            write_gif(numpy.array(frames), str(cycle) +"decisions.gif", fps=2/pause)
            print("Done.")
            frames =[]
            cycle+=1
        
        print(len(next_states))

        done = terminated | truncated
        agent.analyze_feedback(reward, done)

        if done:
            break




# if gif:
#     print("Saving gif... ", end="")
#     write_gif(numpy.array(frames), gif+".gif", fps=1/pause)
#     print("Done.")


pygame 2.4.0 (SDL 2.26.4, Python 3.7.12)
Hello from the pygame community. https://www.pygame.org/contribute.html
Device: cuda

Environment loaded

Agent loaded

Saving gif of 147 length 
Done.
64
Saving gif of 147 length 
Done.
64
Saving gif of 147 length 
Done.
64
Saving gif of 147 length 
Done.
64
Saving gif of 147 length 
Done.
64
Saving gif of 147 length 
Done.
64
Saving gif of 147 length 
Done.
64
Saving gif of 147 length 
Done.
64
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving gif of 147 length 
Done.
64
Saving gif of 147 length 
Done.
64
Saving gif of 217 length 
Done.
125
Saving gif of 217 length 
Done.
125
Saving gif of 217 length 
Done.
125
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving gif of 91 length 
Done.
27
Saving 

KeyboardInterrupt: 

In [None]:
a = 1
print(f"{=a}")

In [None]:
env.get_frame()

In [None]:
env2 = utils.make_env(env_name, seed, render_mode="human")
env2.grid = env.grid

image_data = env.get_frame()
plt.imshow(image_data)
plt.show()
        

In [None]:
env_copy2 = env.grid


In [None]:
dir(env)

In [None]:
env_copy2.grid

In [None]:
env_copy.grid

In [None]:
action = agent.get_action(obs)
obs, reward, terminated, truncated, _ = env.step(action)

# next_states now contains a dictionary of possible states at t+1, t+2, etc.        
next_states =  depth_first_search(env, obs, agent, depth=3)
print(len(next_states))

done = terminated | truncated
agent.analyze_feedback(reward, done)



In [None]:

for episode in range(args.episodes):
    obs, _ = env.reset()

#     while True:



