In [None]:
from toy_problems.random_cones import random_cone
%load_ext autoreload
%autoreload 2

In [None]:
import wandb
import numpy as np
import jax
import jax.numpy as jnp
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Load Custom Modules
from plots import plot_true_function, add_optimization_path
from core import optimize_function, initial_point, initial_point_circle
from problems import *

In [None]:
# Define the sweep configuration

def calc_loss(x, x_opt):
    return jnp.linalg.norm(x - x_opt)

def main():
    wandb.init()
    
    # Set configuration
    fn = get_function_by_name(wandb.config.test_function)
    fn_optimizer = wandb.config.optimizer
    fn_run_name = wandb.config.optimizer
    n_trials = wandb.config.trials_per_run
    domain = [-10, 10]
    
    x_opt = jnp.array([0.0, 0.0])
    
               
    objective_values = []
    
    init_key = jax.random.PRNGKey(np.random.randint(0, 2**32))
    keys = jax.random.split(init_key, n_trials)

    for i in range(n_trials):
        # Get Random point 
        pi = initial_point(minval=jnp.array([domain[0], domain[0]]), maxval=jnp.array([domain[1], domain[1]]))
        
        path = optimize_function(
            fn, 
            pi, 
            wandb.config.true_params,
            [wandb.config.noise_values_pos, wandb.config.noise_values_pos, wandb.config.noise_values_slope], 
            steps=wandb.config.steps,
            learning_rate=wandb.config.learning_rate,
            batch_size=wandb.config.batch_size,
            method=wandb.config.optimizer,
            seed=int(keys[i][0])
        )
        
        # Compute objective
        objective = calc_loss(path[-1], x_opt)
        objective_values.append(objective)
    
    # Compute objective value statistics
    objective_values = np.array(objective_values)
    mean_objective = np.mean(objective_values)
    sdev_objective = np.std(objective_values)
    
    wandb.log({'mean_objective_value': mean_objective, 'sdev_objective_value': sdev_objective, 'init_key': init_key})
    
    wandb.finish()

In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'mean_objective_value',
        'goal': 'minimize'
    },
    'parameters': {
        'learning_rate': {
            "min": 0.001, 
            "max": 1.0
        },
        'batch_size': {
            'values': [1, 3, 5, 10]
        },
        'steps': {
            'value': 10
        },
        'optimizer': {
            'value': 'sgd'
        },
        'trials_per_run': {
            'value': 10
        },
        'true_params': {'value': [0, 0, 1]},
        'test_function': {'value': 'cone'},
        'noise_values_pos': {
            'values': [0.0, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0]
        },
        'noise_values_slope': {
            'values': [0.0, 0.5, 1.0]
        }
    }
}

In [None]:
sweep_id = wandb.sweep(sweep=sweep_config, project=f'Test Function Algorithm Hyperparameter Optimization')

In [None]:
# Maximum 'count' runs
wandb.agent(sweep_id, function=main, count=100)

In [None]:
init_key = jax.random.PRNGKey(np.random.randint(0, 2**32))
keys = jax.random.split(init_key, 100)