# Parameter Sweeps - Second Attempt

In [None]:
import datetime
import warnings
import itertools as it

import dill
from arctic.exceptions import NoDataFoundException
from arctic import Arctic

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import ipyparallel as ipp
import ipywidgets as ipw
import qgrid

## View

In [None]:
def param_grid(**kwargs):
    """Generate Cartesian product of keyword arguments"""
    prod = it.product(kwargs.items())
    name_list, values_list = zip(*kwargs.items())
    value_combinations = list(it.product(*values_list))
    return name_list, value_combinations

def param_df(**kwargs):
    """Generate pandas DataFrame from Cartesian product of keyword arguments"""
    names, values = param_grid(**kwargs)
    return pd.DataFrame(values, columns=names)

def param_qgrid(qgrid_layout=None, **kwargs):
    """Generate Qgrid table from Cartesian product of keyword arguments"""
    if not qgrid_layout:
        qgrid_layout=ipw.Layout()
    return qgrid.QgridWidget(df=param_df(**kwargs), layout=qgrid_layout)

def dict_list_from_df(df):
    """Turn each row into a dictionary indexed by column names"""
    return [{col: val for col, val in zip(df.columns, df.loc[ind,:])} for ind in df.index]

In [None]:
class ParamSpanRemoteConfig(object):
    def __init__(self):
        self.init_db()
        self.init_engines()
    
    def init_db(self):
        # Collect database info
        db_host = 'nautilus.optiputer.net'
        db_port = 31017
        lib_name = 'kale-param-spans'
        self.store = Arctic('mongodb://{}:{}'.format(db_host, db_port))
        self.store.initialize_library(lib_name)
        self.library = self.store[lib_name]

        # Package database info to send to engines
        self.db_info = [
            db_host,
            db_port,
            lib_name
        ]

    def init_engines(self):
        # Connect to controller
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.ipp_client = ipp.Client()

        # Establish views
        self.dview = self.ipp_client.direct_view()
        self.lview = self.ipp_client.load_balanced_view()
            
        # Tell engines their ids
        for i, engine in enumerate(self.ipp_client):
            engine.push({'engine_id': i})
            
        # Distribute database information
        self.dview.push({'db_info': self.db_info})
        
        def load_library():
            global store, library
            from arctic import Arctic
            db_host, db_port, lib_name = db_info
            store = Arctic('mongodb://{}:{}'.format(db_host, db_port))
            library = store[lib_name]
        self.dview.apply(load_library)

In [None]:
class ParamSpanWidget(ipw.VBox):
    def __init__(self, param_span_name, remote_config, compute_func, vis_func, output_layout=None, qgrid_layout=None):
        self.compute_func = compute_func
        self.name = param_span_name
        self.remote_config = remote_config
        self.vis_func = vis_func
        self.output_layout = output_layout
        self.qgrid_layout = qgrid_layout

        super().__init__()

        self.load_remote_config()
        self.init_widgets()
        self.init_layout()

    def load_remote_config(self):
        self.ipp_client = self.remote_config.ipp_client
        self.dview = self.remote_config.dview
        self.lview = self.remote_config.lview
        self.db_info = self.remote_config.db_info
        self.store = self.remote_config.store
        self.library = self.remote_config.library
        self.lview = self.remote_config.lview
        
    def init_widgets(self):
        if not self.output_layout:
            self.output_layout = ipw.Layout(height='400px', border='1px solid', overflow_x='scroll', overflow_y='scroll')
        if not self.qgrid_layout:
            qgrid_layout = ipw.Layout()

        self.output = ipw.Output(layout=self.output_layout)
        # param_table is empty until set_params is called
        self.param_table = param_qgrid(self.qgrid_layout, **{'':[]})

    def init_logic(self):
        self.param_table.observe(self.visualize_wrapper, names='_selected_rows')

    def init_layout(self):
        self.children = [
            self.output,
            self.param_table
        ]

    def set_params(self, **all_params):
        """Provide parameter set to search over
        all_params = {
            'param1': [val1, val2, ...],
            'param2': [val3, val4, ...],
            ...
        }
        """
        self.param_table = param_qgrid(self.qgrid_layout, **all_params)
        self.init_logic()
        self.init_layout()

    def submit_computations(self, *change):
        def compute_wrapper(compute_func, name, paramset_id, params):
            """Perform computation and send results to MongoDB for one set of params"""
            results = compute_func(**params)
            
            # Index collection by paramset_id and paramspan name
            record_label = '{}-{}'.format(name, paramset_id)
            library.write(
                record_label,
                results
            )

        # Loop over all sets of parameters
        # paramset_id is the row index,
        # paramset is the dictionary of params
        self.compute_futures = []
        for paramset_id, paramset in self.param_table.df.T.to_dict().items():
            # Submit task to IPyParallel
            fut =  self.lview.apply(compute_wrapper, self.compute_func, self.name, paramset_id, paramset)
            self.compute_futures.append(fut)

    def visualize_wrapper(self, *change):
        """Call visualization function and capture output"""
        # Do nothing if selection is empty:
        # Empty list evaluates to False
        if not self.param_table.get_selected_rows():
            return

        # Get params from selected row (take first row if >1 selected)
        paramset_id = self.param_table.get_selected_rows()[0]
        paramset = self.param_table.df.loc[paramset_id, :]

        # Search collection by parameters
        record_label = '{}-{}'.format(self.name, paramset_id)
        # Clear screen if the results of this computation
        # are not available. Assume empty dict by default
        compute_results = {}
        try:
            compute_results = self.library.read(record_label).data
        except NoDataFoundException:
            self.output.clear_output()

        # Avoid using output context to ensure that
        # only this function's output is included.
        @self.output.capture(clear_output=True, wait=True)
        def wrapper():
            self.vis_func(**compute_results)
        wrapper()


In [None]:
remote_config = ParamSpanRemoteConfig()

# User functions

In [None]:
def exp_compute(N, mean, std, color):
    import numpy as np
    from datetime import datetime
    x = np.random.normal(loc=mean, scale=std, size=N)
    realmean = np.mean(x)
    realstd = np.std(x)
    return {
        'engine_id': engine_id,
        'date': datetime.now().ctime(),
        'x': x,
        'N': N,
        'realmean': realmean,
        'realstd': realstd,
        'color': color
    }
    
def exp_viz(engine_id, date, x, N, realmean, realstd, color):
    print("Computed on engine {} at {}".format(engine_id, date))
    plt.figure(figsize=[8,5])
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sns.distplot(x, color=color)
    plt.title(r'$N={}$, $\mu={:.2f}$, $\sigma={:.2f}$'.format(N, realmean, realstd))
    plt.show()
    print("Data: {}".format(x))

In [None]:
psw = ParamSpanWidget('test_span', remote_config, exp_compute, exp_viz)
psw.set_params(
    N=np.logspace(2,6,2).astype(int), 
    mean=[0,99], 
    std=[5], 
    color=['red', 'blue']
)
psw.submit_computations()
psw