# Model Selection

## Packages

We begin by loading in necessary software packages, introducing options for saving the results, and software for computing PSIS LOO in Python.

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import netCDF4 as nc4
import numpy as np
import pandas as pd
import logging

import pystan
import re
import time

In [None]:
# use test data (not all data is used for fitting/training)
use_testdata = False

# create plots of the data
show_data = True

# netCDF output file (set to None to not save output)
savename_output = '../results/zinser_results_window.nc'

# save the Stan output instead a few stats (only active if filename is specified above)
save_stan_output = True

# specify the Stan variable names to save; if set to None, all variables are saved 
# (only active if save_stan_output is True)
varnames_save = None

# the number of tries to fit each Stan model to achieve an R-hat < 1.1
num_tries = 3

# the number of chains to run
num_chains = 6

# the prior_only option passed to each Stan model
prior_only = False

# Whether or not to extend the time series two four days by appending the dataset to itself
extend = True

# Number of days in the full dataset
# Maximum 2 if extend is false, maximum 4 if extend is true
limit_days = 4

# Number of days in each window
# Recommend maximum to be half of limit_days
limit_days_window = 2

## Load, Plot Data

In [None]:
# load processed data
datafile = '../data/size_distribution/zinser_processed.nc'

dataname = 'zinser'

desc = 'Culture dataset'

# Indices of data to hold out for hold-out validation
# Uncomment desired line and set use_testdata to true
itestfile = None
# itestfile = '../data/hold_out/keep_twothirds.csv'
# itestfile = '../data/hold_out/keep_half.csv'
# itestfile = '../data/hold_out/keep_onethird.csv'

size_units = 'fg C cell$^{-1}$'


def get_data(datafile, size_units, itestfile, dataname, desc, extend=False):
    
    data_gridded = {}
    with nc4.Dataset(datafile) as nc:
        for var in nc.variables:
            data_gridded[var] = nc.variables[var][:]

    # create "counts" entry
    if 'count' in data_gridded:
        data_gridded['counts'] = (data_gridded['count'][None,:]
                                  * data_gridded['w_obs']).astype(int)
    elif 'abundance' in data_gridded:
        logging.warning('Using "abundance" data to generate count data for the model.')
        data_gridded['counts'] = (data_gridded['count'][None,:]
                                  * data_gridded['w_obs']).astype(int)
    else:
        raise RuntimeError('Cannot find a "count" or "abundance" entry in "{}".'.format(datafile))
    
    # Appends the time series to itself to create a pseudo four-day dataset
    if extend:
        data_gridded['time'] = np.concatenate((data_gridded['time'],
                                               (data_gridded['time']
                                               + data_gridded['time'][-1]
                                               + data_gridded['time'][1])[:-1]))
        
        for item in ('w_obs', 'PAR', 'abundance', 'count', 'counts'):
            if len(data_gridded[item].shape) == 2:
                data_gridded[item] = np.concatenate((data_gridded[item],
                                                     data_gridded[item][:, 1:]), axis=1)
            else:
                data_gridded[item] = np.concatenate((data_gridded[item],
                                                     data_gridded[item][1:]))

    # add description
    desc += ' (m={data[m]}, $\Delta_v^{{-1}}$={data[delta_v_inv]})'.format(data=data_gridded)
    
    return data_gridded, desc

data_gridded, desc = get_data(datafile, size_units, itestfile, dataname, desc, extend=extend)

In [None]:
def add_colorbar(ax, **cbarargs):
    axins_cbar = inset_axes(ax, width='3%', height='90%', loc=5,
                            bbox_to_anchor=(0.05,0.0,1,1),
                            bbox_transform=ax.transAxes)
    mpl.colorbar.ColorbarBase(axins_cbar, orientation='vertical',
                              **cbarargs)

if show_data:
    nrows = 3

    v_min = data_gridded['v_min']
    delta_v = 1.0/data_gridded['delta_v_inv']
    v = v_min * 2**(np.arange(data_gridded['m'])*delta_v) 

    fig,axs = plt.subplots(nrows=nrows, sharex=True, figsize=(12,4*nrows))

    ax = axs[0]
    ax.set_title('raw '+desc, size=20)
    ax.plot(data_gridded['time'], data_gridded['PAR'], color='gold')
    ax.set(ylabel='PAR')

    ax = axs[1]
    pc = ax.pcolormesh(data_gridded['time'], v, data_gridded['w_obs'],
                       shading='auto')
    ax.set(ylabel='size ({})'.format(size_units))
    add_colorbar(ax, norm=pc.norm, cmap=pc.cmap, label='size class proportion')

    ax = axs[2]
    pc = ax.pcolormesh(data_gridded['time'], v, data_gridded['counts'],
                       shading='auto')
    ax.set(ylabel='size ({})'.format(size_units))
    add_colorbar(ax, norm=pc.norm, cmap=pc.cmap, label='counts')
axs[-1].set_xlabel=('time (minutes)')
None

## Process and Re-plot Data

In [None]:
# prepare data for Stan model
def data_prep(data_gridded, dt=20, limit_days=2, start=0, use_testdata=False,
              itestfile=None, prior_only=False):
    
    data = {'dt':dt}
    for v in ('m','v_min','delta_v_inv'):
        data[v] = data_gridded[v]

    data['obs'] = data_gridded['w_obs']
    data['t_obs'] = data_gridded['time']
    par = data_gridded['PAR']

    if limit_days > 0:
        limit_minutes = limit_days*1440

        ind_obs = (start*60 <= data['t_obs']) & (data['t_obs'] < limit_minutes+start*60)

        if not np.all(ind_obs):
            total = data['obs'].shape[1]
            remove = total - data['obs'][:, ind_obs].shape[1]
            print('start is set to {}, limit_days is set to {}, removing {}/{} observation times'.format(start,
                                                                                                         limit_days,
                                                                                                         remove,
                                                                                                         total))

        data['t_obs'] = data['t_obs'][ind_obs]
        data['obs'] = data['obs'][:,ind_obs]

        data['nt'] = int(limit_minutes//data['dt']+1)

    data['nt_obs'] = data['t_obs'].size

    if use_testdata:
        # load cross-validation testing indices and add them to data
        data['i_test'] = np.loadtxt(itestfile).astype(int)
        # remove last index, so that dimensions agree
        data['i_test'] = data['i_test'][:-1]
    else:
        # set all indices to zero
        data['i_test'] = np.zeros(data['nt_obs'], dtype=int)

    # switch on or off data fitting
    data['prior_only'] = int(prior_only)

    # add light data
    t = np.arange(data['nt'])*data['dt'] + start*60
    data['E'] = np.interp(t, xp=data_gridded['time'][ind_obs], fp=par[ind_obs])

    # real count data
    data['obs_count'] = data_gridded['counts'][:, ind_obs]
    
    data['start'] = start

    # consistency check
    if len(data['i_test']) != data['nt_obs']:
        raise ValueError('Invalid number of testing indices (expected {}, got {}).'.format(data['nt_obs'],
                                                                                       len(data['i_test'])))
    return data


data = data_prep(data_gridded, dt=20, limit_days=limit_days, start=0,
                 use_testdata=use_testdata, itestfile=itestfile,
                 prior_only=prior_only)

In [None]:
if show_data:
    nrows = 3

    v_min = data['v_min']
    delta_v = 1.0/data['delta_v_inv']
    v = v_min * 2**(np.arange(data['m'])*delta_v) 
    t = np.arange(data['nt'])*data['dt']


    fig,axs = plt.subplots(nrows=nrows, sharex=True, figsize=(12,4*nrows))

    ax = axs[0]
    ax.set_title('processed '+desc, size=20)
    ax.plot(t, data['E'], color='gold')
    ax.set(ylabel='E')

    ax = axs[1]
    pc = ax.pcolormesh(data['t_obs'], v, data['obs'], shading='auto')
    ax.set(ylabel='size ({})'.format(size_units))
    add_colorbar(ax, norm=pc.norm, cmap=pc.cmap,
                 label='size class proportion')
    ax.set_xlim(left=0.0)

    ax = axs[2]
    pc = ax.pcolormesh(data['t_obs'], v, data['obs_count'], shading='auto')
    ax.set(ylabel='size ({})'.format(size_units))
    add_colorbar(ax, norm=pc.norm, cmap=pc.cmap, label='counts')
    ax.set_xlim(left=0.0)
axs[-1].set_xlabel('time (minutes)')
None

## Choose models to fit

In [None]:
# Code files
stan_files = {
    'm_bmx': '../stan_code/m_bmx.stan',
    'm_bmb': '../stan_code/m_bmb.stan',
    'm_pmb': '../stan_code/m_pmb.stan',
    'm_fmb': '../stan_code/m_fmb.stan',
    'm_fmf': '../stan_code/m_fmf.stan',
    'm_btb': '../stan_code/m_btb.stan',
    'm_ptb': '../stan_code/m_ptb.stan',
    'm_ftb': '../stan_code/m_ftb.stan',
    'm_ftf': '../stan_code/m_ftf.stan',
}

## Fit Models to the Data in a rolling window

In [None]:
def get_max_rhat(fit):
    s = fit.summary()
    irhat = s['summary_colnames'].index("Rhat")
    return np.nanmax(s['summary'][:,irhat])

if 'models' not in globals():
    models = {}
if 'mcmcs' not in globals():
    mcmcs = {}
if 'maxrhats' not in globals():
    maxrhats = {}
if 'sampling_time' not in globals():
    sampling_time = {}
if 'num_tries' not in globals():
    num_tries = 3
    
try_again = True
refit_all = False

refit_required = {}
stan_base_code = {}
for key, stan_file in stan_files.items():
    with open(stan_file) as f:
        stan_base_code[key] = f.read()

stan_code = {}
for model in stan_files.keys():
    code_split = stan_base_code[model].split('\n')
    stan_code[model] = '\n'.join(code_split)

In [None]:
for model in stan_files.keys():
    refit_required[model] = True
    if model in models and models[model].model_code == stan_code[model]:
        print('{}: unchanged code, not recompiling'.format(model))
        refit_required[model] = False
    else:
        if model in models:
            print('{}: code change detected, recompiling'.format(model))
        else:
            print('{}: compiling'.format(model))
        models[model] = pystan.StanModel(model_code=stan_code[model],
                                         model_name=model,
                                         obfuscate_model_name=False)

# Get window slices of data
max_start_time = limit_days*24 - limit_days_window*24 + 2
windows = np.arange(0, max_start_time+1, 2)  # Start times of each rolling window
data = {}
for window in windows:
    data[window] = data_prep(data_gridded, dt=20, limit_days=limit_days_window, start=window)

# run a bunch of experiments -- this may take a while
for model in models:
    if model not in maxrhats:
        maxrhats[model] = {}
    if model not in sampling_time:
        sampling_time[model] = {}
    for window in windows:
        if window not in maxrhats[model]:
            maxrhats[model][window] = []
        if window not in sampling_time[model]:
            sampling_time[model][window] = []
        if model in mcmcs:
            if window in mcmcs[model] and not refit_all and not refit_required[model]:
                print('{}: found existing results:'.format(model))
                print('{}'.format(model)) 
                print('\n'.join(x for x in mcmcs[model][window].__str__().split('\n') if '[' not in x))
                rhat_max = get_max_rhat(mcmcs[model][window])
                if try_again and rhat_max >= 1.1:
                    print('{}: found Rhat={:.3f}, trying again'.format(model, rhat_max))
                else:
                    print('{}: not re-running model'.format(model))
                    print()
                    continue
            elif refit_all:
                print('{}: refit_all is active, re-running model'.format(model))
            elif refit_required[model]:
                print('{}: change in model code requires re-running model'.format(model))
        else:
            mcmcs[model] = {}
        for itry in range(num_tries):
            t0 = time.time()
            mcmcs[model][window] = models[model].sampling(data=data[window], iter=2000, chains=num_chains)
            sampling_time[model][window].append(time.time() - t0) # in seconds
            # get max Rhat
            rhat_max = get_max_rhat(mcmcs[model][window])
            maxrhats[model][window].append(rhat_max)
            print('{}: in try {}/{} found Rhat={:.3f}'.format(model, itry+1, num_tries, rhat_max), end='')
            if rhat_max < 1.1 or itry == num_tries - 1:
                print()
                break
            print(', trying again')
        
        print('{}'.format(model)) 
        print('\n'.join(x for x in mcmcs[model][window].__str__().split('\n') if '[' not in x))
        print()


## Save Results

In [None]:
if 'varnames_save' not in globals():
    varnames_save = None

save_only_converged = False

if savename_output is not None:
    with nc4.Dataset(savename_output, 'w') as nc:
        for model in mcmcs:
            ncm = nc.createGroup(model)
            
            # write model description
            ncm.setncattr('code', stan_files[model])
            
            if save_stan_output:
                for window in windows:
                    if save_only_converged and get_max_rhat(mcmcs[model][window]) > 1.1:
                        logging.warning('Model "{}" did not converge -- skipping.'.format(model))
                        continue
                    ncg = ncm.createGroup(str(window))
                    dimensions = {
                        'obstime':int(data[window]['nt_obs']),
                        'time':int(data[window]['nt']),
                        'sizeclass':int(data[window]['m']),
                        'm_minus_j_plus_1':int(data[window]['m']-data[window]['delta_v_inv']),
                        'm_minus_1':int(data[window]['m']-1),
                        'knots_minus_1':int(6-1),  # hardcoded, adjust for varying nknots
                        'sample': mcmcs[model][window]['mod_obspos'].shape[0],
                    }
                    dimensions_inv = {v:k for k,v in dimensions.items()}
                    
                    for d in dimensions:
                        ncg.createDimension(d, dimensions[d])
                    
                    if 'tau[1]' in mcmcs[model][window].flatnames:
                        dimensions['tau'] = mcmcs[model][window]['tau'].shape[1]
                        dimensions_inv[dimensions['tau']] = 'tau'
                        ncg.createDimension('tau', dimensions['tau'])

                    ncg.createVariable('time', int, ('time',))
                    ncg.variables['time'][:] = int(data[window]['dt']) * np.arange(data[window]['nt'])
                    ncg.variables['time'].units = 'minutes since start of experiment'

                    ncg.createVariable('obstime', int, ('obstime',))
                    ncg.variables['obstime'][:] = data[window]['t_obs'].astype(int)
                    ncg.variables['obstime'].units = 'minutes since start of experiment'
                    ncg.variables['obstime'].long_name = 'time of observations'

                    for v in ('dt', 'm', 'v_min', 'delta_v_inv', 'obs', 'i_test',
                              'E', 'obs_count'):
                        if isinstance(data[window][v], int):
                            ncg.createVariable(v, int, zlib=True)
                            ncg.variables[v][:] = data[window][v]
                        elif isinstance(data[window][v], float):
                            ncg.createVariable(v, float, zlib=True)
                            ncg.variables[v][:] = data[window][v]
                        else:
                            dims = tuple(dimensions_inv[d] for d in data[window][v].shape)
                            ncg.createVariable(v, data[window][v].dtype, dims, zlib=True)
                            ncg.variables[v][:] = data[window][v]

                
                    varnames = set(v.split('[')[0] for v in mcmcs[model][window].flatnames)
                    if varnames_save is None:
                        varnames_curr = varnames
                    else:
                        varnames_curr = varnames_save

                    for v in varnames_curr:
                        if v in varnames:
                            dims = tuple(dimensions_inv[d]
                                         for d in mcmcs[model][window][v].shape)
                            ncg.createVariable(v, float, dims, zlib=True)
                            ncg.variables[v][:] = mcmcs[model][window][v]
                        else:
                            logging.warning('Cannot find variable "{}" for model "{}".'.format(v,
                                                                                               model))
            else:
                for i, window in enumerate(windows):
                    if i == 0:
                        ncm.createDimension('window', len(mcmcs[model]))
                        ncm.createDimension('sample',
                                            mcmcs[model][window]['divrate'].shape[0])

                        ncm.createVariable('divrate', float, ('window','sample'))
                        ncm.createVariable('sumsqdiff', float, ('window','sample'))
                        ncm.variables['sumsqdiff'].setncattr('long_name',
                                                             'sum of squared column differences')

                    ncm.variables['divrate'][i,:] = mcmcs[model][window]['divrate']

                    obs = data[window]['obs']

                    tmp = mcmcs[model][window]['mod_obspos']
                    tmp/= np.sum(tmp, axis=1)[:, None, :]
                    tmp -= obs[None, :, :]
                    tmp **= 2

                    if np.all(data[window]['i_test'] == 0):
                        ncm.variables['sumsqdiff'][i,:] = np.mean(np.sum(tmp, axis=1),
                                                                  axis=1)
                        if i == 0:
                            ncm.variables['sumsqdiff'].setncattr('data_used',
                                                                 'all data')
                    else:
                        nc.variables['sumsqdiff'][i,:] = np.mean(np.sum(tmp[:, :, data[window]['i_test'] == 1],
                                                                        axis=1), axis=1)
                        if i == 0:
                            ncm.variables['sumsqdiff'].setncattr('data_used', 'testing data')

                    for iv,v in enumerate(('gamma_max', 'rho_max', 'xi',
                                           'xir', 'E_star')):
                        if i == 0:
                            ncm.createVariable(v, float, ('model','sample'))
                        if v in mcmcs[model][window].flatnames:
                            ncm.variables[v][i,:] = mcmcs[model][window][v]