
# PGAM Tutorial

## Introduction
This tutorial is aimed to introduce the user to some key concepts of Generalized Additive Model (GAMs), how these concepts are implemented in this specific PGAM library, and describe in detail how to estimate tuning functions with the PGAM library on an example syntetic dataset.


## Why GAMs?
<!--
Estimating tuning functions entails finding maps that characterize how a set of task variables affects the firing rate of a recorded neuron. Some of the challenges that comes with this estimation problem are: (i) the experimenter has no direct access to the firing rate, but can only measure spikes; (ii) there may be no a priori hypothesis on the shape of the tuning functions (excluding some well characterized special case, e.g. the Gabor filter-like responses in V1); (iii) correlations between task variables may be a confunding factor (e.g. if eye movemet and hand movement are correlated during a reaching movement, it becomes hard to discriminate if the hand, the eye or a combination of both is driving a neuron).  

During naturalistic experiments, where there is a lack of identical trial repeats and the behavior is less constrained, those challenges become even more prominent: (i) the firing rate cannot be easily estimated by trial averaging over "identical" experimental conditions; (ii) cortical neurons manifest a strong mixed selectivity to a multitude of behavioral covariates and stimuli features that may not be apparent in trial-based experiments due to the simpler, usually multi-alternative stimulus and response space. Mixed selective reponses are not yet fully characterized for most brain areas; (iii) finally, no/weaker control over the animal behavior and sensory experience may introduce additional correlations (e.g. a correlation between eye position and visual stimuli will arise if no eye-fixation is imposed).
-->

Generalized Linear Models (GLMs) have been succesful in characterizing mixed selective responses by capturing well the statistics of spike trains (which can be modelled as Poisson distributed observations, no averaging needed) and by jointly estimating the contribution of a (potentially) large number of task variables. 

However, GLMs comes with their own limitations. In particular, one needs to carefully choose how to represent tuning functions (in the case of GLMs this translates into chosing an appropriate basis of functions, e.g. Gaussian-shaped, Fourrier, cosine raised... and the number of basis element to be used). More importantly, defining the minimal subset of variable which drive the neural activity becomes cumbersome for a naive implementations. Variables are usually selected through model comparison, an approach that becomes quickly unfeasible when the number of task variable increases (combinatorial explosion of candidate models). Additionally, classical stepwise methods comes with well known theoretical and practical flaws (e.g., Frank Harrell (2001)).

Our solution takles those limitation by taking advantage of GAM theory. GAMs are non-linear extensions of GLMs that retain the advantages mentioned above (model counts directly and jointly infers responses), but additoinally learns from the data the type of non-linearities that are suited to represent each individual response function. As we will see in the tutorial, this will translate into learning the proper prior distribution over a set of possible response functions. The appproach comes with the additional benefit of deriving confidence intervals over the model parameters that can be used to select the minimal subset of variables driving neural acticity. Since selection is based on statistical testing, we completely circumvented costly model comparison. 

Overall, our approach is both user-friendly, requiring less choices for the user, and computatoinally advantageous when variable selection is required.



## What is covered in the tutorial?

The tutorial will cover the main components of the model, in particular:

1. <a href="#repr-nl">**Representing non-linearities**: </a>

    1.1 <a href="#b-spline">**B-spline definition and properties**</a> To familiarize ourselves with the concept of B-splines, we will plot b-splines for different type of response functions: 1D and 2D responses, cyclic or not, and temporal kernels.
    
    1.2 <a href="#sm-prior">**Smoothing prior**</a> We will introduce the concept of smoothing prior and clarify the role of the smoothing prior in GAM fitting. We will provide intuitive insight into the role of the prior by drawing and plottnig functions sampled from different levels of smoothing. 


2. <a href="#pgam-lib">**Introduction to the PGAM library**</a> We will illustrate how the concepts introduced in the tutorial are implemented in the PGAM library and applied to the problem of tuning function estimation.

    2.1 <a href="#sm-handl">**Define the B-spline via the *smooths_handler* class**</a>: We will use a particular class (smooths_handler) to appropriately format the model covariates, as well as design the B-spline and the corresponding smoothing pealization.

    2.2. <a href="#model-fit">**Model fit**</a>: Fit the model by jointly learning the smoothing levels and the B-spline parametters.
    
    2.3. <a href="#post-proc">**Post processing**</a>: Post-process model outputs, plot and explore the outputs.

In [None]:
# Import libraries

# import sys
# ## if working outside the docker container, uncomment the line below and add the path to [YOUR PATH TO PGAM FOLDER]/src/
# ## sys.path.append('[YOUR PATH TO PGAM FOLDER]/src/')
# sys.path.append('src/')

pgam_path = '/Users/dusiyi/Documents/Multifirefly-Project/multiff_analysis/external/pgam/src/'
import sys
if not pgam_path in sys.path: 
    sys.path.append(pgam_path)
    
import numpy as np
import sys
from PGAM.GAM_library import *
import PGAM.gam_data_handlers as gdh
import matplotlib.pylab as plt
import pandas as pd
from post_processing import postprocess_results
from scipy.io import savemat

# 2 Introduction to the PGAM library<a name="pgam-lib">

The PGAM library simplifies the process of constructing and fitting GAM models for tuning function estimation. The key classes of the library are:
    
* **smooths_handler**: constructs B-spline and the penalty matrix for each variable and allows to concatenate multiple B-spline constructing a global model matrix.
    
* **general_additive_model**: a class that contains methods for fitting GAMs by means of dGCV optimization

## 2.1 Define the B-spline via the *smooths_handler* class <a name="sm-handl">


The *smooths_handler* class will construct the appropriate B-spline for all the covariates of interest. 

Each task variable needs to be inputed to the *smooths_handler* class one at the time via the method
    
        smooths_handler.add_smooth

The the method requires the following inputs:
* **name**: string, the label of the task variable

* **x_cov**: list containing the input variable (the list will contain 1 vector per dimension of the variable)
    
* **is_temporal_kernel**: boolean, True if the variable is "temporal", False if "spatial" (see <a href="spatial-temporal">below </a> for the definitioin of temporal and spatial variables)

* **kernel_direction**: int or None, None when "is_temporal_kernel == False". When "True", 0 = acausal (bidirectional), '1' = causal (i.e., firing change after the event happens), '-1' = anticipatory (i.e., firing change before event happens). See the <a href="tempcov"> temporal covariate</a> session for examples.

* **kernel_length**: int or None. None when "is_temporal_kernel == False". When "True the number of time points used for the kernel. Suggested to use an odd number of samples.

* **ord**: integer, the order of the B-spline

* **knots**: list or None. None when "is_temporal_kernel == True". If list, each element of the list is a vector of knots locations for a specific dimension of the variable. 

* **knots_num**: integer or None. If integer, the number of equispaced knots over the x_cov range (for spatial variables) or the temporal kernel range (for temporal variable); knots_num must be smaller then the number of time points that for the filter.

* **penalty_type**: 'der' for derivative based penalty matrix, or 'diff' for a difference based penalty matrix. see <a href='der-diff-pen'> above </a>.

* **der**: int or None. None if 'diff' penalty is used. The order of the derivative used for the penalizatoin. Default is 2 for a smoother penalty.

* **is_cyclic**: list of bool, "is_cyclic$[i]$ = True" if the i-th dimension of the task variable is cyclic

* **lam**: float, initial smoothing controlling parameter $\lambda$.

* **samp_period**: float, the sampling period in seconds.

* **trial_idx**: vector of length the number of samples containing the trial ids of each sample

### 2.1.1 Spatial vs. temporal covariates <a name="spatial-temporal">
We label the covariates as "spatial" or "temporal" in order to specify two different type of response functions. In a neuroscience application, a "spatial" variable would be a traditional tuning function, where the X-axis defines a range of stimuli, for example position of an animal in an arena, or orientation of gratings. While a "temporal" variable would describe response to events, such as stimulus onset.  

1. Response to **spatial variable** are instantaneous non-linear effects (the task variable $x_t$ immediately affects the rate in a non-linear way),
    \begin{align}
    f(x_t) = \sum_j\beta_j b_j(x_t).
    \end{align}

2. The response to a **temporal variable** is assumed to be the convolution of a kernel function (described in terms of B-spline) and the variable:

    \begin{align}
    f(x_t) &= \int_{-\infty}^{\infty} x(\tau) h(t-\tau) d \tau \\
    h(t) &= \sum_j \beta_j b_j(t)
    \end{align}

    where $b_j$ are spline basis elements. This means that past and/or future values of the variable $x_t$ will affect the current firing rate with a linear contribution weighted by $h$.



<!--Fitting a PGAM will entail learning the appropriate $\mathbf{\beta}$ coefficients characterizing the response function.

Below we will create an example of three syntetic covariates (an event indicator, a 1D continous variable and a 2D continuous variable) for an experiment with 2 trials of 500 time points per trial.-->

### 2.1.2 Temporal covariates <a name="tempcov">

Below we define and plot an acausal and the two directional filters.

In [None]:
convolved_ev.shape

In [None]:

# Define a series of event marker
tot_tp = 10**3

# set some trial ids
trial_ids = np.zeros(tot_tp)
trial_ids[400:] = 1

# sample some event marker at random
event = np.zeros(tot_tp)
event[[100, 200, 600, 900]] = 1

# define the b-spline params
kernel_h_legnth = 121 # duration of the kernel h(t) in time points 
num_int_knots = 12 # number of internal knots used to represent h
order = 4
dict_kernel = {0:'Acausal',1:'Direction %d'%1, -1:'Direction %d'%(-1)}


for kernel_direction in [0,1,-1]:
    # define the "smooths_handler" container
    sm_handler = gdh.smooths_handler()
    
    # add the covariate & evaluate the convolution
    sm_handler.add_smooth('this_event', 
                          [event], 
                          is_temporal_kernel=True, 
                          ord=order, 
                          knots_num=num_int_knots,
                          trial_idx=trial_ids,
                          kernel_length=kernel_h_legnth,
                          kernel_direction=kernel_direction)

    # sm_handler['varname'] process and stores the B-spline for the variable
    # below we retrive the B-spline convolved with the "event" variable
    convolved_ev = sm_handler['this_event'].X.toarray()
    
    # retrive the B-spline used for the convolution
    basis = sm_handler['this_event'].basis_kernel.toarray()

    # plot the basis & the convolved events
    plt.figure(figsize=(8,3))
    plt.suptitle('%s Filter'%dict_kernel[kernel_direction])
    
    # basis for the kenel h
    plt.subplot(121)
    plt.title('kernel basis')
    tps = np.repeat(np.arange(kernel_h_legnth)-kernel_h_legnth//2, basis.shape[1]).reshape(basis.shape)
    plt.plot(tps, basis)
    plt.xlabel('time points')

    plt.subplot(122)
    plt.title('convolved events')

    # select a time point interval containing an event
    idx0, idx1 = np.where(event == 1)[0][2] - 100, np.where(event == 1)[0][2] + 400

    # extract the events convolved with each of the B-spline elements
    conv = convolved_ev[idx0:idx1,:]

    tps = np.arange(0,idx1-idx0) - 100
    tps = np.repeat(tps,conv.shape[1]).reshape(conv.shape)
    plt.plot(tps, conv)
    plt.vlines(tps[0,0] + np.where(event[idx0:idx1])[0],0,1.5,'k',ls='--',label='event')
    plt.xlabel('time points')
    plt.legend()


### 2.1.3 Spatial variable 1D and 2D

Example "spatial" variables.

In [None]:
# generate three covariate
x = np.random.normal(size=tot_tp)
y = np.random.normal(size=tot_tp)
z = np.random.normal(size=tot_tp)

# add the 1d spatial variable
int_knots = np.linspace(-2,2,10)
order = 4
knots = np.hstack(([int_knots[0]]*(order-1), int_knots, [int_knots[-1]]*(order-1)))

# remove out of range values
x[np.abs(x)>2] = np.nan
y[np.abs(y)>2] = np.nan
z[np.abs(z)>2] = np.nan

# add the variable
if 'spatial_1D' in sm_handler.smooths_var:
    sm_handler.smooths_var.remove('spatial_1D')
    sm_handler.smooths_dict.pop('spatial_1D')
    
sm_handler.add_smooth('spatial_1D', [x], 
                      knots=[knots], 
                      ord=order, 
                      is_temporal_kernel=False,
                      trial_idx=trial_ids, 
                      is_cyclic=[False])


# retrive the b-spline evaluated at x.
X_1D = sm_handler['spatial_1D'].X.toarray()


# sort for plotting
plt.figure()
plt.title('1D spatial response')
idx_srt = np.argsort(x)
X_srt = X_1D[idx_srt]
p = plt.plot(X_srt)

# add a 2D response with one cyclic variable and one acyclic
if 'spatial_2D' in sm_handler.smooths_var:
    sm_handler.smooths_var.remove('spatial_2D')
    sm_handler.smooths_dict.pop('spatial_2D')

sm_handler.add_smooth('spatial_2D', [y,z], knots=[knots, knots], ord=order, is_temporal_kernel=False,
                     trial_idx=trial_ids, is_cyclic=[False,True])
X_2D = sm_handler['spatial_2D'].X.toarray()


# the size of basis set grows as n^m where n is the basis in the 1D case, and m is the number of dimensions
print('Size of X_1D',X_1D.shape)
print('Size of X_2D',X_2D.shape)

## 2.2 Model fit <a name="model-fit">
Putting all the pieces together, here we will create a *smooths_handler* object containing the appropriate covariates, fit the model, evaluate the fit quality, and save the post-processed outputs in a standard numpy structured array or as a MATLAB structure.
    
The class **general_additive_model** is used for defining the GAM. It requires the following inputs:
    
* **sm_handler**: the smooths_handler object
* **var_list**: list of variable names
* **y**: the array with the spike counts (all trials must be stacked in a 1D array)
* **link**: statsmoldels.genmod.families.links.link class which describe the link function (the library allows to fit any exponential family observation noise)


You can fit the GAM with the method **general_additive_model.fit_full_and_reduced**, which fits a model with all the variables in **var_list**, then selects a minimal subset of variables that drive the neural activity by statistical testing, and re-fits the model with the significant variables only.
    
The inputs for **fit_full_and_reduced** requires are the following:

* **var_list**: list with the subset of variables to be used for model fitting 
* **th_pval**: float between 0 and 1,the significance level for task variable inclusion
* **max_iter**: int, max number of iteration of the optization routine
* **use_dgcv**: True for learning the smoothing constants via dgcv
* **trial_idx**: vector of length the number of samples containing the trial ids of each sample
* **filter_trials**: vector of boolean, of the same length of *trial_idx*, indicates which time points should be used for training the model

    
In the following subsection we will provide an example where spike counts are generated according to the PGAM assumptions, and estimate the response function within the GAM framework.



### 2.2.1 Generate synthetic data
Below we generate a syntetic dataset.

In [None]:

## inputs parameters
num_events = 6000
time_points = 3 * 10 ** 5  # 30 mins at 0.006 ms resolution
rate = 5. * 0.006  # Hz rate of the final kernel
variance = 5.  # spatial input and nuisance variance
int_knots_num = 20  # num of internal knots for the spline basis
order = 4  # spline order

## assume 200 trials
trial_ids = np.repeat(np.arange(200),time_points//200)

## create temporal input
idx = np.random.choice(np.arange(time_points), num_events, replace=False)
events = np.zeros(time_points)
events[idx] = 1

rv = sts.multivariate_normal(mean=[0, 0], cov= variance * np.eye(2))
samp = rv.rvs(time_points)
spatial_var = samp[:, 0]
nuisance_var = samp[:, 1]

# truncate X to avoid jumps in the resp function
sele_idx = np.abs(spatial_var) < 5
spatial_var = spatial_var[sele_idx]
nuisance_var = nuisance_var[sele_idx]
while spatial_var.shape[0] < time_points:
    tmpX = rv.rvs(10 ** 4)
    sele_idx = np.abs(tmpX[:, 0]) < 5
    tmpX = tmpX[sele_idx, :]

    spatial_var = np.hstack((spatial_var, tmpX[:, 0]))
    nuisance_var = np.hstack((nuisance_var, tmpX[:, 1]))
spatial_var = spatial_var[:time_points]
nuisance_var = nuisance_var[:time_points]

# create a resp function
knots = np.hstack(([-5]*3, np.linspace(-5,5,8),[5]*3))
beta = np.arange(10)
beta = beta / np.linalg.norm(beta)
beta = np.hstack((beta[5:], beta[:5][::-1]))
resp_func = lambda x : np.dot(gdh.splineDesign(knots, x, order, der=0),beta)

filter_used_conv = sts.gamma.pdf(np.linspace(0,20,100),a=2) - sts.gamma.pdf(np.linspace(0,20,100),a=5)
filter_used_conv = np.hstack((np.zeros(101),filter_used_conv))*2
# mean of the spike counts depending on spatial_var and events
log_mu0 = resp_func(spatial_var)
for tr in np.unique(trial_ids):
    log_mu0[trial_ids == tr] = log_mu0[trial_ids == tr] + np.convolve(events[trial_ids == tr], filter_used_conv, mode='same')

# adjust mean rate
const = np.log(np.mean(np.exp(log_mu0)) / rate)
log_mu0 = log_mu0 - const

# generate spikes
spk_counts = np.random.poisson(np.exp(log_mu0))

In [None]:
# plot the firing rate and the spike counts generated
plt.figure(figsize=(8,3))
plt.subplot(121)
plt.plot(np.arange(1000) * 0.006, np.exp(log_mu0)[:1000]/0.006)
plt.title('firing rate [Hz]', fontsize=16)
plt.xlabel('time[sec]', fontsize=12)

plt.subplot(122)
plt.plot(np.arange(1000) * 0.006, spk_counts[:1000])
plt.title('6ms binned spike counts', fontsize=16)
plt.xlabel('time[sec]', fontsize=12)

### 2.2.2 Create the *smooths_handler* object and fit the model
Below we create the smooths_handler object and run a fit. We include a "nuisance" spatial variable, that is not driving the neuron, the fit will learn to discard the variable;

In [None]:
import statsmodels.api as sm

# Creating the class
sm_handler = smooths_handler()
# Creating the knots (notice the over-representation of edge knots)
knots = np.hstack(([-5]*3, np.linspace(-5,5,15),[5]*3))
# Using smooths_handler class to add variables 
sm_handler.add_smooth('spatial', [spatial_var], knots=[knots], ord=4, is_temporal_kernel=False,
                     trial_idx=trial_ids, is_cyclic=[False],penalty_type='der', der=2)

sm_handler.add_smooth('nuisance', [nuisance_var], knots=[knots], ord=4, is_temporal_kernel=False,
                     trial_idx=trial_ids, is_cyclic=[False],penalty_type='der', der=2)

sm_handler.add_smooth('temporal', [events], knots=None, ord=4, is_temporal_kernel=True,
                     trial_idx=trial_ids, is_cyclic=[False],penalty_type='der', der=2,
                     knots_num=10, kernel_length=500, kernel_direction=1)


# split trial in train and eval
train_trials = trial_ids % 10 != 0
eval_trials = ~train_trials


link = sm.genmod.families.links.log()
poissFam = sm.genmod.families.family.Poisson(link=link)

# create the pgam model
pgam = general_additive_model(sm_handler,
                              sm_handler.smooths_var, # list of covariate we want to include in the model
                              spk_counts, # vector of spike counts
                              poissFam # poisson family with exponential link from statsmodels.api
                             )

# with with all covariate, remove according to stat testing, and then refit
full, reduced = pgam.fit_full_and_reduced(sm_handler.smooths_var, 
                                          th_pval=0.001,# pval for significance of covariate icluseioon
                                          max_iter=10 ** 2, # max number of iteration
                                          use_dgcv=True, # learn the smoothing penalties by dgcv
                                          trial_num_vec=trial_ids,
                                          filter_trials=train_trials)

print('Minimal subset of variables driving the activity:')
print(reduced.var_list)

## 2.3 Post processing<a name="post-proc">
After a fit, it is possible to post-process the model fit output to obtain an easy to parse result in the form of a numpy.structarray. 

Each row will represent results for a specific input variable, additional information about the neuron (e.g. channel ID, electrode ID, or anything else) can be provided in the form of a dictionary, each dictionary value will be stored in the structured array with type "object".

The output structure can be saved either as a ".npy" via *numpy.save(\<filename\>)* or as a .mat (for MATLAB) via *scipy.io.savemat(\<filename*\>)*.

Below is an example of the post-processing applied to the fit just obtained.

In [None]:
# string with the neuron identifier
neuron_id = 'neuron_000_session_1_monkey_001'
# dictionary containing some information about the neuron, keys must be strings and values can be anything
# since are stored with type object.
info_save = {'x':100,
             'y':801.2,
             'z':301,
             'brain_region': 'V1',
             'subject':'monkey_001'
            }

# assume that we used 90% of the trials for training, 10% for evaluation
res = postprocess_results(neuron_id, spk_counts, full, reduced, train_trials,
                        sm_handler, poissFam, trial_ids, var_zscore_par=None,info_save=info_save,bins=100)

# each row of res contains the info about a variable
# some info are shared for all the variables (p-rsquared for example is a goodness of fit measure for the model
# it is shared, not a property of the variable), while other, like the parameters of the b-splines, 
# are variable specific
print('\n\n')
print('Result structarray types\n========================\n')
for name in res.dtype.names:
    print('%s: \t %s'%(name, type(res[name][0])))



In [None]:

# plot tuning functions
plt.figure(figsize=(8,4))

for k in range(3):
    plt.subplot(2,3,k+1)
    plt.title('log-space %s'%res['variable'][k])
    x_kernel = res['x_kernel'][k]
    
    # changed from the original tutorial
    x_kernel = x_kernel.reshape(-1)  # reshape for plotting
    
    
    y_kernel = res['y_kernel'][k]
    ypCI_kernel = res['y_kernel_pCI'][k]
    ymCI_kernel = res['y_kernel_mCI'][k]
    
    plt.plot(x_kernel, y_kernel, color='r')
    plt.fill_between(x_kernel, ymCI_kernel, ypCI_kernel, color='r', alpha=0.3)
    
    
    
    x_firing = res['x_rate_Hz'][k]
    y_firing_model = res['model_rate_Hz'][k]
    y_firing_raw = res['raw_rate_Hz'][k]
    
    plt.subplot(2,3,k+4)
    plt.title('rate-space %s'%res['variable'][k])
    
    plt.plot(x_firing, y_firing_raw, color='k',label='raw')
    plt.plot(x_firing, y_firing_model, color='r',label='model')
    
    plt.legend()
    plt.tight_layout()
    
    
    

In [None]:
# saving the output for further analysis
#np.save('/notebooks/result_pgam.npy', res)
#savemat('/notebooks/result_pgam.mat', mdict = {'result_pgam':res})

# References <a name="refs"></a>
<a id="1">[1]</a> 
<a href="https://proceedings.neurips.cc/paper/2020/hash/94d2a3c6dd19337f2511cdf8b4bf907e-Abstract.html">
Balzani, Edoardo , et al., 
"Efficient estimation of neural tuning during naturalistic behavior."
Advances in Neural Information Processing Systems 33 (2020): 12604-12614.<a>