# WandB_parallel_sweep

This demonstration of the usage for the WandB_parallel_sweep class is based on the tutorial by Weights and Biases.
The original tutorial and colab notebook can be found here:
- [https://wandb.ai/site/articles/hyperparameter-tuning-as-easy-as-1-2-3](https://wandb.ai/site/articles/hyperparameter-tuning-as-easy-as-1-2-3)
- [https://colab.research.google.com/drive/1gKixa6hNUB8qrn1CfHirOfTEQm0qLCSS#scrollTo=aIhxl7glaJ5k](https://colab.research.google.com/drive/1gKixa6hNUB8qrn1CfHirOfTEQm0qLCSS#scrollTo=aIhxl7glaJ5k)

In [1]:
import numpy as np

#ai
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from keras.optimizers import RMSprop, SGD, Adam, Nadam
from keras.datasets import mnist

#wandb
import wandb
from wandb.keras import WandbCallback


import sys
sys.path.append('..')

from astroLuSt.ai.wandb_utils import WandB_parallel_sweep

print(tf.__version__)

2.11.0


## Define VAE

In [2]:
class VariationalAutoencoder:
    def __init__(self, input_shape, original_dim, intermediate_dim, latent_dim):
        self.input_shape = input_shape
        self.original_dim = original_dim
        self.intermediate_dim = intermediate_dim
        self.latent_dim = latent_dim

    def build_model(self):
        # instantiate VAE model
        self.build_encoder()
        self.build_decoder()

        self.outputs = self.decoder(self.encoder(self.inputs)[2])

        self.vae = keras.models.Model(self.inputs, self.outputs, name='vae_mlp')
        self.add_loss()

        return self.vae

    def sampling(self, args):
        z_mean, z_log_var = args
        batch = K.shape(z_mean)[0]
        dim = K.int_shape(z_mean)[1]
        # by default, random_normal has mean = 0 and std = 1.0
        epsilon = K.random_normal(shape=(batch, dim))
        return z_mean + K.exp(0.5 * z_log_var) * epsilon

    def build_encoder(self):
        self.inputs = keras.layers.Input(shape=self.input_shape, name='encoder_input')
        x = keras.layers.Dense(self.intermediate_dim, activation='relu')(self.inputs)
        self.z_mean = keras.layers.Dense(self.latent_dim, name='z_mean')(x)
        self.z_log_var = keras.layers.Dense(self.latent_dim, name='z_log_var')(x)

        self.z = keras.layers.Lambda(self.sampling, output_shape=(self.latent_dim,), name='z')([self.z_mean, self.z_log_var])

        # instantiate encoder model
        self.encoder = keras.models.Model(self.inputs, [self.z_mean, self.z_log_var, self.z], name='encoder')


    def build_decoder(self):
        latent_inputs = keras.layers.Input(shape=(self.latent_dim,), name='z_sampling')
        x = keras.layers.Dense(self.intermediate_dim, activation='relu')(latent_inputs)
        outputs = keras.layers.Dense(self.original_dim, activation='sigmoid')(x)

        # instantiate decoder model
        self.decoder = keras.models.Model(latent_inputs, outputs, name='decoder')

    def add_loss(self):
        # VAE loss = mse_loss or xent_loss + kl_loss
        reconstruction_loss = keras.losses.mse(self.inputs, self.outputs)
        reconstruction_loss *= self.original_dim

        kl_loss = 1 + self.z_log_var - K.square(self.z_mean) - K.exp(self.z_log_var)
        kl_loss = K.sum(kl_loss, axis=-1)
        kl_loss *= -0.5

        vae_loss = K.mean(reconstruction_loss + kl_loss)

        self.vae.add_loss(vae_loss)



        keras.backend.clear_session()


## Custom Callback

In [3]:

class ReconstructionLogger(tf.keras.callbacks.Callback):
    def __init__(self, batch_size):
        self.batch_size = batch_size

        super(ReconstructionLogger, self).__init__()

    def on_epoch_end(self, logs, epoch):
        sample_images = x_test[:self.batch_size]
        
        images = []
        reconstructions = []
        
        for i in range(32):
            reconstruction = self.model.predict(sample_images[i].reshape((1,)+sample_images[i].shape))

            images.append(sample_images[i].reshape(28,28))
            reconstructions.append(reconstruction.reshape(28,28))

        wandb.log({"images": [wandb.Image(image)
                            for image in images]})
        wandb.log({"reconstructions": [wandb.Image(reconstruction)
                            for reconstruction in reconstructions]})



## Training Loop to be called by each agent during the sweep

The training loop usually contains
- run initialization
- model definintion
- model building
- model compiling
- model fitting
- some extra logging if desired

In [4]:

def train():

    config_defaults = {
       'epochs':5,
       'batch_size':64,
       'learning_rate':1e-3,
    }

    run = wandb.init(
        # Set the project where this run will be logged
        project="sweeptesting",
        config=config_defaults,
    )

    config = wandb.config

    #callbacks
    callbacks = (
        WandbCallback(save_model=False),
        ReconstructionLogger(config.batch_size),
    )
    
    opt = Adam(lr=config.learning_rate, beta_1=0.9, beta_2=0.999, clipnorm=1.0)

    K.clear_session()
    vae = VariationalAutoencoder(input_shape=(784), 
                                original_dim=(784), 
                                intermediate_dim=512,
                                latent_dim=2)
    model = vae.build_model()
    model.compile(optimizer=opt)


    hist = model.fit(x_train, 
                x_train, 
                epochs=config.epochs, 
                batch_size=config.batch_size,
                callbacks=callbacks)
    
    run.finish()


## Initializing wandb

In [5]:

wandb.login()


# Configure the sweep – specify the parameters to search through, the search strategy, the optimization metric et all.
sweep_config = {
    'method':'grid', #grid, random
    'metric':{
        'name'  : 'accuracy',
        'goal'  : 'maximize'   
    },
    'run_cap':4,
    'parameters':{
        'epochs':{
            'values':[2]
        },
        'batch_size':{
            'values':[256, 128]
        },
        'learning_rate':{
            'values':[1e-2, 1e-3]
        },
    }
}
sweep_id = wandb.sweep(sweep_config, entity=None, project="sweeptesting")


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlust[0m. Use [1m`wandb login --relogin`[0m to force relogin


Create sweep with ID: 1qpdl237
Sweep URL: https://wandb.ai/lust/sweeptesting/sweeps/1qpdl237


## Loading the data

In [6]:
#MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

## Just dataset info
print("X_train: ", x_train.shape)
print("X_test: ", x_test.shape)

X_train:  (60000, 784)
X_test:  (10000, 784)


## Instantiating and running a parallelized sweeper

In [7]:
PS = WandB_parallel_sweep(
   sweep_id, train, n_jobs=4, n_agents=-1, wandb_mode='online', verbose=3,
)

#estimate upper bound of required agents and adopting it in PS
n_comb = PS.get_upper_bound_agents(sweep_config)
PS.n_agents = n_comb

print(PS)


WandB_parallel_sweep(
    sweep_id=1qpdl237, function=<function train at 0x00000206E23291B0>,
    n_jobs=4, n_agents=4,
    wandb_mode=online,
    verbose=3,
)


In [8]:
PS.sweep_parallel()



INFO(sweep_parallel()): Using n_jobs=4 jobs and n_agents=4 agents to run sweep.

######################################################################
INFO: Started sweep_parallel at 2023-05-03T18:55:15


[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.


LookupError: IndexError occured. Probably your provided 'taskname' has never been initialized.Make sure to initialize a 'taskname' before calling 'checkpoint_end()!
Original Error: index 0 is out of bounds for axis 0 with size 0.

In [9]:
print(PS.ET.df_protocoll)

             Task               Start                  End         Duration  \
0  sweep_parallel 2023-04-17 19:39:44  2023-04-17T19:41:21  0 days 00:01:37   

  Comment_Start Comment_End  
0                            
