In [None]:
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb


from mynam.data.toydataset import ToyDataset
from mynam.data.generator import *

from mynam.trainer.wandbtrainer import *

In [None]:
from nam.config import defaults
from nam.models import NAM, get_num_units

In [None]:
%reload_ext autoreload 
%autoreload 2

In [None]:
cfg = defaults()
cfg.experiment_name='nam-api-sparse-features'

cfg.log_loss_frequency = 10
cfg.batch_size=64
cfg.num_epochs=100
cfg.hidden_sizes=[]
cfg.num_basis_functions=1024
cfg.regression=True
print(cfg)

In [None]:
gen_funcs, gen_func_names = task()
sigma = 1.0
trainset = ToyDataset(gen_funcs,
                    gen_func_names, 
                      num_samples=1000, 
                     sigma=1.0)
valset = ToyDataset(gen_funcs, 
                    gen_func_names, 
                     num_samples=200)
testset = ToyDataset(gen_funcs, 
                    gen_func_names, 
                     num_samples=200,
                    use_test=True)
in_features = trainset.in_features
trainset.plot()

In [None]:
nam = NAM(
  config=cfg,
  name="NAM",
  num_inputs=len(trainset[0][0]),
  num_units=get_num_units(cfg, trainset.X))

In [None]:
nam

In [None]:
wandb.login()
wandb.finish() # mark runs as finished before starting new runs

In [None]:
# note that non-iterative type is invalid for wandb parameters_list
# note that log_uniform will add base exponents;
# while log_uniform_values expects specified values. 
parameters_list = {
    'lr': {
        'distribution': 'log_uniform_values',
        'min': 1e-3, 
        'max': 1e-1, 
    }, 
    'output_regularization': {
        'distribution': 'log_uniform_values',
        'min': 1e-3, 
        'max': 1e-1, 
    }, 
    'l2_regularization': {
        'distribution': 'log_uniform_values',
        'min': 1e-6, 
        'max': 1e-4, 
    }, 
    'dropout':  {
        'values': [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    }, 
    'feature_dropout': {
        'values': [0, 0.05, 0.1, 0.2]
    }, 
    'decay_rate':  {
        'values': [0, 0.005]
    }, 
    'activation':  {
        'values': ['relu', 'exu']
    }, 
    
}
sweep_configuration = {
    'method': 'bayes', 
    'name': 'sweep',
    'metric': {
        'goal': 'minimize', 
        'name': 'val_MAE', 
    }, 
    'early_terminate': {
      'type': 'hyperband', 
        'min_iter': 3,
    },
    'parameters': parameters_list
}
# initialize the sweep 
sweep_id = wandb.sweep(
    sweep=sweep_configuration, 
    project=cfg.experiment_name,
)

print(f"sweep id: {sweep_id}")

In [None]:
wandb.agent(sweep_id, 
            function=partial(sweep_train, 
                             config=cfg, 
                             dataloader_train=trainset.loader, 
                             dataloader_val=valset.loader, 
                             testset=testset),
            count=50) # specify the maximum number of runs

In [None]:
nam_lt, nam_mt, nam_lv, nam_mv = nam_trainer.train()

In [None]:
mnam_lt, mnam_mt, mnam_lv, mnam_mv = mynam_trainer.train()