In [3]:
import xarray as xr
import xsimlab as xs
from xsimlab import variable as var, global_ref as glob
import pandas as pd
import zarr

# 20210729_sandbox

Entities:
- Variables
    - State matrix (SM)
    - State transition matrix (TM)
- Processes
    - Multiple processes edit TM
        - FOI process edits TM only on index of S and E compartments
        - I -> R transition is goverened by some process that
            - Ingests gamma
            - Edits TM
    - TM-editing processes interface
        - Inputs: SM(t), contacts per time, contact probability, other epi params, etc.
        - Output: TM(t)
    - Provide t0 SM
    - Step
        - Apply operation SM(t) * TM(t)
        - Ouputs SM(t+1)

- Potential to populate different subsets of SM async from multiple processes?
    - Keep it at 1 for now

First pass:
- S -> E progression
    - FOI is f(I, N) where N is sum over all compartments

## What do I mean when I say "edit TM"?

- Ideally, a process that "edits" TM is one that
    - Ingests a TM
    - Indexes TM, inserting along only the compartments it edits
    - Outputs a TM with fewer NaNs
    - Therefore represents TM as an `inout` variable in xsimlab
- The problem is that multiple `inout` processes are not allowed in xsimlab due to ambiguity in the task graph

### Solution

- What we could do is provide multiple processes that output *distinct* subsets of the TM
    - Each subset would be defined like `xs.variable(groups=['tm_subsets'])`
- Then an aggregator process `AggTM` ingests all the subsets of the TM using `xsimlab.group_dict`

In [4]:
@xs.process
class SetTimeZero:
    COMPT_COORDS = ['S', 'I', 'R']
    AGE_COORDS = ['0-4', '5-17', '18-49', '50-64', '65+']

    state = glob('state', intent='out')

    def initialize(self):
        self.state = xr.DataArray(
            data=100.,
            dims=('compt', 'age',),
            coords=dict(compt=self.COMPT_COORDS, age=self.AGE_COORDS)
        )

In [5]:
@xs.process
class CalcTM:
    state = glob('state', intent='in')
    x = glob('x', intent='out')

    def initialize(self):
        self.x = xr.ones_like(self.state)
        self.x.loc[dict(compt='S')] = 0.9
        self.x.loc[dict(compt='I')] = 1.1

In [6]:
@xs.process
class FOI:
    state = glob('state', intent='in')
    x = glob('x', intent='out')

    def initialize(self):
        self.x = 2 * xr.ones_like(self.state)

In [7]:
@xs.process
class Step:
    STEP_DIMS = ('compt', 'age',)
    state = var(dims=STEP_DIMS, intent='inout', global_name='state')
    x = var(dims=STEP_DIMS, intent='in', global_name='x')

    def run_step(self):
        self.state *= self.x


## Run Model

In [8]:
model = xs.Model({
    'set_t0': SetTimeZero,
    'calc_tm': CalcTM,
    'step': Step,
})
in_ds = xs.create_setup(
    model=model,
    clocks={
        'step': pd.date_range(start='3/1/2020', end='3/15/2020', freq='24H')
    },
    input_vars={

    },
    output_vars={
        'step__state': 'step'
    }
)
out_ds = in_ds.xsimlab.run(model=model, decoding=dict(mask_and_scale=False))

In [9]:
out_ds

## Sandbox

# Copied from Episimlab

In [10]:
def try_coerce_to_da(self, name, value):
    """Given a variable with `name`, and `value` set from a config file,
    retrieve the variable metadata and use it to coerce the `value` into
    an `xarray.DataArray` with the correct dimensions and coordinates.
    Returns `value` if variable is scalar (zero length dims attribute),
    DataArray otherwise.
    """
    # get dims
    dims = get_var_dims(self.KEYS_MAPPING[name], name)
    if not dims:
        return value
    # get coords
    coords = {dim: getattr(self, dim) for dim in dims if dim != 'value'}
    return xr.DataArray(data=value, dims=dims, coords=coords)

In [11]:
def get_var_dims(process, name) -> tuple:
    """Given process-wrapped class `process`, retrieve the `dims` metadata
    attribute for variable with `name`.
    """
    if not '__xsimlab_cls__' in dir(process):
        raise TypeError(
            f"Expected type 'xsimlab.Process' for arg `process`, received " +
            f"'{type(process)}'"
        )
    var = xs.utils.variables_dict(process).get(name, None)
    if var is None:
        raise AttributeError(f"process '{process}' has no attribute '{name}'")
    return tuple(var.metadata['dims'][0])
