# Parameter Sweeps - Second Attempt

In [66]:
import ipyparallel as ipp
import ipywidgets as ipw
import pandas as pd
import qgrid
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display
from arctic import Arctic
import itertools as it
import warnings

In [8]:
import dill

In [9]:
import IPython

In [10]:
import time

In [11]:
rc = ipp.Client()
rc.ids

            Controller appears to be listening on localhost, but not on this machine.
            If this is true, you should specify Client(...,sshserver='you@oliver-arch')
            or instruct your controller to listen on an external IP.


[0, 1]

In [12]:
# Tell engines their ids
for i, engine in enumerate(rc):
    engine.push({'engine_id': i})

# Ask them to report back
%px print(engine_id)

[stdout:0] 0
[stdout:1] 1


In [13]:
def printme(*args, **kwargs):
    return engine_id, args, kwargs

In [14]:
lview = rc.load_balanced_view()

In [15]:
lview.apply(printme, 1, 3, x=4, y=79.1).result()

(0, (1, 3), {'x': 4, 'y': 79.1})

In [16]:
db = pymongo.MongoClient('nautilus.optiputer.net', 31017)['kale-param-sweeps']
coll = db['sweep1']

In [17]:
coll.insert_one({
    'params': dict(n=3, x=5),
    'results': dict(mean=3, std=6)
})

<pymongo.results.InsertOneResult at 0x7f1540d26f88>

In [18]:
coll.find_one({'params': {'n': 3, 'x': 5}})

{'_id': ObjectId('5ad86832298bc7739c0ea7fb'),
 'params': {'n': 3, 'x': 5},
 'results': {'mean': 3, 'std': 6}}

In [19]:
coll.find_one()

{'_id': ObjectId('5ad8469d298bc765b913c59b'), 'test': 'a'}

In [20]:
%%time
# Test execution scheduling
def printsleep(dur, msg):
    import time
    time.sleep(dur)
    return '{}: {}'.format(engine_id, msg)

futures = [lview.apply(printsleep, 2, 'hello') for i in range(3)]
res = [f.result() for f in futures]

CPU times: user 20.5 ms, sys: 2.59 ms, total: 23.1 ms
Wall time: 4.03 s


# Kale functions

## View

In [21]:
def param_grid(**kwargs):
    """Generate Cartesian product of keyword arguments"""
    prod = it.product(kwargs.items())
    name_list, values_list = zip(*kwargs.items())
    value_combinations = list(it.product(*values_list))
    return name_list, value_combinations

def param_df(**kwargs):
    """Generate pandas DataFrame from Cartesian product of keyword arguments"""
    names, values = param_grid(**kwargs)
    return pd.DataFrame(values, columns=names)

def param_qgrid(qgrid_layout=None, **kwargs):
    """Generate Qgrid table from Cartesian product of keyword arguments"""
    if not qgrid_layout:
        qgrid_layout=ipw.Layout()
    return qgrid.QgridWidget(df=param_df(**kwargs), layout=qgrid_layout)

In [22]:
g = param_qgrid(qgrid_layout=ipw.Layout(width='300px'), x=[1,2], y=[6,3], z=[True, False])
g

QgridWidget(grid_options={'fullWidthRows': True, 'syncColumnCellResize': True, 'forceFitColumns': True, 'defau…

## Submit

In [117]:
def dict_list_from_df(df):
    """Turn each row into a dictionary indexed by column names"""
    return [{col: val for col, val in zip(df.columns, df.loc[ind,:])} for ind in df.index]

In [137]:
#class ParamSpan(tr.HasTraits):
#    pass
log_output = ipw.Output()

class ParamSpanWidget(ipw.VBox):
    def __init__(self, param_span_name, compute_func, vis_func, output_layout=None, qgrid_layout=None):
        self.name = param_span_name
        self.compute_func = compute_func
        self.vis_func = vis_func
        self.output_layout = output_layout
        self.qgrid_layout = qgrid_layout

        super().__init__()

        self.init_db()
        self.init_engines()
        self.init_widgets()
        self.init_layout()

    def init_db(self):
        # Collect database info
        db_host = 'nautilus.optiputer.net'
        db_port = 31017
        lib_name = self.name
        self.store = Arctic('mongodb://{}:{}'.format(db_host, db_port))
        self.store.initialize_library(lib_name)
        self.library = self.store[lib_name]

        # Package database info to send to engines
        self.db_info = [
            db_host,
            db_port,
            lib_name
        ]

    def init_engines(self):
        # Connect to controller
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.ipp_client = ipp.Client()

        # Establish views
        self.dview = self.ipp_client.direct_view()
        self.lview = self.ipp_client.load_balanced_view()

        # Tell engines their ids
        for i, engine in enumerate(self.ipp_client):
            engine.push({'engine_id': i})
            
        # Distribute database information
        self.dview.push({'db_info': self.db_info})
        
        def load_library():
            global store, library
            from arctic import Arctic
            db_host, db_port, lib_name = db_info
            store = Arctic('mongodb://{}:{}'.format(db_host, db_port))
            library = store[lib_name]
        self.dview.apply(load_library)
            
    def init_widgets(self):
        if not self.output_layout:
            self.output_layout = ipw.Layout(height='300px', border='1px solid', overflow_x='scroll', overflow_y='scroll')
        if not self.qgrid_layout:
            qgrid_layout = ipw.Layout()

        self.output = ipw.Output(layout=self.output_layout)
        # param_table is empty until set_params is called
        self.param_table = param_qgrid(self.qgrid_layout, **{'':[]})

    def init_logic(self):
        self.param_table.observe(self.visualize_wrapper, names='_selected_rows')

    def init_layout(self):
        self.children = [
            self.output,
            self.param_table
        ]

    def set_params(self, **all_params):
        """Provide parameter set to search over
        all_params = {
            'param1': [val1, val2, ...],
            'param2': [val3, val4, ...],
            ...
        }
        """
        self.param_table = param_qgrid(self.qgrid_layout, **all_params)
        self.init_logic()
        self.init_layout()

    def submit_computations(self, *change):
        def compute_wrapper(compute_func, paramset_id, params):
            """Perform computation and send results to MongoDB for one set of params"""
            print("Compute.")
            print("params = {}".format(params))
            results = compute_func(**params)
            
            # Index collection by parameters
            library.write(
                str(paramset_id),
                results
            )
            print("Compute done.")

        # Loop over all sets of parameters
        # paramset_id is the row index,
        # paramset is the dictionary of params
        self.compute_futures = []
        for paramset_id, paramset in self.param_table.df.T.to_dict().items():
            # Submit task to IPyParallel
            print("Sumbitting {}: {}".format(paramset_id, paramset))
            fut =  self.lview.apply(compute_wrapper, self.compute_func, paramset_id, paramset)
            self.compute_futures.append(fut)
            print(fut.stdout)
            print(fut.stderr)

    def visualize_wrapper(self, *change):
        """Call visualization function and capture output"""
        # Do nothing if selection is empty:
        # Empty list evaluates to False
        if not self.param_table.get_selected_rows():
            return

        # Get params from selected row (take first row if >1 selected)
        paramset_id = self.param_table.get_selected_rows()[0]
        paramset = self.param_table.df.loc[paramset_id, :]

        print("Vis")
        print("paramset_id: {}".format(paramset_id))
        print("paramset: {}".format(paramset))
        # Search collection by parameters
        compute_results = self.library.read(str(paramset_id)).data
        print("results: {}".format(compute_results))

        # Avoid using output context to ensure that
        # only this function's output is included.
        @self.output.capture(clear_output=True, wait=True)
        def wrapper():
            self.vis_func(**compute_results)
        wrapper()


# User functions

In [138]:
def compute2(N, mean, std, color):
    import numpy as np
    x = np.random.normal(loc=mean, scale=std, size=N)
    realmean = np.mean(x)
    realstd = np.std(x)
    return {
        'engine_id': engine_id,
        'x': x,
        'N': N,
        'realmean': realmean,
        'realstd': realstd,
        'color': color
    }
    
def visualize2(engine_id, x, N, realmean, realstd, color):
    print("Computed on engine {}".format(engine_id))
    plt.figure(figsize=[8,3])
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sns.distplot(x, color=color)
    plt.title(r'$N={}$, $\mu={:.2f}$, $\sigma={:.2f}$'.format(N, realmean, realstd))
    plt.show()
    print("Data: {}".format(x))

In [139]:
psw = ParamSpanWidget('test_span', compute2, visualize2)
psw.set_params(
    N=np.logspace(2,6,2).astype(int), 
    mean=[0], 
    std=[5], 
    color=['red', 'blue']
)
psw.submit_computations()
display(ipw.VBox([psw,log_output]))

Library created, but couldn't enable sharding: no such command: 'enablesharding', bad cmd: '{ enablesharding: "arctic", lsid: { id: UUID("d0b44a03-0285-4736-90cf-6161b12aa649") }, $readPreference: { mode: "secondaryPreferred" }, $db: "admin" }'. This is OK if you're not 'admin'


Sumbitting 0: {'N': 100, 'mean': 0, 'std': 5, 'color': 'red'}


Sumbitting 1: {'N': 100, 'mean': 0, 'std': 5, 'color': 'blue'}


Sumbitting 2: {'N': 1000000, 'mean': 0, 'std': 5, 'color': 'red'}


Sumbitting 3: {'N': 1000000, 'mean': 0, 'std': 5, 'color': 'blue'}




VBox(children=(ParamSpanWidget(children=(Output(layout=Layout(border='1px solid', height='300px', overflow_x='…