# Lesson 3 : Mean function & Gaussian processes 

Below some packages to import that will be used for this lesson

Cell bellow is here for avoiding scrolling when plot is create within ipython notebook

In [1]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines){
    return false;
}

<IPython.core.display.Javascript object>

In [2]:
# Classical package for manipulating
# array, for plotting and interactiv plots.
import pylab as plt
from matplotlib import gridspec
import numpy as np
import scipy
import ipywidgets as widgets
from ipywidgets import interact
import itertools
import copy
import sys
import pickle

# Gaussian processes from scikit-learn is used for this lesson.
# Other packages exist (e.g. george) but for the courses I guess
# it would be the best because a lot of people in desc are already
# using scikit-learn. The only suggestion would be about how to fit
# hyperparameters in a more efficient way, but it will be done with
# an other packages also broadly used within DESC, which is TreeCorr.
from sklearn import gaussian_process as skl_gp
from sklearn.gaussian_process.kernels import Kernel

# Special implemetation of anisotropic squarred exponential kernel
# in scikit-learn. Not implemented in scikit-learn originally.
import sys
sys.path.append('/home/nbuser/project')
from kernel import AnisotropicRBF

# treecorr is a package to compute 2-point correlation function.
# it will be use as an alternative way of Maximum Likelihood described 
# in Rasmussen & Williams 2006 to estimate hyperparameters.
try:
    import treecorr
except:
    !pip install treecorr
    import treecorr

# Some import trickery to get all subclasses of sklearn.gaussian_process.kernels.Kernel
# into the local namespace without doing "from sklearn.gaussian_process.kernels import *"
# and without importing them all manually. Originally developped by Josh Meyers within Piff.
# Example:
# kernel = eval_kernel("RBF(1)") instead of
# kernel = sklearn.gaussian_process.kernels.RBF(1)
def eval_kernel(kernel):
    def recurse_subclasses(cls):
        out = []
        for c in cls.__subclasses__():
            out.append(c)
            out.extend(recurse_subclasses(c))
        return out
    clses = recurse_subclasses(Kernel)
    for cls in clses:
        module = __import__(cls.__module__, globals(), locals(), cls)
        execstr = "{0} = module.{0}".format(cls.__name__)
        exec(execstr, globals(), locals())

    from numpy import array

    try:
        k = eval(kernel)
    except (KeyboardInterrupt, SystemExit):
        raise
    except Exception as e:  # pragma: no cover
        raise RuntimeError("Failed to evaluate kernel string {0!r}.  "
                               "Original exception: {1}".format(kernel, e))

    if type(k.theta) is property:
        raise TypeError("String provided was not initialized properly")
    return k

def load_pickle(pickle_file):
    if sys.version_info[0] < 3:
        dico = pickle.load(open(pickle_file))
    else:
        dico = pickle.load(open(pickle_file, 'rb'), encoding='latin1')

    return dico

Collecting treecorr
  Using cached https://files.pythonhosted.org/packages/06/72/0b86c778e815a0611a7fc7bd5239d17ed346d9f382b8733f6cab1b38a06e/TreeCorr-4.0.4.tar.gz
Collecting LSSTDESC.Coord>=1.1 (from treecorr)
  Using cached https://files.pythonhosted.org/packages/c4/28/7175cb1c0df002b4435ff25f6f2d92c5ad7417e80f4bdf436783205760cb/LSSTDESC.Coord-1.1.2.tar.gz
Building wheels for collected packages: treecorr, LSSTDESC.Coord
  Running setup.py bdist_wheel for treecorr ... [?25ldone
[?25h  Stored in directory: /home/nbuser/.cache/pip/wheels/22/6a/04/c7b238b2e07633907026191f5bea54cb27035183045279e173
  Running setup.py bdist_wheel for LSSTDESC.Coord ... [?25ldone
[?25h  Stored in directory: /home/nbuser/.cache/pip/wheels/5f/d7/aa/627da57d6a75fe0bf63e03d8bb0e8767e804bb9be2c7a05bb7
Successfully built treecorr LSSTDESC.Coord
Installing collected packages: LSSTDESC.Coord, treecorr
Successfully installed LSSTDESC.Coord-1.1.2 treecorr-4.0.4


## Exercice 6) Adding a mean function on real SNIa data, impact on GP interpolation (1D):

In [3]:
##########################################################################################
# EXERCICE 6: Adding a mean function on real SNIa data, impact on GP interpolation (1D): #
##########################################################################################

dic = load_pickle('data/snia_gaussian_process_de_school.pkl')

def gp_regression(x, new_x, y, kernel, y_err=None):
    
    if y_err is None:
        y_err =np.ones_like(y) *1e-10
    
    gp = skl_gp.GaussianProcessRegressor(kernel=kernel, alpha=y_err,
                                         optimizer=None,
                                         normalize_y=None)
    gp.fit(x,y)
    y_predict, y_std = gp.predict(new_x, return_std=True)
    log_L = gp.log_marginal_likelihood()
    return y_predict, y_std, log_L

def spline_1D(old_binning, mean_function, new_binning):
    cubic_spline = scipy.interpolate.InterpolatedUnivariateSpline(old_binning,
                                                                  mean_function)
    mean_interpolate = cubic_spline(new_binning)
    return mean_interpolate


@interact(sigma = widgets.FloatSlider(value=0.5, min=0.1, max=0.8, step=0.01, description='$\sigma$:',
          disabled=False,
          continuous_update=False,
          orientation='horizontal',
          readout=True,
          readout_format='.2f'), 
          l = widgets.FloatSlider(value=3., min=1., max=15, step=0.1, description='$l$:',
          disabled=False,
          continuous_update=False,
          orientation='horizontal',
          readout=True,
          readout_format='.2f'),
          add_mean=widgets.Checkbox(value=False,
                                   description='Add mean function',
                                   disabled=False),
          sn_name = widgets.Dropdown(options=['SNF20080514-002', 'SNF20050821-007', 'SNF20070802-000'],
                                     value='SNF20080514-002',
                                     description='SNIa name:',
                                     disabled=False,))
def plot_samples(sigma, l, add_mean, sn_name):
    
    i = 151
    
    new_x = np.linspace(-12, 48, 80).reshape((80, 1))
    Kernel = "%f * %s(%f)"%((sigma**2, "RBF", l))
    Kernel = eval_kernel(Kernel)
    
    y = copy.deepcopy(dic[sn_name]['y'])
    y0 = copy.deepcopy(dic[sn_name]['y0'])
 
    if add_mean:
        y0_on_y = spline_1D(dic[sn_name]['y0_time'], y0, 
                            dic[sn_name]['y_time'])
    else:
        y0_on_y = 0
    
    epoch = dic[sn_name]['y_time'].reshape((len(dic[sn_name]['y_time']),1))

    y_pred, y_std, log_L = gp_regression(epoch, new_x, y-y0_on_y, 
                                         Kernel, y_err=dic[sn_name]['y_err'])
    if add_mean:
        y0_on_ypredict = spline_1D(dic[sn_name]['y0_time'], y0, 
                                   np.linspace(-12, 48, 80))
    else:
        y0_on_ypredict = 0
        
    y_pred += y0_on_ypredict

    plt.figure(figsize=(14,8))
    
    # Data
    plt.scatter(dic[sn_name]['y_time'], y, 
                c='b', label = 'data')
    plt.errorbar(dic[sn_name]['y_time'], y, 
                 linestyle='', yerr=dic[sn_name]['y_err'], ecolor='b', 
                 alpha=0.7,marker='.',zorder=0)
    
    # GP prediction
    plt.plot(new_x, y_pred, 'r', lw =3, label = 'GP prediction')
    plt.fill_between(new_x.T[0], y_pred-y_std, y_pred+y_std, color='r', alpha=0.3)
    
    if not add_mean:
        plt.plot(new_x, np.zeros_like(new_x),'k--', label='used mean function')
    else:
        plt.plot(dic[sn_name]['y0_time'], dic[sn_name]['y0'],
                 'k--', label='used mean function')
    plt.xlim(-12,48)
    plt.ylim(y.min()-1,
             y.max()+1)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.xlabel('epoch relative to SALT2 $t_0$ (days)', fontsize=20)
    plt.ylabel('Mag AB + cst.', fontsize=20)
    plt.title("$\log({\cal{L}}) = %.2f$ \n(kernel used: RBF)"%(log_L), fontsize=20)
    plt.gca().invert_yaxis()
    plt.legend(fontsize=18, loc=3)

aW50ZXJhY3RpdmUoY2hpbGRyZW49KEZsb2F0U2xpZGVyKHZhbHVlPTAuNSwgY29udGludW91c191cGRhdGU9RmFsc2UsIGRlc2NyaXB0aW9uPXUnJFxcc2lnbWEkOicsIG1heD0wLjgsIG1pbj3igKY=
