In [2]:
import numpy as np
from pathlib import Path
from itertools import product, chain

# parameters
run_parallel = False
datasets_to_run = ['pokerdvs']
algorithm = 'clusterGAN' # hgmr, jrmpc, clusterGAN
save_path = Path('milled_nbs')

#### clusterGAN parameters

In [None]:
clustergan_dict = {
    'pokerdvs' : {
        'dataset_name' : ['pokerdvs'],
        'frame_time' : [5000],
        'subsample' : [100],
        'batch_size' : [64],
        'latent_dim' : [30],
        'beta_cycle_gen' : [10.0],
        'beta_cycle_label' : [10.0],
        'train' : [True],
        'timestamp' : ['']
    },
    'ncars' : {
        'dataset_name' : ['ncars'],
        'frame_time' : [5000],
        'subsample' : [100],
        'batch_size' : [64],
        'latent_dim' : [30],
        'beta_cycle_gen' : [10.0],
        'beta_cycle_label' : [10.0],
        'train' : [True],
        'timestamp' : ['']
    },
    'nmnist' : {
        'dataset_name' : ['nmnist'],
        'frame_time' : [5000],
        'subsample' : [100],
        'batch_size' : [64],
        'latent_dim' : [30],
        'beta_cycle_gen' : [10.0],
        'beta_cycle_label' : [10.0],
        'train' : [True],
        'timestamp' : ['']
    },
    'dvsgesture' : {
        'dataset_name' : ['dvsgesture'],
        'frame_time' : [5000],
        'subsample' : [100],
        'batch_size' : [64],
        'latent_dim' : [30],
        'beta_cycle_gen' : [10.0],
        'beta_cycle_label' : [10.0],
        'train' : [True],
        'timestamp' : ['']
    }
}

#### HGMR parameters

In [None]:
hgmr_dict = {
    'pokerdvs' : {
        'dataset_name' : ['pokerdvs'],
        'download_dataset' : [True],
        'first_saccade_only' : [False],
        'subsample' : [100],
        'tree_level' : [3],
        'inference_level' : [3],
        'spatial_histograms' : [True],
        'K' : [10],
        'coresets' : [False],
        'Np' : [2**16],
        't_res' : [1000],
    },
    'ncars' : {
        'dataset_name' : ['ncars'],
        'download_dataset' : [True],
        'first_saccade_only' : [False],
        'subsample' : [100],
        'tree_level' : [3],
        'inference_level' : [3],
        'spatial_histograms' : [True],
        'K' : [10],
        'coresets' : [True],
        'Np' : [2**16],
        't_res' : [1000],
    },
    'nmnist' : {
        'dataset_name' : ['nmnist'],
        'download_dataset' : [True],
        'first_saccade_only' : [False],
        'subsample' : [100],
        'tree_level' : [3],
        'inference_level' : [3],
        'spatial_histograms' : [True],
        'K' : [10],
        'coresets' : [True],
        'Np' : [2**16],
        't_res' : [1000],
    },
    'dvsgesture' : {
        'dataset_name' : ['dvsgesture'],
        'download_dataset' : [True],
        'first_saccade_only' : [False],
        'subsample' : [100],
        'tree_level' : [3],
        'inference_level' : [3],
        'spatial_histograms' : [True],
        'K' : [10],
        'coresets' : [True],
        'Np' : [2**16],
        't_res' : [1000],
    }
}

#### JRMPC parameters

In [None]:
jrmpc_dict = {
    'pokerdvs' : {
        'dataset_name' : ['pokerdvs'],
        'download_dataset' : [True],
        'first_saccade_only' : [False],
        'subsample' : [100],
        'spatial_histograms' : [True],
        'K' : [10],
        'coresets' : [False],
        'Np' : [1000],
        't_res' : [1000],
        'C' : [5],
        'maxNumIter' : [1000],
    },
    'ncars' : {
        'dataset_name' : ['ncars'],
        'download_dataset' : [True],
        'first_saccade_only' : [False],
        'subsample' : [100],
        'spatial_histograms' : [True],
        'K' : [10],
        'coresets' : [True],
        'Np' : [1000],
        't_res' : [1000],
        'C' : [10],
        'maxNumIter' : [1000],
    },
    'nmnist' : {
        'dataset_name' : ['nmnist'],
        'download_dataset' : [True],
        'first_saccade_only' : [False],
        'subsample' : [100],
        'spatial_histograms' : [True],
        'K' : [10],
        'coresets' : [True],
        'Np' : [1000],
        't_res' : [1000],
        'C' : [10],
        'maxNumIter' : [1000],
    },
    'dvsgesture' : {
        'dataset_name' : ['dvsgesture'],
        'download_dataset' : [True],
        'first_saccade_only' : [False],
        'subsample' : [100],
        'spatial_histograms' : [True],
        'K' : [10],
        'coresets' : [True],
        'Np' : [1000],
        't_res' : [1000],
        'C' : [10],
        'maxNumIter' : [1000],
    }
}

In [None]:
if algorithm.lower() == 'clustergan':
    input_nb_path = Path('clusterGAN.ipynb')
    parameter_dict = clustergan_dict
elif algorithm.lower() == 'hgmr':
    input_nb_path = Path('hgmr.ipynb')
    parameter_dict = hgmr_dict
elif algorithm.lower() == 'jrmpc':
    input_nb_path = Path('jrmpc.ipynb')
    parameter_dict = jrmpc_dict
else:
    raise ValueError('wrong algorithm')
    
value_permutations = []
for dic in parameters_dict.keys():
    if dic in datasets_to_run:
        value_permutations.extend(list(product(*list(parameter_dict[dic].values()))))

short_keys = [('_'+k[:4]+'_',i) for i,k in enumerate(parameter_dict[list(parameter_dict.keys())[0]].keys()) if k != 'download_dataset' and k != 'first_saccade_only'
    
]
print('number of permutations:', len(value_permutations))

### execute notebooks

In [None]:
import time
import multiprocessing as mp
import papermill as pm
from functools import partial

def process_notebook(params):
    string = ''
    for key, index in short_keys:
        string += key + str(params[index])
    string += '.ipynb'
    
    param_dict = dict(zip(parameter_dict[list(parameter_dict.keys())[0]].keys(), params))
    save_path.mkdir(parents=True, exist_ok=True)
    output = save_path / string
        
    pm.execute_notebook(
        str(input_nb_path),
        str(output),
        parameters = param_dict,
        nest_asyncio=True
    )

if __name__ == '__main__':
    start_time = time.time()
    if run_parallel:
        with mp.Pool() as pool:
            pool.map(partial(process_notebook), value_permutations)
    else:
        for perm in value_permutations:
            process_notebook(perm)
    print("--- %s seconds ---" % (time.time() - start_time))