# Lesson 3 : Mean function & Gaussian processes 

Below some packages to import that will be used for this lesson

Cell bellow is here for avoiding scrolling when plot is create within ipython notebook

In [1]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines){
    return false;
}

<IPython.core.display.Javascript object>

In [2]:
# Classical package for manipulating
# array, for plotting and interactiv plots.
import pylab as plt
from matplotlib import gridspec
import numpy as np
import scipy
import ipywidgets as widgets
from ipywidgets import interact
import itertools
import copy
import sys
import pickle
import treecorr
import treegp
from treegp import AnisotropicRBF, eval_kernel


def load_pickle(pickle_file):
    dico = pickle.load(open(pickle_file, 'rb'), encoding='latin1')
    return dico

## Exercice 6) Adding a mean function on real SNIa data, impact on GP interpolation (1D):

In [3]:
##########################################################################################
# EXERCICE 6: Adding a mean function on real SNIa data, impact on GP interpolation (1D): #
##########################################################################################

dic = load_pickle('data/snia_gaussian_process_de_school.pkl')

def gp_regression(x, new_x, y, kernel, y_err=None):
    
    if y_err is None:
        y_err =np.ones_like(y) *1e-10
    
    gp = treegp.GPInterpolation(kernel=kernel, optimizer='none', 
                                normalize=False, white_noise=0., p0=[3000., 0.,0.],
                                n_neighbors=4, average_fits=None, nbins=20, 
                                min_sep=None, max_sep=None)
    gp.initialize(x, y, y_err=y_err)
    y_predict, y_cov = gp.predict(new_x, return_cov=True)
    y_std = np.sqrt(np.diag(y_cov))
    log_L = gp.return_log_likelihood()
    return y_predict, y_std, log_L


def spline_1D(old_binning, mean_function, new_binning):
    cubic_spline = scipy.interpolate.InterpolatedUnivariateSpline(old_binning,
                                                                  mean_function)
    mean_interpolate = cubic_spline(new_binning)
    return mean_interpolate


@interact(sigma = widgets.FloatSlider(value=0.5, min=0.1, max=0.8, step=0.01, description='$\sigma$:',
          disabled=False,
          continuous_update=False,
          orientation='horizontal',
          readout=True,
          readout_format='.2f'), 
          l = widgets.FloatSlider(value=3., min=1., max=15, step=0.1, description='$l$:',
          disabled=False,
          continuous_update=False,
          orientation='horizontal',
          readout=True,
          readout_format='.2f'),
          add_mean=widgets.Checkbox(value=False,
                                   description='Add mean function',
                                   disabled=False),
          sn_name = widgets.Dropdown(options=['SNF20080514-002', 'SNF20050821-007', 'SNF20070802-000'],
                                     value='SNF20080514-002',
                                     description='SNIa name:',
                                     disabled=False,))
def plot_samples(sigma, l, add_mean, sn_name):
    
    i = 151
    
    new_x = np.linspace(-12, 48, 80).reshape((80, 1))
    Kernel = "%f * %s(%f)"%((sigma**2, "RBF", l))
    
    y = copy.deepcopy(dic[sn_name]['y'])
    y0 = copy.deepcopy(dic[sn_name]['y0'])
 
    if add_mean:
        y0_on_y = spline_1D(dic[sn_name]['y0_time'], y0, 
                            dic[sn_name]['y_time'])
    else:
        y0_on_y = 0
    
    epoch = dic[sn_name]['y_time'].reshape((len(dic[sn_name]['y_time']),1))

    y_pred, y_std, log_L = gp_regression(epoch, new_x, y-y0_on_y, 
                                         Kernel, y_err=dic[sn_name]['y_err'])
    if add_mean:
        y0_on_ypredict = spline_1D(dic[sn_name]['y0_time'], y0, 
                                   np.linspace(-12, 48, 80))
    else:
        y0_on_ypredict = 0
        
    y_pred += y0_on_ypredict

    plt.figure(figsize=(14,8))
    
    # Data
    plt.scatter(dic[sn_name]['y_time'], y, 
                c='b', label = 'data')
    plt.errorbar(dic[sn_name]['y_time'], y, 
                 linestyle='', yerr=dic[sn_name]['y_err'], ecolor='b', 
                 alpha=0.7,marker='.',zorder=0)
    
    # GP prediction
    plt.plot(new_x, y_pred, 'r', lw =3, label = 'GP prediction')
    plt.fill_between(new_x.T[0], y_pred-y_std, y_pred+y_std, color='r', alpha=0.3)
    
    if not add_mean:
        plt.plot(new_x, np.zeros_like(new_x),'k--', label='used mean function')
    else:
        plt.plot(dic[sn_name]['y0_time'], dic[sn_name]['y0'],
                 'k--', label='used mean function')
    plt.xlim(-12,48)
    plt.ylim(y.min()-1,
             y.max()+1)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.xlabel('epoch relative to SALT2 $t_0$ (days)', fontsize=20)
    plt.ylabel('Mag AB + cst.', fontsize=20)
    plt.title("$\log({\cal{L}}) = %.2f$ \n(kernel used: RBF)"%(log_L), fontsize=20)
    plt.gca().invert_yaxis()
    plt.legend(fontsize=18, loc=3)

interactive(children=(FloatSlider(value=0.5, continuous_update=False, description='$\\sigma$:', max=0.8, min=0…