In [3]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [4]:
import gym
import numpy as np
import os
from VQVAE_environment import VQVAE_Env
from stable_baselines3.common.env_checker import check_env

## Testing the Environment Setup

In [5]:
# Create dummy surrogate model, decoder, and codebook to test the environment

import numpy as np

class MockSurrogateModel:
    def __init__(self):
        pass
    
    def evaluate(self, decoded_state):
        # Return a dummy accuracy value
        return np.random.random()

class MockDecoder:
    def __init__(self):
        pass
    
    def decode(self, state):
        # Return a dummy decoded state
        return state

# Create a dummy codebook as a numpy array
# Assuming the embed_dim is 10 and you have 100 embeddings plus 1 for the stop action
mock_codebook = np.random.rand(100, 10)


In [6]:
# Initialize your environment with the mock components
env = VQVAE_Env(embed_dim=10, num_embeddings=100, max_allowed_actions=200,
                surrogate_model=MockSurrogateModel(), decoder=MockDecoder(), codebook=mock_codebook,
                num_previous_actions=4)

In [7]:
# Using check_env from stable baselines 3 to check if the environment is compatible with stable baselines
check_env(env, warn=True)

In [8]:
# Manual testing of the environment

# Create an instance of the environment with dummy parameters
env = VQVAE_Env(
    embed_dim=10,
    num_embeddings=100,
    max_allowed_actions=20,
    surrogate_model=MockSurrogateModel(),  # Dummy surrogate model
    decoder=MockDecoder(),  # Dummy decoder
    codebook=mock_codebook  
)

# Reset the environment to start a new episode
observation = env.reset()
print("Initial Observation:", observation)

# Take actions in a loop until the episode ends
done = False
while not done:
    # Sample a random action
    action = env.sample_action()
    print("Taking action:", action)

    # Perform the action in the environment
    observation, reward, done, truncate, info = env.step(action)
    print("New Observation:", observation)
    print("Reward:", reward)
    print("Done:", done)
    print("Truncate:", truncate)
    print("Info:", info)
    print("---")

    if done:
        print("Episode finished after {} timesteps.".format(env.step_count))
        break

# Close the environment
env.close()


Initial Observation: ({'latent_vector': array([ 0.9183627 , -0.4344771 ,  1.2901202 ,  0.08363848, -0.44587874,
        2.771025  , -1.1156787 ,  1.4472796 , -0.6398812 , -0.24225442],
      dtype=float32), 'action_history': array([-1, -1, -1, -1], dtype=int32)}, {})
Taking action: 153
New Observation: {'latent_vector': array([ 0.9183627 , -0.4344771 ,  1.2901202 ,  0.8510552 , -0.44587874,
        2.771025  , -1.1156787 ,  1.4472796 , -0.6398812 , -0.24225442],
      dtype=float32), 'action_history': array([ -1,  -1,  -1, 153], dtype=int32)}
Reward: 0.07863918608774867
Done: False
Truncate: False
Info: {}
---
Taking action: 48
New Observation: {'latent_vector': array([ 0.9183627 , -0.4344771 ,  1.2901202 ,  0.8510552 , -0.44587874,
        2.771025  , -1.1156787 ,  1.4472796 ,  0.8455125 , -0.24225442],
      dtype=float32), 'action_history': array([ -1,  -1, 153,  48], dtype=int32)}
Reward: 0.24013735096384758
Done: False
Truncate: False
Info: {}
---
Taking action: 901
New Observatio

## Stable Baseline Training Script (with dummy Surrogate & Decoder)

In [26]:
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.env_util import make_vec_env
import wandb
from wandb.integration.sb3 import WandbCallback

In [9]:
model_dir = 'models'
log_dir = 'logs'
os.makedirs(model_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

In [13]:
# Instantiate the env
vec_env = make_vec_env(VQVAE_Env, n_envs=1, env_kwargs=dict(embed_dim=10,
    num_embeddings=100,
    max_allowed_actions=20,
    surrogate_model=MockSurrogateModel(),  # Dummy surrogate model
    decoder=MockDecoder(),  # Dummy decoder
    codebook=mock_codebook ))

In [18]:
vec_env.reset()

OrderedDict([('action_history', array([[-1, -1, -1, -1]], dtype=int32)),
             ('latent_vector',
              array([[ 0.10479282, -1.0372716 ,  1.3703396 ,  0.408466  , -0.2843564 ,
                      -1.0075978 ,  0.5536992 , -2.2233102 , -0.07724699, -1.0645074 ]],
                    dtype=float32))])

In [25]:
config = {
    "policy": 'MultiInputPolicy',
    "total_timesteps": 25000
}

run = wandb.init(
    config=config,
    sync_tensorboard=True,  # automatically upload SB3's tensorboard metrics to W&B
    project="Test",
    #monitor_gym=True,       # automatically upload gym environements' videos
    save_code=True,
)

In [27]:
# Train the agent
model = PPO(config['policy'], vec_env, verbose=1, tensorboard_log=log_dir)
model.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()



Using cpu device
Logging to logs/PPO_2
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 19.8     |
|    ep_rew_mean     | 0.464    |
| time/              |          |
|    fps             | 4440     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 19.9        |
|    ep_rew_mean          | 0.518       |
| time/                   |             |
|    fps                  | 2075        |
|    iterations           | 2           |
|    time_elapsed         | 1           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.027024012 |
|    clip_fraction        | 0.438       |
|    clip_range           | 0.2         |
|    entropy_loss         | -6.9        |
|    explained_variance   | -1.51

0,1
global_step,▁▂▂▃▃▄▅▅▆▆▇▇█
rollout/ep_len_mean,▆▇█▃▅█▅▇▁▇▆▇▅
rollout/ep_rew_mean,▁▇▇▆█▂▅▁▆▄▃▄▄
time/fps,█▂▂▁▁▁▁▁▁▁▁▁▁
train/approx_kl,▁▅▆▇████▇▇█▇
train/clip_fraction,▁▆▇█▇▆▆▆▅▃▃▃
train/clip_range,▁▁▁▁▁▁▁▁▁▁▁▁
train/entropy_loss,▁▂▃▃▄▅▅▆▆▇▇█
train/explained_variance,▁███████████
train/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁

0,1
global_step,26624.0
rollout/ep_len_mean,19.83
rollout/ep_rew_mean,0.49422
time/fps,1456.0
train/approx_kl,0.034
train/clip_fraction,0.46099
train/clip_range,0.2
train/entropy_loss,-6.80211
train/explained_variance,0.0404
train/learning_rate,0.0003


In [21]:
# Test the trained agent
# using the vecenv
obs = vec_env.reset()
n_steps = 20
for step in range(n_steps):
    action, _ = model.predict(obs, deterministic=True)
    print(f"Step {step + 1}")
    print("Action: ", action)
    obs, reward, done, info = vec_env.step(action)
    print("obs=", obs, "reward=", reward, "done=", done)
    vec_env.render()
    if done:
        # Note that the VecEnv resets automatically
        # when a done signal is encountered
        print("Goal reached!", "reward=", reward)
        break

Step 1
Action:  [955]
obs= OrderedDict([('action_history', array([[ -1,  -1,  -1, 955]], dtype=int32)), ('latent_vector', array([[ 0.05283138, -0.09423001,  0.31997174, -0.37998945,  0.15682319,
         0.20595308,  0.30689126,  0.27883932, -1.5004084 ,  0.84475845]],
      dtype=float32))]) reward= [0.5419346] done= [False]
Step 2
Action:  [921]
obs= OrderedDict([('action_history', array([[ -1,  -1, 955, 921]], dtype=int32)), ('latent_vector', array([[ 0.05283138,  0.6752847 ,  0.31997174, -0.37998945,  0.15682319,
         0.20595308,  0.30689126,  0.27883932, -1.5004084 ,  0.84475845]],
      dtype=float32))]) reward= [0.4386297] done= [False]
Step 3
Action:  [921]
obs= OrderedDict([('action_history', array([[ -1, 955, 921, 921]], dtype=int32)), ('latent_vector', array([[ 0.05283138,  0.6752847 ,  0.31997174, -0.37998945,  0.15682319,
         0.20595308,  0.30689126,  0.27883932, -1.5004084 ,  0.84475845]],
      dtype=float32))]) reward= [-0.22091103] done= [False]
Step 4
Action:



In [1]:
# Reference code for later!!


import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "CNN",
    "dataset": "CIFAR-100",
    "epochs": 10,
    }
)

# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
    acc = 1 - 2 ** -epoch - random.random() / epoch - offset
    loss = 2 ** -epoch + random.random() / epoch + offset

    # log metrics to wandb
    wandb.log({"acc": acc, "loss": loss})

# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33masaficontact[0m ([33mtrex-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


0,1
acc,▁▄▇▇█▇██
loss,█▅▃▃▁▂▂▁

0,1
acc,0.7922
loss,0.20247
