In [17]:
import xarray as xr
import xsimlab as xs
import zarr
import numpy as np
from dask.diagnostics import Profiler, ResourceProfiler, CacheProfiler, ProgressBar
from time import sleep
import pandas as pd
from collections.abc import Iterable
import itertools

# 20201007_try_xarray_simlab
----

Continuation of [20201007_try_xarray_simlab](20201007_try_xarray_simlab.ipynb). Recap of progress is below:

In [18]:
def ravel_to_midx(coords):
    """TODO: pass dims so we dont rely on coords.keys() order
    """
    # Type checking
    assert isinstance(coords, dict)
    c = coords.copy()
    for k, v in c.items():
        # Since we're using the `size` attr...
        if isinstance(v, np.ndarray):
            pass
        elif isinstance(v, Iterable):
            c[k] = np.array(v)
        else:
            raise TypeError()
        
    # Generate pandas MultiIndex
    shape = [_c.size for _c in c.values()]
    midx = pd.MultiIndex.from_product(c.values(), names=c.keys())
    return np.ravel_multi_index(midx.codes, shape)

encoded_midx = ravel_to_midx(dict(
    ag=np.array(['0-4', '5-17', '18-49', '50-64', '65+']),
    rg=np.array(['low', 'high'])
))
encoded_midx

array([1, 0, 5, 4, 3, 2, 7, 6, 9, 8])

In [19]:
def unravel_encoded_midx(midx, coords):
    """TODO: pass dims so we dont rely on coords.keys() order
    """
    # Type checking
    assert isinstance(midx, np.ndarray)
    assert isinstance(coords, dict)
    c = coords.copy()
    for k, v in c.items():
        # Since we're using the `size` attr...
        if isinstance(v, np.ndarray):
            pass
        elif isinstance(v, Iterable):
            c[k] = np.array(v)
        else:
            raise TypeError()
        
    # Decode to a MultiIndex
    shape = [_c.size for _c in c.values()]
    indices = np.unravel_index(midx, shape)
    arrays = [c[dim][index] for dim, index in zip(c.keys(), indices)]
    return pd.MultiIndex.from_arrays(arrays)

decoded_midx = unravel_encoded_midx(
    midx=encoded_midx,
    coords=dict(
        ag=np.array(['0-4', '5-17', '18-49', '50-64', '65+']),
        rg=np.array(['low', 'high'])
    )
)
decoded_midx

MultiIndex([(  '0-4', 'high'),
            (  '0-4',  'low'),
            ('18-49', 'high'),
            ('18-49',  'low'),
            ( '5-17', 'high'),
            ( '5-17',  'low'),
            ('50-64', 'high'),
            ('50-64',  'low'),
            (  '65+', 'high'),
            (  '65+',  'low')],
           )

In [20]:
@xs.process
class MidxSetter:
    """TODO: handle the coords dynamically as a `group`
    """
    
    age_group = xs.index(dims='age_group')
    risk_group = xs.index(dims='risk_group')
    midx1 = xs.index(dims=('midx1'))
    midx2 = xs.index(dims=('midx2'))
    midx_mapping = xs.variable(dims=('age_group', 'risk_group'), static=True, intent='out')
    
    def initialize(self):
        self.age_group = ['0-4', '5-17', '18-49', '50-64', '65+']
        self.risk_group = ['low', 'high']  
        encoded_midx = ravel_to_midx(coords=dict(age_group=self.age_group, risk_group=self.risk_group))
        self.midx1 = encoded_midx
        self.midx2 = encoded_midx
        
        self.midx_mapping = self._get_midx_mapping()
        
    def _get_midx_mapping(self):
        shape = [len(self.age_group), len(self.risk_group)]
        coords = dict(
            age_group=self.age_group,
            risk_group=self.risk_group
        )
        da = xr.DataArray(
            data=self.midx1.reshape(shape),
            dims=('age_group', 'risk_group'),
            coords=dict(
                age_group=self.age_group,
                risk_group=self.risk_group
            )
        )
        return da
        
@xs.process
class ToyPhi:
    """
    """
    
    midx1 = xs.foreign(MidxSetter, 'midx1', intent='in')
    midx2 = xs.foreign(MidxSetter, 'midx2', intent='in')
    midx_mapping = xs.foreign(MidxSetter, 'midx_mapping', intent='in')
    age_group = xs.foreign(MidxSetter, 'age_group', intent='in')
    risk_group = xs.foreign(MidxSetter, 'risk_group', intent='in')
    phi = xs.variable(
        dims=('midx1', 'midx2'),
        # dims='midx1',
        intent='inout')
    
    def initialize(self):
        self.phi = np.zeros(shape=[self.midx1.size, self.midx2.size], dtype='int32')
        
    def finalize_step(self):
        """Toy behavior of how the SEIR model would access this array
        
        TODO: a version of this with matrix multiplication
        """
        
        # Iterate over every pair of age-risk pairs
        for a1, r1, a2, r2 in itertools.product(*[self.age_group, self.risk_group] * 2):
            # print(combo)
            
            # Get the indices in midx1/midx2
            i = self.midx_mapping.loc[(a1, r1)].values
            j = self.midx_mapping.loc[(a2, r2)].values
            # print(i, j)
            
            # ...then index the symmetrical array using
            # the derived i j indices
            # self.phi[i, j] += 1
        
        # Validate that this encoded index actually works
        # print(self._validate_midx())
    
    def _validate_midx(self):
        # Validate that this is actually the correct index in midx1
        return unravel_encoded_midx(
            midx=self.midx1,
            coords=dict(
                age_group=self.age_group,
                risk_group=self.risk_group
            )
        )

In [21]:
# Model definition
toy_midx = xs.Model(dict(
    midx_setter=MidxSetter,
    phi=ToyPhi
))

in_ds = xs.create_setup(
    model=toy_midx,
    clocks={
        'step': range(70)
    },
    input_vars={},
    output_vars={
        'phi__phi': 'step',
        'midx_setter__midx_mapping': 'step'
    }
)

# Track progress with Dask profilers
# with Profiler() as prof:
with ResourceProfiler(dt=0.25) as prof:
    out_ds = in_ds.xsimlab.run(
        model=toy_midx,
        # parallel=True,
        parallel=False,
        # scheduler='threads',
        # store=zarr.TempStore()
    )
    out_ds.load()
    sleep(5)
prof.visualize(show=True, save=False)
out_ds

KeyError: "Missing variables ['phi__phi'] in Dataset"

Recap from [20201007_try_xarray_simlab](20201005_try_xarray_simlab.ipynb):

> This encoded index appears to be working well. Only thing is that the `ravel_endoced_midx` and `ravel_to_midx` functions rely on `coords.keys()` to get an ordered list of dims. I'm actually not sure if this would make a difference (as long as `coords.values()` is in the same order), definitely not for this toy implementation, but it's an easy enough fix and is worth the added consistency.
> 
> Another important consideration is how this framework will handle sparse or recurring values on the `time`/`step` axis. This is important for things like:
> * Varying phi based on school calendar
> * Varying adjacency matrix based on day of the week, recurring every week
> * Switching on/off "stochastic mode" at a given timepoint
    > * Does this even need it's own data variable?
    > * ~~Perhaps one could just use `@xs.runtime(args='step')` wrapping a function with `if step == config.start_deterministic`...~~
    > * Probably best to just give the data variable `deterministic` its own clock.
    > 
> In any case, `xarray-simlab` provides the concept of multiple clocks. One could, for instance, have a different clock that corresponds to each of these events. ~~This is probably at best equivalently clunky to defining the above properties at all points on the time axis.~~ Actually, no. Each one of these clocks could be dynamically shaped (could have only a few timepoints), and would apply to only one of the data variables. I'm still thinking in terms of arrays, not Datasets. So it's quite viable to have separate clocks for the `counts`, `phi`, and `adjacency_weight` arrays, as well as the arrays for every other epidemiological parameter.
> 
> Now that I mention epidemiological parameters:
> * A lot of the epidemiological params in SEIRcity v1 and v2 are simply multipliers on the value `phi`. These **could** be safely deprecated, I think. The benefit to keeping them is that you get near optimal memory efficiency if you want to vary phi only on one axis. The downside to keeping all these variables is that this multiplication-on-a-single-axis operation will be hard-coded in the simulation engine.

Let's try adding another clock, and see if we can provide a `phi` input on that other clock value.

In [23]:
# Model definition
toy_midx = xs.Model(dict(
    midx_setter=MidxSetter,
    phi=ToyPhi
))

in_ds = xs.create_setup(
    model=toy_midx,
    clocks={
        'step': range(10),
        'otime': [0.0, 1.0]
    },
    master_clock='step',
    input_vars={
        'phi__phi': ('otime', [6, 9])
    },
    output_vars={
        'phi__phi': 'step',
        'midx_setter__midx_mapping': 'step'
    }
)

# Track progress with Dask profilers
# with Profiler() as prof:
with ResourceProfiler(dt=0.25) as prof:
    out_ds = in_ds.xsimlab.run(
        model=toy_midx,
        # parallel=True,
        parallel=False,
        # scheduler='threads',
        # store=zarr.TempStore()
    )
prof.visualize(show=False, save=False)
out_ds

ValueError: Invalid dimension(s) for variable 'phi__phi': found ('otime',), must be one of ('midx1', 'midx2'),('step', 'midx1', 'midx2')

It seems that I misunderstood what the alternate clocks do. I hoped that one would be able to replace the "master" ("main") clock with an alternate clock, but it seems that all variables **must** tick on the main clock, or no clock at all. Only time that alternate clock can be used is for taking snapshots of output vars.

This isn't a dealbreaker: we can still define each parameter with either no time coordinates, or for all time coordinates, per `xarray-simlab`'s design. I'm curious to see if one could dynamically generate parameter values at each timepoint, and what the performance overhead of that approach would be. For instance, one could implement conditional logic that recurses over a seven-day phi matrix.

One potential way to do this is to manually define a "day of the week" index variable, and then do some modulo magic to make the indexing function recurse over those 7 timepoints:

In [26]:
@xs.process
class MidxSetter:
    """TODO: handle the coords dynamically as a `group`
    """
    
    age_group = xs.index(dims='age_group')
    risk_group = xs.index(dims='risk_group')
    midx1 = xs.index(dims=('midx1'))
    midx2 = xs.index(dims=('midx2'))
    midx_mapping = xs.variable(dims=('age_group', 'risk_group'), static=True, intent='out')
    day_of_week = xs.index(dims=('day_of_week'))
    
    def initialize(self):
        self.age_group = ['0-4', '5-17', '18-49', '50-64', '65+']
        self.risk_group = ['low', 'high']  
        encoded_midx = ravel_to_midx(coords=dict(age_group=self.age_group, risk_group=self.risk_group))
        self.midx1 = encoded_midx
        self.midx2 = encoded_midx
        self.day_of_week = np.arange(7)
        
        self.midx_mapping = self._get_midx_mapping()
        
    def _get_midx_mapping(self):
        shape = [len(self.age_group), len(self.risk_group)]
        coords = dict(
            age_group=self.age_group,
            risk_group=self.risk_group
        )
        da = xr.DataArray(
            data=self.midx1.reshape(shape),
            dims=('age_group', 'risk_group'),
            coords=dict(
                age_group=self.age_group,
                risk_group=self.risk_group
            )
        )
        return da
        
@xs.process
class ToyPhi:
    """
    """
    
    midx1 = xs.foreign(MidxSetter, 'midx1', intent='in')
    midx2 = xs.foreign(MidxSetter, 'midx2', intent='in')
    midx_mapping = xs.foreign(MidxSetter, 'midx_mapping', intent='in')
    age_group = xs.foreign(MidxSetter, 'age_group', intent='in')
    risk_group = xs.foreign(MidxSetter, 'risk_group', intent='in')
    day_of_week = xs.foreign(MidxSetter, 'day_of_week', intent='in')
    phi = xs.variable(
        dims=('day_of_week', 'midx1', 'midx2'),
        # dims='midx1',
        static=True,
        intent='out')
    
    def initialize(self):
        self.phi = np.zeros(shape=[self.day_of_week.size, self.midx1.size, self.midx2.size], dtype='int32')
    
    @xs.runtime(args='step')
    def run_step(self, step):
                
        # Get the index on `day_of_week`
        day_idx = step % self.day_of_week.size
        self.phi_t = self.phi[day_idx]
        # print(step, self.day_of_week.size, day_idx)
        
    def finalize_step(self):
        """Toy behavior of how the SEIR model would access this array
        
        TODO: a version of this with matrix multiplication
        """
        
        # Iterate over every pair of age-risk pairs
        for a1, r1, a2, r2 in itertools.product(*[self.age_group, self.risk_group] * 2):
            # print(combo)
            
            # Get the indices in midx1/midx2
            i = self.midx_mapping.loc[(a1, r1)].values
            j = self.midx_mapping.loc[(a2, r2)].values
            # print(i, j)
            
            # ...then index the symmetrical array using
            # the derived i j indices
            self.phi_t[i, j] += 1
        
        # Validate that this encoded index actually works
        # print(self._validate_midx())
    
    def _validate_midx(self):
        # Validate that this is actually the correct index in midx1
        return unravel_encoded_midx(
            midx=self.midx1,
            coords=dict(
                age_group=self.age_group,
                risk_group=self.risk_group
            )
        )

In [34]:
# Model definition
toy_midx = xs.Model(dict(
    midx_setter=MidxSetter,
    phi=ToyPhi
))

in_ds = xs.create_setup(
    model=toy_midx,
    clocks={
        'step': range(70)
    },
    input_vars={},
    output_vars={
        'phi__phi': 'step',
        'midx_setter__midx_mapping': 'step'
    }
)

# Track progress with Dask profilers
# with Profiler() as prof:
with ResourceProfiler(dt=0.25) as prof:
    out_ds = in_ds.xsimlab.run(
        model=toy_midx,
        # parallel=True,
        parallel=False,
        # scheduler='threads',
        # store=zarr.TempStore()
    )
prof.visualize(show=False, save=False)
out_ds

In [28]:
out_ds

In [25]:
out_ds.phi__phi.loc[dict(step=8, midx1=0, midx2=0)]

Great! This dynamic behavior was not that difficult to implement. Of course, it is a little more CPU-heavy, but lighter on RAM. Also, if one wanted to modulate phi for every day of the year (e.g. including holidays and such), one could take the hit on RAM and simply construct a workflow that ingests phi without a `day_of_week` dimension.

I made a small tweak to the above model. Slicing `phi` along the `day_of_week` axis at every timepoint returns an array with `midx1, midx2` dims. This would be a good workflow for abstracting additional dims such as `day_of_week` separate from the main phi computation. For instance, one could imagine moving `ToyPhi.run_step` and `ToyPhi.initialize` to a different process.

It seems that `xarray-simlab` has met or surpassed my expectations with regards to its usefulness for this project. It allows for streamlined, reproducible model development with a minimal boilerplate. Importantly, it natively supports the software stack that I wanted to introduce: `xarray.Dataset`, Dask, and Zarr/NetCDF.

It addresses my two major concerns with the array-based data representations in SEIRcity v2: 1.) low memory efficiency and the associated 2.) data redundancy. The switch from `DataArray`s to `Dataset`s will definitely help address the data redundancy issue, since it allows us to define parameters on a **subset** of dimensions. In addition, `Dataset` compatibility with NetCDF/Zarr allows for RAM spillover to /tmp, if necessary.

Notably, the interplay between CPU and memory efficiency can easily modulated using `xarray-simlab`. The structuring of the above `phi` matrix is an excellent example, which can be generated by either memory or CPU efficient approaches. `xarray-simlab` would allow us to test and even **maintain** two separate versions of the array construction, allowing for fine modulation of resource usage. This optimization process is also accelerated since we get Dask diagnostics for free.

## Tentative Software Specification for Episimlab

After playing around with `xarray-simlab` this week, a more detailed software specification for the `episimlab` package is emerging.

### Purpose

Provide a set of reusable model components (known as "processes" in `xarray-simlab`) that enable the user to write, test, and run custom epidemiological models. Development boilerplate should be minimal for default settings. However, the full feature set of `xarray-simlab` should be exposed to the user, so that they can modify the default models, write their own models using process components, and optimize for performance.

### Feature Set

1. **Flexibility and adaptability**: The user should be able to able to leverage tools in an _interactive development environment_ to answer a wide variety of epidemiology questions
    * "Given a meta-population model, does social vulnerability index correlate with outcomes of infected individuals?"
    * "What if I use zip codes instead of census tracts?"
    * "What if I make my beta parameter time-dependent?"
    * "What if I define schools and hospitals as node types?"
2. **Performance**: User should be able to finish a production job in one day on an HPC node. Development jobs should finish in seconds or minutes on a laptop.
3. **Extensibility**: Models written by users should be able to be easily included into the production package, so that they can be shared with other users. Under the hood, this will require more than just `git merge`:
    * Well-defined regime for code organization
        * If you write a custom process, it should have its own file and belongs in this directory.
    * Well-defined testing regime
        * If you write a custom process, it should pass integration tests with (all | a subset of) the models in production
        * If you write a custom model, it should be able to...
        * CI/CD is promising here
    * Extensive documentation
    * Especially well-defined contribution guide

### Software Stack

* Top level interface should provide pre-defined `process`es and `Model`s from the [`xarray-simlab`](https://xarray-simlab.readthedocs.io/) package. This package ships with support for:
    * [`xarray.Dataset`](http://xarray.pydata.org/en/stable/data-structures.html#dataset)
    * [Dask](https://docs.dask.org)
    * [Zarr](https://zarr.readthedocs.io/)
    * [attrs](https://www.attrs.org/)
* Low level interface takes current counts and parameters for a single step in the simulation, and returns the counts for the next timepoint.
    * This can be implemented as "native xsimlab" in Python
    * Simulations that are more CPU intensive and less memory intensive can use a Cython implementation
    * This is a similar paradigm to SEIRcity v2
    
### `Process`es

The majority of the code base will reside in a set of `xarray-simlab` `process`es. There will be no formal categories for processes (except for module organization), but processes will generally fall into one of the following:
* **I/O**: processes that read input variables from a file
    * File types
        * YAML
        * NetCDF
        * Zarr
    * Possibility in future: processes that query databases, such as Kelly Gaither's SafeGraph SQL
    * Should be easily replaced with user-supplied inputs
* **Simulation setup**: processes that instantiate the variables necessary for simulation.
    * Processes like
        * Coordinate construction
        * Array construction from these coordinates (notably `phi` and adjacency matrix)
    * Most of the functionality in these processes resides in their `initialize` method
    * Will leverage `run_step` method in some cases.
        * For instance, setting `phi` based on day of the week
* **Simulation**: processes that directly influence the `counts` variable (known as the simulation space array in SEIRcity v2)
    * Processes like
        * Calculating force of infection given a `phi` matrix
        * Using force of infection to get the change in `counts` for this timestep
        * Using adjacency matrix to determine change in `counts` due to travel
        * Applying all changes to `counts` to the next timestep
    * Most of the functionality in these processes resides in their `run_step` and `finalize_step` methods
    * As lightweight and static as possible
        * While all processes should be mutable and replacable to some degree, these **Simulation** processes should be written with high reusability in mind
        * This often means that it is preferable to factor out simulation setup logic to a less reusable **Simulation Setup** process
    * Written in "native xsimlab" where performance allows, and makes calls to more performant Python or Cython engines if necessary
        
### `Model`s

* Models in `episimlab` are instances of `xarray-simlab` `Model` class.
* They are comprised of unique `process`es.
* In `episimlab`, models are provided primarily for testing and demonstration purposes
    * **Testing**: "can the developers recapitulate the results of _this_ model that is not part of `episimlab`?"
    * **Demonstration**: "here is an example of a basic, single-city SEIR model that is comprised of processes from `episimlab`"
* As such, the "standard library" of models that ships with `episimlab` should be minimal.
* The library of _user-defined_ models will grow as users push their own models to the production codebase.

## Developer Best Practices

Apart from the usual software engineering best practices, the following conventions will be especially important when developing and extending `episimlab`:
1. New features should **always** be contained in a new process.
2. Updated features should **almost always** be contained in a **new** process.
    * The exception is bug fixes
    * In other words, only change or remove features of an existing process if it is wrong (**not** if it is simply used less often than it was)
3. Processes should be atomic
    * A process should have few inputs and few outputs, just like a normal Python function.
    * Advanced, multi-faceted features should be broken up into atomic processes. The advanced feature should function (and pass tests) when its component processes are combined in a `Model`.
4. New processes should be written in "native xsimlab" whenever possible
    * That is, almost all variables that are defined in Python should also be defined using xsimlab (`xs.variable`).
    * The main exception is when the logic in the new process is not performant, in which case it should make a call to a more efficient external Cython function.
    * In the case where a process calls a more performant backend, there must always be an equivalent process written in "native xsimlab" for testing purposes.
5. New processes should be tested as stand-alone components
    * A new process should have a corresponding `pytest` module that covers only that process.
6. Paradigm for testing new models
    * Work in progress