In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import gym
import numpy as np
from VQVAE_environment import VQVAE_Env
from stable_baselines3.common.env_checker import check_env

## Testing the Environment Setup

In [3]:
# Create dummy surrogate model, decoder, and codebook to test the environment

import numpy as np

class MockSurrogateModel:
    def __init__(self):
        pass
    
    def evaluate(self, decoded_state):
        # Return a dummy accuracy value
        return np.random.random()

class MockDecoder:
    def __init__(self):
        pass
    
    def decode(self, state):
        # Return a dummy decoded state
        return state

# Create a dummy codebook as a numpy array
# Assuming the embed_dim is 10 and you have 100 embeddings plus 1 for the stop action
mock_codebook = np.random.rand(101, 10)


In [4]:
# Initialize your environment with the mock components
env = VQVAE_Env(embed_dim=10, num_embeddings=101, max_allowed_actions=200,
                surrogate_model=MockSurrogateModel(), decoder=MockDecoder(), codebook=mock_codebook,
                num_previous_actions=4)

In [5]:
# Using check_env from stable baselines 3 to check if the environment is compatible with stable baselines
check_env(env, warn=True)

In [18]:
# Manual testing of the environment

# Create an instance of the environment with dummy parameters
env = VQVAE_Env(
    embed_dim=10,
    num_embeddings=101,
    max_allowed_actions=20,
    surrogate_model=MockSurrogateModel(),  # Dummy surrogate model
    decoder=MockDecoder(),  # Dummy decoder
    codebook=mock_codebook  
)

# Reset the environment to start a new episode
observation = env.reset()
print("Initial Observation:", observation)

# Take actions in a loop until the episode ends
done = False
while not done:
    # Sample a random action
    action = env.sample_action()
    print("Taking action:", action)

    # Perform the action in the environment
    observation, reward, done, truncate, info = env.step(action)
    print("New Observation:", observation)
    print("Reward:", reward)
    print("Done:", done)
    print("Truncate:", truncate)
    print("Info:", info)
    print("---")

    if done:
        print("Episode finished after {} timesteps.".format(env.step_count))
        break

# Close the environment
env.close()


Initial Observation: ({'latent_vector': array([ 1.3413056 , -0.6393575 , -0.79040223, -1.2632823 ,  1.256437  ,
        0.5965506 , -0.72908646, -1.217553  , -0.60058206,  1.2916957 ],
      dtype=float32), 'action_history': array([-1, -1, -1, -1], dtype=int32)}, {})
Taking action: 655
New Observation: {'latent_vector': array([ 1.3413056 , -0.6393575 , -0.79040223, -1.2632823 ,  1.256437  ,
        0.43826824, -0.72908646, -1.217553  , -0.60058206,  1.2916957 ],
      dtype=float32), 'action_history': array([ -1,  -1,  -1, 655], dtype=int32)}
Reward: 0.9952473670909366
Done: False
Truncate: False
Info: {}
---
Taking action: 126
New Observation: {'latent_vector': array([ 1.3413056 , -0.6393575 , -0.79040223, -1.2632823 ,  1.256437  ,
        0.43826824,  0.6787315 , -1.217553  , -0.60058206,  1.2916957 ],
      dtype=float32), 'action_history': array([ -1,  -1, 655, 126], dtype=int32)}
Reward: -0.28470189742054386
Done: False
Truncate: False
Info: {}
---
Taking action: 586
New Observati