In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [1]:
import gym
import numpy as np
import os
from VQVAE_environment import VQVAE_Env
from stable_baselines3.common.env_checker import check_env

## Testing the Environment Setup

In [5]:
# Create dummy surrogate model, decoder, and codebook to test the environment

import numpy as np

class MockSurrogateModel:
    def __init__(self):
        pass
    
    def evaluate(self, decoded_state):
        # Return a dummy accuracy value
        return np.random.random()

class MockDecoder:
    def __init__(self):
        pass
    
    def decode(self, state):
        # Return a dummy decoded state
        return state

# Create a dummy codebook as a numpy array
# Assuming the embed_dim is 10 and you have 100 embeddings plus 1 for the stop action
mock_codebook = np.random.rand(100, 10)


In [6]:
# Initialize your environment with the mock components
env = VQVAE_Env(embed_dim=10, num_embeddings=100, max_allowed_actions=200,
                surrogate_model=MockSurrogateModel(), decoder=MockDecoder(), codebook=mock_codebook,
                num_previous_actions=4)

In [7]:
# Using check_env from stable baselines 3 to check if the environment is compatible with stable baselines
check_env(env, warn=True)

In [8]:
# Manual testing of the environment

# Create an instance of the environment with dummy parameters
env = VQVAE_Env(
    embed_dim=10,
    num_embeddings=100,
    max_allowed_actions=20,
    surrogate_model=MockSurrogateModel(),  # Dummy surrogate model
    decoder=MockDecoder(),  # Dummy decoder
    codebook=mock_codebook  
)

# Reset the environment to start a new episode
observation = env.reset()
print("Initial Observation:", observation)

# Take actions in a loop until the episode ends
done = False
while not done:
    # Sample a random action
    action = env.sample_action()
    print("Taking action:", action)

    # Perform the action in the environment
    observation, reward, done, truncate, info = env.step(action)
    print("New Observation:", observation)
    print("Reward:", reward)
    print("Done:", done)
    print("Truncate:", truncate)
    print("Info:", info)
    print("---")

    if done:
        print("Episode finished after {} timesteps.".format(env.step_count))
        break

# Close the environment
env.close()


Initial Observation: ({'latent_vector': array([ 0.9183627 , -0.4344771 ,  1.2901202 ,  0.08363848, -0.44587874,
        2.771025  , -1.1156787 ,  1.4472796 , -0.6398812 , -0.24225442],
      dtype=float32), 'action_history': array([-1, -1, -1, -1], dtype=int32)}, {})
Taking action: 153
New Observation: {'latent_vector': array([ 0.9183627 , -0.4344771 ,  1.2901202 ,  0.8510552 , -0.44587874,
        2.771025  , -1.1156787 ,  1.4472796 , -0.6398812 , -0.24225442],
      dtype=float32), 'action_history': array([ -1,  -1,  -1, 153], dtype=int32)}
Reward: 0.07863918608774867
Done: False
Truncate: False
Info: {}
---
Taking action: 48
New Observation: {'latent_vector': array([ 0.9183627 , -0.4344771 ,  1.2901202 ,  0.8510552 , -0.44587874,
        2.771025  , -1.1156787 ,  1.4472796 ,  0.8455125 , -0.24225442],
      dtype=float32), 'action_history': array([ -1,  -1, 153,  48], dtype=int32)}
Reward: 0.24013735096384758
Done: False
Truncate: False
Info: {}
---
Taking action: 901
New Observatio

## Stable Baseline Training Script (with dummy Surrogate & Decoder)

In [26]:
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.env_util import make_vec_env
import wandb
from wandb.integration.sb3 import WandbCallback

In [9]:
model_dir = 'models'
log_dir = 'logs'
os.makedirs(model_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

In [13]:
# Instantiate the env
vec_env = make_vec_env(VQVAE_Env, n_envs=1, env_kwargs=dict(embed_dim=10,
    num_embeddings=100,
    max_allowed_actions=20,
    surrogate_model=MockSurrogateModel(),  # Dummy surrogate model
    decoder=MockDecoder(),  # Dummy decoder
    codebook=mock_codebook ))

In [18]:
vec_env.reset()

OrderedDict([('action_history', array([[-1, -1, -1, -1]], dtype=int32)),
             ('latent_vector',
              array([[ 0.10479282, -1.0372716 ,  1.3703396 ,  0.408466  , -0.2843564 ,
                      -1.0075978 ,  0.5536992 , -2.2233102 , -0.07724699, -1.0645074 ]],
                    dtype=float32))])

In [25]:
config = {
    "policy": 'MultiInputPolicy',
    "total_timesteps": 25000
}

run = wandb.init(
    config=config,
    sync_tensorboard=True,  # automatically upload SB3's tensorboard metrics to W&B
    project="Test",
    #monitor_gym=True,       # automatically upload gym environements' videos
    save_code=True,
)

In [27]:
# Train the agent
model = PPO(config['policy'], vec_env, verbose=1, tensorboard_log=log_dir)
model.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()



Using cpu device
Logging to logs/PPO_2
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 19.8     |
|    ep_rew_mean     | 0.464    |
| time/              |          |
|    fps             | 4440     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 19.9        |
|    ep_rew_mean          | 0.518       |
| time/                   |             |
|    fps                  | 2075        |
|    iterations           | 2           |
|    time_elapsed         | 1           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.027024012 |
|    clip_fraction        | 0.438       |
|    clip_range           | 0.2         |
|    entropy_loss         | -6.9        |
|    explained_variance   | -1.51

0,1
global_step,▁▂▂▃▃▄▅▅▆▆▇▇█
rollout/ep_len_mean,▆▇█▃▅█▅▇▁▇▆▇▅
rollout/ep_rew_mean,▁▇▇▆█▂▅▁▆▄▃▄▄
time/fps,█▂▂▁▁▁▁▁▁▁▁▁▁
train/approx_kl,▁▅▆▇████▇▇█▇
train/clip_fraction,▁▆▇█▇▆▆▆▅▃▃▃
train/clip_range,▁▁▁▁▁▁▁▁▁▁▁▁
train/entropy_loss,▁▂▃▃▄▅▅▆▆▇▇█
train/explained_variance,▁███████████
train/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁

0,1
global_step,26624.0
rollout/ep_len_mean,19.83
rollout/ep_rew_mean,0.49422
time/fps,1456.0
train/approx_kl,0.034
train/clip_fraction,0.46099
train/clip_range,0.2
train/entropy_loss,-6.80211
train/explained_variance,0.0404
train/learning_rate,0.0003


In [21]:
# Test the trained agent
# using the vecenv
obs = vec_env.reset()
n_steps = 20
for step in range(n_steps):
    action, _ = model.predict(obs, deterministic=True)
    print(f"Step {step + 1}")
    print("Action: ", action)
    obs, reward, done, info = vec_env.step(action)
    print("obs=", obs, "reward=", reward, "done=", done)
    vec_env.render()
    if done:
        # Note that the VecEnv resets automatically
        # when a done signal is encountered
        print("Goal reached!", "reward=", reward)
        break

Step 1
Action:  [955]
obs= OrderedDict([('action_history', array([[ -1,  -1,  -1, 955]], dtype=int32)), ('latent_vector', array([[ 0.05283138, -0.09423001,  0.31997174, -0.37998945,  0.15682319,
         0.20595308,  0.30689126,  0.27883932, -1.5004084 ,  0.84475845]],
      dtype=float32))]) reward= [0.5419346] done= [False]
Step 2
Action:  [921]
obs= OrderedDict([('action_history', array([[ -1,  -1, 955, 921]], dtype=int32)), ('latent_vector', array([[ 0.05283138,  0.6752847 ,  0.31997174, -0.37998945,  0.15682319,
         0.20595308,  0.30689126,  0.27883932, -1.5004084 ,  0.84475845]],
      dtype=float32))]) reward= [0.4386297] done= [False]
Step 3
Action:  [921]
obs= OrderedDict([('action_history', array([[ -1, 955, 921, 921]], dtype=int32)), ('latent_vector', array([[ 0.05283138,  0.6752847 ,  0.31997174, -0.37998945,  0.15682319,
         0.20595308,  0.30689126,  0.27883932, -1.5004084 ,  0.84475845]],
      dtype=float32))]) reward= [-0.22091103] done= [False]
Step 4
Action:



In [1]:
# Reference code for later!!


import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "CNN",
    "dataset": "CIFAR-100",
    "epochs": 10,
    }
)

# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
    acc = 1 - 2 ** -epoch - random.random() / epoch - offset
    loss = 2 ** -epoch + random.random() / epoch + offset

    # log metrics to wandb
    wandb.log({"acc": acc, "loss": loss})

# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33masaficontact[0m ([33mtrex-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


0,1
acc,▁▄▇▇█▇██
loss,█▅▃▃▁▂▂▁

0,1
acc,0.7922
loss,0.20247


## Setting up the Surrogate Model

In [49]:
import torch
import pandas as pd
from SurrogateModel import SurrogateModel
df = pd.read_csv('/Users/tawab/Desktop/columbia/Courses/Spring2024/HPML/Project/Analog_NAS/data/dataset_cifar10_v1.csv')

### Testing with Examples from Dataset

In [50]:
# Use orginal data from the training set to test the Surrogate model evaluate function
value_mapping = {"A": 1, "B": 2, "C": 3, "D": 4}
df["convblock1"] = df["convblock1"].replace(value_mapping).astype('float32')
df["convblock2"] = df["convblock2"].replace(value_mapping).astype('float32')
df["convblock3"] = df["convblock3"].replace(value_mapping).astype('float32')
df["convblock4"] = df["convblock4"].replace(value_mapping).astype('float32')
df["convblock5"] = df["convblock5"].replace(value_mapping).astype('float32')
data = df.iloc[:,:22]
x = data.head()
normal_input_value = x.values
normal_input_value = torch.tensor(normal_input_value, dtype=torch.float32)

  df["convblock1"] = df["convblock1"].replace(value_mapping).astype('float32')


In [51]:
S_model = SurrogateModel('/Users/tawab/Desktop/columbia/Courses/Spring2024/HPML/Project/Analog_NAS/env/models/surrogate_model.json')

In [52]:
pred = S_model.evaluate(normal_input_value)
print('Predicted accuracy:', pred)
print('Actual accuracy:', df['1_day_accuracy'].iloc[:5])

Predicted accuracy: [0.92884475 0.9123806  0.7297597  0.77937865 0.79452664]
Actual accuracy: 0    0.923597
1    0.922466
2    0.876552
3    0.658484
4    0.874308
Name: 1_day_accuracy, dtype: float64


### Testing with Predictions from Decoder

In [53]:
decoder_data = torch.tensor([[64.6800,  3.0988,  8.2265,  6.8180,  4.8326,  3.3803,  2.1295,  2.5059,
          2.5093,  6.3183,  1.9964,  1.8305,  4.9266,  1.5593,  1.7053,  3.8663,
          1.2503,  1.2723,  2.7812,  0.6364,  0.0000,  1.6426],
        [64.4652,  3.0861,  8.2008,  6.7940,  4.8138,  3.3655,  2.1192,  2.4984,
          2.5002,  6.3002,  1.9876,  1.8226,  4.9095,  1.5518,  1.6977,  3.8513,
          1.2425,  1.2670,  2.7679,  0.6332,  0.0000,  1.6346],
        [64.2907,  3.0758,  8.1799,  6.7745,  4.7985,  3.3535,  2.1109,  2.4924,
          2.4928,  6.2855,  1.9804,  1.8161,  4.8957,  1.5458,  1.6916,  3.8392,
          1.2363,  1.2627,  2.7571,  0.6306,  0.0000,  1.6282],
        [64.6227,  3.0954,  8.2196,  6.8116,  4.8276,  3.3763,  2.1267,  2.5039,
          2.5069,  6.3135,  1.9941,  1.8284,  4.9220,  1.5573,  1.7033,  3.8623,
          1.2482,  1.2709,  2.7776,  0.6355,  0.0000,  1.6405],
        [64.2712,  3.0746,  8.1775,  6.7723,  4.7968,  3.3522,  2.1100,  2.4917,
          2.4920,  6.2839,  1.9796,  1.8154,  4.8941,  1.5451,  1.6909,  3.8378,
          1.2356,  1.2623,  2.7559,  0.6303,  0.0000,  1.6274],
        [64.8440,  3.1085,  8.2461,  6.8364,  4.8470,  3.3916,  2.1373,  2.5116,
          2.5163,  6.3321,  2.0032,  1.8366,  4.9396,  1.5649,  1.7112,  3.8777,
          1.2562,  1.2763,  2.7913,  0.6388,  0.0000,  1.6487],
        [65.2637,  3.1333,  8.2964,  6.8833,  4.8837,  3.4204,  2.1572,  2.5261,
          2.5341,  6.3674,  2.0204,  1.8522,  4.9729,  1.5795,  1.7260,  3.9070,
          1.2713,  1.2867,  2.8173,  0.6451,  0.0000,  1.6642],
        [64.2908,  3.0758,  8.1799,  6.7745,  4.7985,  3.3535,  2.1109,  2.4924,
          2.4928,  6.2855,  1.9804,  1.8161,  4.8957,  1.5458,  1.6916,  3.8392,
          1.2363,  1.2627,  2.7571,  0.6306,  0.0000,  1.6282],
        [64.6034,  3.0943,  8.2173,  6.8095,  4.8259,  3.3750,  2.1258,  2.5032,
          2.5061,  6.3118,  1.9933,  1.8277,  4.9205,  1.5566,  1.7027,  3.8610,
          1.2475,  1.2704,  2.7764,  0.6352,  0.0000,  1.6397],
        [64.4813,  3.0870,  8.2027,  6.7958,  4.8152,  3.3666,  2.1200,  2.4990,
          2.5009,  6.3016,  1.9882,  1.8232,  4.9108,  1.5524,  1.6983,  3.8525,
          1.2431,  1.2674,  2.7689,  0.6334,  0.0000,  1.6352],
        [65.2572,  3.1329,  8.2956,  6.8826,  4.8831,  3.4200,  2.1569,  2.5258,
          2.5338,  6.3669,  2.0201,  1.8519,  4.9724,  1.5793,  1.7258,  3.9066,
          1.2710,  1.2865,  2.8169,  0.6450,  0.0000,  1.6640],
        [64.2538,  3.0736,  8.1754,  6.7704,  4.7953,  3.3510,  2.1092,  2.4911,
          2.4913,  6.2824,  1.9789,  1.8148,  4.8927,  1.5445,  1.6903,  3.8366,
          1.2349,  1.2618,  2.7548,  0.6300,  0.0000,  1.6268],
        [64.9387,  3.1141,  8.2575,  6.8470,  4.8552,  3.3981,  2.1418,  2.5148,
          2.5203,  6.3400,  2.0070,  1.8401,  4.9471,  1.5682,  1.7145,  3.8843,
          1.2596,  1.2787,  2.7972,  0.6402,  0.0000,  1.6522],
        [63.6988,  3.0408,  8.1090,  6.7083,  4.7467,  3.3128,  2.0828,  2.4719,
          2.4677,  6.2357,  1.9561,  1.7942,  4.8487,  1.5252,  1.6706,  3.7979,
          1.2150,  1.2482,  2.7204,  0.6218,  0.0000,  1.6062],
        [63.9417,  3.0552,  8.1381,  6.7355,  4.7679,  3.3295,  2.0943,  2.4803,
          2.4780,  6.2561,  1.9661,  1.8032,  4.8680,  1.5337,  1.6792,  3.8148,
          1.2237,  1.2542,  2.7354,  0.6254,  0.0000,  1.6152],
        [65.2935,  3.1350,  8.3000,  6.8867,  4.8863,  3.4225,  2.1587,  2.5271,
          2.5354,  6.3699,  2.0216,  1.8533,  4.9753,  1.5805,  1.7271,  3.9091,
          1.2723,  1.2874,  2.8192,  0.6455,  0.0000,  1.6653],
        [65.2400,  3.1319,  8.2936,  6.8807,  4.8816,  3.4188,  2.1561,  2.5252,
          2.5331,  6.3654,  2.0194,  1.8513,  4.9710,  1.5787,  1.7252,  3.9053,
          1.2704,  1.2861,  2.8159,  0.6447,  0.0000,  1.6633],
        [64.9390,  3.1141,  8.2575,  6.8470,  4.8553,  3.3981,  2.1418,  2.5148,
          2.5203,  6.3401,  2.0071,  1.8402,  4.9471,  1.5682,  1.7145,  3.8844,
          1.2596,  1.2787,  2.7972,  0.6402,  0.0000,  1.6522],
        [64.6351,  3.0961,  8.2211,  6.8130,  4.8287,  3.3772,  2.1273,  2.5043,
          2.5074,  6.3145,  1.9946,  1.8289,  4.9230,  1.5577,  1.7038,  3.8632,
          1.2487,  1.2712,  2.7784,  0.6357,  0.0000,  1.6409],
        [64.3260,  3.0779,  8.1841,  6.7784,  4.8016,  3.3560,  2.1126,  2.4936,
          2.4943,  6.2885,  1.9819,  1.8175,  4.8985,  1.5470,  1.6928,  3.8416,
          1.2376,  1.2636,  2.7593,  0.6311,  0.0000,  1.6295],
        [64.5792,  3.0928,  8.2144,  6.8068,  4.8238,  3.3734,  2.1247,  2.5024,
          2.5051,  6.3098,  1.9923,  1.8268,  4.9186,  1.5558,  1.7018,  3.8593,
          1.2466,  1.2698,  2.7749,  0.6349,  0.0000,  1.6389],
        [63.5627,  3.0328,  8.0927,  6.6930,  4.7348,  3.3035,  2.0763,  2.4672,
          2.4619,  6.2242,  1.9506,  1.7892,  4.8379,  1.5205,  1.6658,  3.7884,
          1.2101,  1.2449,  2.7120,  0.6197,  0.0000,  1.6012],
        [64.5792,  3.0928,  8.2144,  6.8068,  4.8238,  3.3734,  2.1247,  2.5024,
          2.5051,  6.3098,  1.9923,  1.8268,  4.9186,  1.5558,  1.7018,  3.8593,
          1.2466,  1.2698,  2.7749,  0.6349,  0.0000,  1.6389],
        [64.3263,  3.0779,  8.1841,  6.7785,  4.8016,  3.3560,  2.1126,  2.4936,
          2.4943,  6.2885,  1.9819,  1.8175,  4.8985,  1.5470,  1.6928,  3.8416,
          1.2376,  1.2636,  2.7593,  0.6311,  0.0000,  1.6295],
        [64.2321,  3.0723,  8.1728,  6.7679,  4.7934,  3.3495,  2.1081,  2.4904,
          2.4903,  6.2806,  1.9780,  1.8140,  4.8910,  1.5437,  1.6895,  3.8351,
          1.2342,  1.2613,  2.7534,  0.6297,  0.0000,  1.6260],
        [64.2907,  3.0758,  8.1799,  6.7745,  4.7985,  3.3535,  2.1109,  2.4924,
          2.4928,  6.2855,  1.9804,  1.8161,  4.8957,  1.5458,  1.6916,  3.8392,
          1.2363,  1.2627,  2.7571,  0.6306,  0.0000,  1.6282],
        [65.1269,  3.1252,  8.2800,  6.8680,  4.8717,  3.4110,  2.1507,  2.5213,
          2.5283,  6.3559,  2.0148,  1.8471,  4.9620,  1.5747,  1.7212,  3.8975,
          1.2664,  1.2833,  2.8089,  0.6430,  0.0000,  1.6591],
        [64.2912,  3.0758,  8.1799,  6.7745,  4.7986,  3.3536,  2.1110,  2.4924,
          2.4928,  6.2856,  1.9804,  1.8161,  4.8957,  1.5458,  1.6916,  3.8392,
          1.2363,  1.2628,  2.7571,  0.6306,  0.0000,  1.6282],
        [64.5770,  3.0927,  8.2142,  6.8065,  4.8236,  3.3732,  2.1246,  2.5023,
          2.5050,  6.3096,  1.9922,  1.8268,  4.9184,  1.5557,  1.7017,  3.8591,
          1.2466,  1.2698,  2.7748,  0.6348,  0.0000,  1.6388],
        [64.4834,  3.0872,  8.2029,  6.7960,  4.8154,  3.3668,  2.1201,  2.4991,
          2.5010,  6.3017,  1.9883,  1.8233,  4.9110,  1.5524,  1.6984,  3.8526,
          1.2432,  1.2675,  2.7690,  0.6334,  0.0000,  1.6353],
        [64.6300,  3.0958,  8.2205,  6.8125,  4.8282,  3.3769,  2.1271,  2.5041,
          2.5072,  6.3141,  1.9944,  1.8287,  4.9226,  1.5575,  1.7036,  3.8628,
          1.2485,  1.2711,  2.7781,  0.6356,  0.0000,  1.6407],
        [64.4953,  3.0879,  8.2044,  6.7974,  4.8164,  3.3676,  2.1207,  2.4995,
          2.5015,  6.3027,  1.9888,  1.8237,  4.9119,  1.5529,  1.6988,  3.8534,
          1.2436,  1.2678,  2.7698,  0.6336,  0.0000,  1.6357]], dtype = torch.float32)

In [54]:
decoder_data = decoder_data[:5, :]

In [55]:
decoder_pred  = S_model.evaluate(decoder_data)

In [56]:
decoder_pred

array([0.77182955, 0.7760769 , 0.7759796 , 0.77182955, 0.7759796 ],
      dtype=float32)

## Setting up the Decoder Model

In [15]:
from Decoder import Decoder
from SurrogateModel import SurrogateModel
import torch

In [16]:
x_dim = 22
h_nodes = 512
scale = 2
num_layers = 5
embed_dim = 8
dropout = 0.2

In [17]:
decoder_model = Decoder(x_dim, embed_dim= embed_dim, h_nodes = h_nodes, dropout = dropout, scale = scale, num_layers= num_layers, load_path = '/Users/tawab/Desktop/columbia/Courses/Spring2024/HPML/Project/Analog_NAS/env/models/decoder_model.pth').to('cpu')

Decoder model loaded from:  /Users/tawab/Desktop/columbia/Courses/Spring2024/HPML/Project/Analog_NAS/env/models/decoder_model.pth


### Test with random vector

In [28]:
random_vector = torch.rand(5, embed_dim)

In [29]:
decoder_output = decoder_model(random_vector)
decoder_output.shape

torch.Size([5, 22])

In [30]:
S_model = SurrogateModel('/Users/tawab/Desktop/columbia/Courses/Spring2024/HPML/Project/Analog_NAS/env/models/surrogate_model.json')

In [31]:
pred = S_model.evaluate(decoder_output)

In [32]:
pred

array([0.75241435, 0.7893968 , 0.7689091 , 0.76963234, 0.7433951 ],
      dtype=float32)