## TODO

* impose a naming convention
* add some background on example model
* nested field structure, for example BASICS: periods, PARAMETERS: rho, which mimics the respy/soepy/norpy structure

This notebook hosts some ideas for how to set up a model specification in our OSE projects.

In [25]:
from collections import namedtuple
import pathlib
import copy

import numpy as np
import yaml

# This is where once and for all the model parameters are defined.
fields = ['rho', 'periods']

class ModelCls(namedtuple('model_base', ' '.join(fields))):
 
    def __eq__(self, other):
        assert isinstance(other, type(self))
        assert set(spec_1._fields) == set(spec_2._fields)
        for field in self._fields:
            if getattr(self, field) != getattr(other, field):
                return False
        return True
    
    def copy(self):
        return copy.deepcopy(self)
    
    def __ne__(self, other):
        return not self.__eq__(other)
    
    def __repr__(self):
        str_ = ''
        for field in self._fields:
            str_ +='{:}: {:}\n'.format(field, getattr(self, field))
        return str_
    
    def as_dict(self):
        return self._asdict()
    
    def replace(self, *args, **kwargs):
        return self._replace(*args, **kwargs)
    
    def to_yaml(self, fname='test.yml'):
        with open(fname, 'w') as out_file:
            yaml.dump(self._asdict(), out_file)
            
    def validate(self):
        '''This method validates the model specification. All validation is done here and no further checks 
        are necessary later in the program for the immutable parameters describing the model.
        
        The design ensures that all fields require exlicit checks.
        '''
        for field in self._fields:
            if field == 'periods':
                attr = getattr(self, field)
                assert isinstance(attr, int)
                assert attr > 0
            elif field == 'rho':
                attr = getattr(self, field)
                assert isinstance(attr, float)
            else:
                raise NotImplementedError('validation of {:} missing'.format(field))
                
def generate_random_init(constr=dict()):
    def process_constraints(constr):
        '''This function processes all constraints passed in by the user for the random model specification.'''
        if constr.get('periods'):
            init_dict['periods'] = constr['periods']
        if constr.get('rho'):
            init_dict['rho'] = constr['rho']
    
    init_dict = dict()
    init_dict['rho'] = np.random.uniform(0.01, 0.99)
    init_dict['periods'] = np.random.randint(1, 10)
        
    process_constraints(constr)
        
    return init_dict
    
def get_model_obj(input_, constr=dict()):
    """This is a factory method to create a model spefication from a variety of inputs."""    
    # We want to enforce the use of Path objects going forward.
    if isinstance(input_, str):
        input_ = pathlib.Path(input_)
 
    if isinstance(input_, dict):
        model_spec = ModelCls(**input_)
    elif isinstance(input_, pathlib.PosixPath):
        model_spec = ModelCls(**yaml.load(open(input_, 'r'), Loader=yaml.FullLoader))
    elif input_ is None:
        model_spec = ModelCls(**generate_random_init(constr))
    else:
        raise NotImplementedError

    model_spec.validate()

    return model_spec

## Use cases 

We want to explore some use cases with a basic model and translate them into tests.

We can specify a model programmatically using a dictionary.

In [9]:
init_dict = {'rho': 0.5, 'periods': 10}
get_model_obj(init_dict);

As an alternative we can also read it in from a *.yml* file.

In [10]:
# %load model_spec.yml
periods: 2
rho: 0.5
    
get_model_obj('test.yml');

* We want to be able to update the parameters of the model specification during an optimization.

In [11]:
spec_1 = get_model_obj(None)
spec_2 = spec_1.replace(rho=0.9)

* We want to easily compare different model specifications.

In [12]:
spec_1 = get_model_obj(None)
assert spec_1 != spec_1.replace(rho=0.9)

spec_1 = get_model_obj(None)
spec_2 = spec_1.copy()
assert spec_1 == spec_2

* We want to be able to go back and forth between the different ways a model is stored.

In [24]:
for _ in range(100):
    spec_1 = get_model_obj(None)
    spec_1.to_yaml()

    spec_2 = get_model_obj('test.yml')
    assert spec_1 == spec_2

    spec_3 = get_model_obj(pathlib.Path('test.yml'))
    assert spec_1 == spec_3
    
    spec_4 = get_model_obj(spec_1.as_dict())
    assert spec_1 == spec_4

* We want to easily validate the integrity of our model specification.

In [14]:
for _ in range(100):
    spec = get_model_obj(None)
    spec.validate()

* We want to easily access all fields.

In [15]:
for _ in range(100):
    spec = get_model_obj(None)
    for field in spec._fields:
        field, getattr(spec, field)

* We want to easily inspect the model specification.

In [16]:
spec = get_model_obj(None)
print(spec)

rho: 0.28615267980194814
periods: 3



* We do not want to change parts of our model specification by accident.

In [17]:
spec = get_model_obj(None)

# We cannot change a field already defined.
with np.testing.assert_raises(AttributeError):
    spec.periods = 1

# We cannot add a field dynamically.
spec_1 = spec.copy()
spec.period = 1 
# Note that the statement above does not throw an error though.
# It simply does not have any effect
assert set(spec._fields) == set(['rho', 'periods'])

## Integration

This model class can then be used to work with the specified model.

In [18]:
def simulate(model_spec):
    '''This function simulates a simple AR(1) process.'''
    assert isinstance(model_spec, ModelCls)
    
    sequence = np.tile(np.nan, model_spec.periods)
    sequence[0] = np.random.normal()
    for i in range(1, model_spec.periods):
        sequence[i] = np.random.normal() + model_spec.rho * sequence[i - 1]
    return sequence

model_spec = get_model_obj('test.yml')
simulate(model_spec);

We can then combine the testing features.

In [19]:
for _ in range(100):
    rslt = simulate(get_model_obj(None))
    assert np.isnan(rslt).any() == False