# Template notebook for fitting synthetic (model-generated) data

The synthetic data used for fitting needs to be contained in a netCDF file, just like the regular size-distribution data. One simple way to create a synthetic file is to follow these steps:
 1. Run the `fit_models.ipynb` notebook setting the `prior_only = True`.
 2. Select a sample from the output file (specified by the `savename_output` option).
 3. Use the [extract_synthetic_data.py](../scripts/extract_synthetic_data.py) script to convert into a synthetic data file in the correct format.

## Preparation

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import netCDF4 as nc4
import numpy as np
import pandas as pd
import logging

import pystan
import re
import time

In [None]:
# use test data (not all data is used for fitting/training)
use_testdata = False

# create plots of the data
show_data = True

# add noise to synthetic data
add_noise=True

# the sigma parameter determining the amount of noise (higher value -> less noise)
sigma_noise = 10000.

# netCDF output file (set to None to not save output)
savename_output = '../results/m_ftf_synthetic01.nc'

# save the Stan output instead a few stats (only active if filename is specified above)
save_stan_output = True

save_only_converged = True

# specify the Stan variable names to save; if set to None, all variables are saved 
# (only active if save_stan_output is True)
varnames_save = None

# the number of tries to fit each Stan model to achieve an R-hat < 1.1
num_tries = 3

# the number of chains to run
num_chains = 4

# the prior_only option passed to each Stan model
prior_only = False

# Number of days of data to fit
limit_days = 2

# Whether or not data limit is inclusive (include boundary point)
inclusive = False

# Whether to append dataset to itself to create a 96-hour dataset
# This option is used for validation experiments in rolling_window.ipynb
extend = False

# Whether to append results to an existing file or overwrite
append = False

## Load, plot synthetic data

In [None]:
# data file containing the synthetic data without added noise
datafile = '../data/size_distribution/synthetic_m_ftf_01.nc'

dataname = 'zinser'

desc = 'synthetic dataset'

# Indices of data to hold out for hold-out validation
itestfile = None

size_units = 'fg C cell$^{-1}$'

def get_data(datafile, size_units, itestfile, dataname, desc, extend=False):
    
    data_gridded = {}
    with nc4.Dataset(datafile) as nc:
        for var in nc.variables:
            data_gridded[var] = nc.variables[var][:]
            
    # create "counts" entry
    if 'count' in data_gridded:
        data_gridded['counts'] = (data_gridded['count'][None,:]
                                  * data_gridded['w_obs']).astype(int)
    else:
        raise RuntimeError('Cannot find a "count" entry in "{}".'.format(datafile))
    
    
    # Appends the time series to itself to create a pseudo four-day dataset
    if extend:
        data_gridded['time'] = np.concatenate((data_gridded['time'],
                                               (data_gridded['time']
                                               + data_gridded['time'][-1]
                                               + data_gridded['time'][1])[:-1]))
        
        for item in ('w_obs', 'PAR', 'abundance', 'count', 'counts'):
            if len(data_gridded[item].shape) == 2:
                data_gridded[item] = np.concatenate((data_gridded[item],
                                                     data_gridded[item][:, 1:]), axis=1)
            else:
                data_gridded[item] = np.concatenate((data_gridded[item],
                                                     data_gridded[item][1:]))


    # add description
    desc += ' (m={data[m]}, $\Delta_v^{{-1}}$={data[delta_v_inv]})'.format(data=data_gridded)
    
    return data_gridded, desc

data_gridded, desc = get_data(datafile, size_units, itestfile, dataname, desc, extend=extend)

In [None]:
def add_colorbar(ax, **cbarargs):
    axins_cbar = inset_axes(ax, width='3%', height='90%', loc=5,
                            bbox_to_anchor=(0.05,0.0,1,1),
                            bbox_transform=ax.transAxes)
    mpl.colorbar.ColorbarBase(axins_cbar, orientation='vertical',
                              **cbarargs)

if show_data:
    nrows = 3

    v_min = data_gridded['v_min']
    delta_v = 1.0/data_gridded['delta_v_inv']
    v = v_min * 2**(np.arange(data_gridded['m'])*delta_v) 

    fig,axs = plt.subplots(nrows=nrows, sharex=True, figsize=(12,4*nrows))

    ax = axs[0]
    ax.set_title('raw '+desc, size=20)
    ax.plot(data_gridded['time'], data_gridded['PAR'], color='gold')
    ax.set(ylabel='PAR')

    ax = axs[1]
    pc = ax.pcolormesh(data_gridded['time'], v, data_gridded['w_obs'],
                       shading='auto')
    ax.set(ylabel='size ({})'.format(size_units))
    add_colorbar(ax, norm=pc.norm, cmap=pc.cmap, label='size class proportion')

    ax = axs[2]
    pc = ax.pcolormesh(data_gridded['time'], v, data_gridded['counts'],
                       shading='auto')
    ax.set(ylabel='size ({})'.format(size_units))
    add_colorbar(ax, norm=pc.norm, cmap=pc.cmap, label='counts')
axs[-1].set_xlabel=('time (minutes)')
None

## Process data, add noise, and re-plot

In [None]:
# prepare data for Stan model
def data_prep(data_gridded, dt=20, limit_days=2, start=0, use_testdata=False,
              itestfile=None, prior_only=False, inclusive=False, add_noise=True, 
              sigma_noise=10000.0):
    
    data = {'dt':dt}
    for v in ('m','v_min','delta_v_inv'):
        data[v] = data_gridded[v]

    data['obs'] = data_gridded['w_obs']
    data['t_obs'] = data_gridded['time']
    par = data_gridded['PAR']

    # add noise
    if add_noise:
        for icol in range(data['obs'].shape[1]):
            alpha = data['obs'][:, icol] / np.sum(data['obs'][:, icol]) * sigma_noise + 1.0
            data['obs'][:, icol] = np.random.multinomial(data_gridded['count'][icol], 
                                                         np.random.dirichlet(alpha))
        data['obs'] /= np.sum(data['obs'], axis=0)[None,:]
    
    if limit_days > 0:
        limit_minutes = limit_days*1440

        if inclusive:
            ind_obs = (start*60 <= data['t_obs']) & (data['t_obs'] <= limit_minutes+start*60)
        else:
            ind_obs = (start*60 <= data['t_obs']) & (data['t_obs'] < limit_minutes+start*60)

        if not np.all(ind_obs):
            total = data['obs'].shape[1]
            remove = total - data['obs'][:, ind_obs].shape[1]
            print('start is set to {}, limit_days is set to {}, removing {}/{} observation times'.format(start,
                                                                                                         limit_days,
                                                                                                         remove,
                                                                                                         total))

        data['t_obs'] = data['t_obs'][ind_obs]
        data['obs'] = data['obs'][:,ind_obs]

        data['nt'] = int(limit_minutes//data['dt']+1)

    data['nt_obs'] = data['t_obs'].size

    if use_testdata:
        # load cross-validation testing indices and add them to data
        data['i_test'] = np.loadtxt(itestfile).astype(int)
        # remove last index, so that dimensions agree
        data['i_test'] = data['i_test'][:-1]
    else:
        # set all indices to zero
        data['i_test'] = np.zeros(data['nt_obs'], dtype=int)

    # switch on or off data fitting
    data['prior_only'] = int(prior_only)

    # add light data
    t = np.arange(data['nt'])*data['dt'] + start*60
    data['E'] = np.interp(t, xp=data_gridded['time'][ind_obs], fp=par[ind_obs])

    # real count data
    data['obs_count'] = data_gridded['counts'][:, ind_obs]
    
    data['start'] = start

    # consistency check
    if len(data['i_test']) != data['nt_obs']:
        raise ValueError('Invalid number of testing indices (expected {}, got {}).'.format(data['nt_obs'],
                                                                                       len(data['i_test'])))
    return data


data = data_prep(data_gridded, dt=20, limit_days=limit_days, start=0, use_testdata=use_testdata, 
                 itestfile=itestfile, prior_only=prior_only, inclusive=False, add_noise=add_noise,
                 sigma_noise=sigma_noise)

In [None]:
if show_data:
    nrows = 3

    v_min = data['v_min']
    delta_v = 1.0/data['delta_v_inv']
    v = v_min * 2**(np.arange(data['m'])*delta_v) 
    t = np.arange(data['nt'])*data['dt']


    fig,axs = plt.subplots(nrows=nrows, sharex=True, figsize=(12,4*nrows))

    ax = axs[0]
    ax.set_title('processed '+desc, size=20)
    ax.plot(t, data['E'], color='gold')
    ax.set(ylabel='E')

    ax = axs[1]
    pc = ax.pcolormesh(data['t_obs'], v, data['obs'], shading='auto')
    ax.set(ylabel='size ({})'.format(size_units))
    add_colorbar(ax, norm=pc.norm, cmap=pc.cmap,
                 label='size class proportion')
    ax.set_xlim(left=0.0)

    ax = axs[2]
    pc = ax.pcolormesh(data['t_obs'], v, data['obs_count'], shading='auto')
    ax.set(ylabel='size ({})'.format(size_units))
    add_colorbar(ax, norm=pc.norm, cmap=pc.cmap, label='counts')
    ax.set_xlim(left=0.0)
axs[-1].set_xlabel('time (minutes)')
None

## Choose model to fit

Only run the model that generated the synthetic data

In [None]:
# Code files
stan_files = {
    #'m_bmx': '../stan_code/m_bmx.stan',
    #'m_bmb': '../stan_code/m_bmb.stan',
    #'m_pmb': '../stan_code/m_pmb.stan',
    #'m_fmb': '../stan_code/m_fmb.stan',
    #'m_fmf': '../stan_code/m_fmf.stan',
    #'m_btb': '../stan_code/m_btb.stan',
    #'m_ptb': '../stan_code/m_ptb.stan',
    #'m_ftb': '../stan_code/m_ftb.stan',
    'm_ftf': '../stan_code/m_ftf.stan',
}

## Fit model and save results

In [None]:
def get_max_rhat(fit):
    s = fit.summary()
    irhat = s['summary_colnames'].index("Rhat")
    return np.nanmax(s['summary'][:,irhat])

if 'models' not in globals():
    models = {}
if 'num_tries' not in globals():
    num_tries = 3
    
try_again = True
refit_all = False

refit_required = {}
stan_base_code = {}
for key, stan_file in stan_files.items():
    with open(stan_file) as f:
        stan_base_code[key] = f.read()

stan_code = {}
for model in stan_files.keys():
    code_split = stan_base_code[model].split('\n')
    stan_code[model] = '\n'.join(code_split)

In [None]:
for model in stan_files.keys():
    refit_required[model] = True
    if model in models and models[model].model_code == stan_code[model]:
        print('{}: unchanged code, not recompiling'.format(model))
        refit_required[model] = False
    else:
        if model in models:
            print('{}: code change detected, recompiling'.format(model))
        else:
            print('{}: compiling'.format(model))
        models[model] = pystan.StanModel(model_code=stan_code[model],
                                         model_name=model,
                                         obfuscate_model_name=False)


# run a bunch of experiments -- this may take a while
for model in models:
    for itry in range(num_tries):
        t0 = time.time()
        logging.disable(logging.WARNING)
        mcmcs = models[model].sampling(data=data, iter=2000, chains=num_chains)
        logging.disable(logging.NOTSET)
        sampling_time = time.time() - t0  # in seconds
        print('Model {} for {}-hour window starting at {} hours fit in {} minutes.'.format(model,
                                                                                           limit_days*24+2*int(inclusive),
                                                                                           data['start'],
                                                                                           np.round(sampling_time/60, 2)))
        # get max Rhat
        rhat_max = get_max_rhat(mcmcs)
        print('{}: in try {}/{} found Rhat={:.3f}'.format(model, itry+1, num_tries, rhat_max), end='')
        if rhat_max < 1.1 or itry == num_tries - 1:
            print()
            break
        print(', trying again')

    print('{}'.format(model)) 
    print('\n'.join(x for x in mcmcs.__str__().split('\n') if '[' not in x))
    print()
    
    
    
    if 'varnames_save' not in globals():
        varnames_save = None

    if savename_output is not None:
        if model == tuple(models.keys())[0] and append == False:
            mode = 'w'
        else:
            mode = 'a'
        with nc4.Dataset(savename_output, mode) as nc:
            ncm = nc.createGroup(model)

            # write model description
            ncm.setncattr('code', stan_files[model])

            if save_stan_output:
                if save_only_converged and get_max_rhat(mcmcs) > 1.1:
                    logging.warning('Model "{}" did not converge -- skipping.'.format(model))
                    continue
                dimensions = {
                    'obstime':int(data['nt_obs']),
                    'time':int(data['nt']),
                    'sizeclass':int(data['m']),
                    'm_minus_j_plus_1':int(data['m']-data['delta_v_inv']),
                    'm_minus_1':int(data['m']-1),
                    'knots_minus_1':int(6-1),  # hardcoded, adjust for varying nknots
                    'sample': mcmcs['mod_obspos'].shape[0],
                }
                dimensions_inv = {v:k for k,v in dimensions.items()}

                for d in dimensions:
                    if d not in ncm.dimensions:
                        ncm.createDimension(d, dimensions[d])

                if 'tau[1]' in mcmcs.flatnames:
                    dimensions['tau'] = mcmcs['tau'].shape[1]
                    dimensions_inv[dimensions['tau']] = 'tau'
                    if 'tau' not in ncm.dimensions:
                        ncm.createDimension('tau', dimensions['tau'])

                if 'time' not in ncm.variables:
                    ncm.createVariable('time', int, ('time',))
                ncm.variables['time'][:] = int(data['dt']) * np.arange(data['nt'])
                ncm.variables['time'].units = 'minutes since start of experiment'

                if 'obstime' not in ncm.variables:
                    ncm.createVariable('obstime', int, ('obstime',))
                ncm.variables['obstime'][:] = data['t_obs'].astype(int)
                ncm.variables['obstime'].units = 'minutes since start of experiment'
                ncm.variables['obstime'].long_name = 'time of observations'

                for v in ('dt', 'm', 'v_min', 'delta_v_inv', 'obs', 'i_test',
                          'E', 'obs_count'):
                    if isinstance(data[v], int):
                        if v not in ncm.variables:
                            ncm.createVariable(v, int, zlib=True)
                        ncm.variables[v][:] = data[v]
                    elif isinstance(data[v], float):
                        if v not in ncm.variables:
                            ncm.createVariable(v, float, zlib=True)
                        ncm.variables[v][:] = data[v]
                    else:
                        dims = tuple(dimensions_inv[d] for d in data[v].shape)
                        if v not in ncm.variables:
                            ncm.createVariable(v, data[v].dtype, dims, zlib=True)
                        ncm.variables[v][:] = data[v]


                varnames = set(v.split('[')[0] for v in mcmcs.flatnames)
                if varnames_save is None:
                    varnames_curr = varnames
                else:
                    varnames_curr = varnames_save

                for v in varnames_curr:
                    if v in varnames:
                        dims = tuple(dimensions_inv[d]
                                     for d in mcmcs[v].shape)
                        if v not in ncm.variables:
                            ncm.createVariable(v, float, dims, zlib=True)
                        ncm.variables[v][:] = mcmcs[v]
                    else:
                        logging.warning('Cannot find variable "{}" for model "{}".'.format(v,
                                                                                           model))
            else:
                if 'sample' not in ncm.dimensions:
                    ncm.createDimension('sample',
                                        mcmcs['divrate'].shape[0])

                if 'divrate' not in ncm.variables:
                    ncm.createVariable('divrate', float, ('sample'))

                if 'sumsqdiff' not in ncm.variables:
                    ncm.createVariable('sumsqdiff', float, ('sample'))

                ncm.variables['sumsqdiff'].setncattr('long_name',
                                                     'sum of squared column differences')

                ncm.variables['divrate'][:] = mcmcs['divrate']

                obs = data['obs']

                tmp = mcmcs['mod_obspos']
                tmp /= np.sum(tmp, axis=1)[:, None, :]
                tmp -= obs[None, :, :]
                tmp **= 2

                if np.all(data['i_test'] == 0):
                    ncm.variables['sumsqdiff'][:] = np.mean(np.sum(tmp, axis=1),
                                                              axis=1)
                    ncm.variables['sumsqdiff'].setncattr('data_used',
                                                         'all data')
                else:
                    ncm.variables['sumsqdiff'][:] = np.mean(np.sum(tmp[:, :, data['i_test'] == 1],
                                                                    axis=1), axis=1)
                    ncm.variables['sumsqdiff'].setncattr('data_used', 'testing data')

                for iv,v in enumerate(('gamma_max', 'rho_max', 'xi',
                                       'xir', 'E_star')):
                    if v not in ncm.variables:
                        ncm.createVariable(v, float, ('model','sample'))
                    if v in mcmcs.flatnames:
                        ncm.variables[v][:] = mcmcs[v]