# Curso básico de Python aplicado à Astronomia
### Laboratório Interinstitucional de e-Astronomia
# Aula XX - Relação Massa-riqueza de Aglomerados de Galáxias
Michel Aguena, LAPP/IN2P3 & LIneA

## Objetivo


Aglomerados de galáxias podem ser ferramentas poderosas para se obter infomação sobre a cosmologia,
mas é necessário fazer a associação entre os aglomerados e os halos de matéria escura. Dentre essas propriedades, está a relação entre a riqueza de aglomerados e a massa dos halos.

## Índice
1. [O que são aglomerados de galáxias](#cluster)
2. [Calibrando uma relação (fazendo um "fit")](#fit)
3. [O espaço de parâmetros](#param)
4. [Calibrando a relação massa-riqueza](#mr)

# 1. O que são aglomerados de galáxias <a class="anchor" id="cluster"></a>

# 2. Calibrando uma relação (fazendo um "fit") <a class="anchor" id="fit"></a>
Como definir qual os valores de uma função se ajustam melhor um conjunto de dados?
Ex:

In [None]:
# computation
import numpy as np
# display
from IPython.display import Markdown as md
from IPython.display import display, Math
# plots
import pylab as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import MultipleLocator
# widgets plot
import ipywidgets as widgets
%matplotlib widget
# ignore warnings
import warnings
warnings.filterwarnings('ignore')

In [None]:
x = np.array([1, 2, 3, 4, 5])
y = np.array([3, 5, 7, 9, 11])
err = np.array([3.12, 7.40, 5.61, 1.48 , 3.50])

In [None]:
%matplotlib inline
plt.errorbar(x, y, err, ls='', fmt='.', capsize=3)
plt.grid()
plt.xlabel('x')
plt.ylabel('y')
plt.show()

Qual a melhor função que descreve esses dados?
Uma reta:
\begin{equation}
f(x) = a\; x + b
\end{equation}

In [None]:
def func(x, a, b):
    return a*x+b

Que tal testar alguns parâmetros:

In [None]:
%matplotlib widget


avals = widgets.FloatSlider(
    value=0,
    min=-10,
    max=10.0,
    step=0.01,
    description='a:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)
bvals = widgets.FloatSlider(
    value=5,
    min=0,
    max=20.0,
    step=0.01,
    description='b:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)



plt.figure(figsize=(8, 6))
ax = plt.axes()
ax.scatter(x, y)
ax.errorbar(x, y, err, 
            ls='', fmt='.',
            lw=1,
            capsize=5)
ax.grid(lw=.5)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_ylim(-2, 15)



@widgets.interact(a=avals, b=bvals)
def update(a, b):
    ax.lines = []
    ax.plot(x, func(x, a, b), c='r', label=f'$f(x) = {a:.2f} x + {b:.2f}$')
    ax.legend(loc=2)

In [None]:
parametros = [(0, 7), (1, 2), (3, 0),
              (2.5, 1), (2, 6), (-3, 15)]
md('Dados alguns parâmetros:\n\n'+
   'Conjunto | a | b \n---|---|---\n'+
    '\n'.join([f'$c_{i}$ | {a} | {b}' for i, (a, b) in enumerate(parametros)])+
   '\n\nQual se ajusta melhor?'
  )

In [None]:
%matplotlib inline
plt.errorbar(x, y, err, ls='', fmt='.', capsize=3)
for i, (a, b) in enumerate(parametros):
    plt.plot(x, func(x, a, b),
            zorder=0, label=f'$c_{i}$ (a={a}, b={b})')
plt.grid()
plt.legend(ncol=2)
plt.xlabel('x')
plt.ylabel('y')
plt.show()

Como avaliar quantativamente qual o melhor ajuste?

## Método do $\chi^2$:

\begin{equation}
\chi^2 = \sum_i \frac{(data_i-modelo_i)^2}{(erro_i)^2}
\end{equation}

In [None]:
def chi2(data, modelo, erro):
    return sum((data-modelo)**2/erro**2)
# Print computation:
def show_chi2_numbers(x, y, err, parametros):
    out = ''
    for i, (a, b) in enumerate(parametros):
        out += (f'$c_{i}:'+\
          ' + '.join([rf'\frac{{({y_}-{func(x_, a, b)})^2}}{{{err_:.2f}^2}}'
                     for x_, y_, err_ in zip(x, y, err)])+\
          f'= {chi2(y, func(x, a, b), err):.2f}$\n\n'
               )
    return md(out)
show_chi2_numbers(x, y, err, parametros)

In [None]:
%matplotlib inline
plt.errorbar(x, y, err, ls='', fmt='.', capsize=3)
for i, (a, b) in enumerate(parametros):
    plt.plot(x, func(x, a, b),
            zorder=0, label=f'$c_{i}$ ($\chi^2={chi2(y, func(x, a, b), err):.0f}$)')
plt.grid()
plt.legend(ncol=2)
plt.xlabel('x')
plt.ylabel('y')
plt.show()

In [None]:
def plot_with_line(x, y, err):
    %matplotlib widget


    avals = widgets.FloatSlider(
        value=0,
        min=-10,
        max=10.0,
        step=0.01,
        description='a:',
        disabled=False,
        continuous_update=True,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
        layout=widgets.Layout(width='500px')
    )
    bvals = widgets.FloatSlider(
        value=5,
        min=-10,
        max=10.0,
        step=0.01,
        description='b:',
        disabled=False,
        continuous_update=True,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
        layout=widgets.Layout(width='500px')
    )

    plt.figure(figsize=(8, 6))
    ax = plt.axes()
    ax.scatter(x, y)
    ax.errorbar(x, y, err, 
                ls='', fmt='.',
                lw=1,
                capsize=5)
    ax.grid(lw=.5)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_ylim(-2, 15)

    @widgets.interact(a=avals, b=bvals)
    def update(a, b):
        ax.lines = []
        signal = '+' if b>=0 else ''
        ax.plot(x, func(x, a, b), c='r',
                label=f'$f(x) = {a:.2f}\,x {"+" if b>=0 else ""} {b:.2f}$\n'+
                f'$\chi^2 = {chi2(y, func(x, a, b), err):.4f}$')
        ax.legend(loc=2)
plot_with_line(x, y, err)

# 3. O espaço de parâmetros <a class="anchor" id="param"></a>

Avaliar a qualidade do ajuste no espaço de parâmetros

In [None]:
%matplotlib inline
for i, (a, b) in enumerate(parametros):
    plt.scatter(a, b, color=f'C{i}')
    plt.text(a, b, f'$c_{i}$({chi2(y, func(x, a, b), err):.0f})')
plt.xlabel('a')
plt.ylabel('b')
plt.grid()
plt.show()

E se calculassemos os valores na grade?

In [None]:
# Calcular valores
def compute_grid(vals1, vals2, func):
    '''
    Calcula uma função em uma grade 2D
    
    Parameters
    ----------
    vals1: array
        Valores para o parâmetro 1
    vals2: array
        Valores para o parâmetro 2
    func: function
        Função a ser calculada
        
    Returns
    -------
    grid1: array 2D
        Valores do parâmetro 1 na grade
    grid2: array 2D
        Valores do parâmetro 2 na grade
    chi2_grid: array 2D
        Valores da função na grade
    '''
    grid1 = np.outer(vals1, vals2*0+1)
    grid2 = np.outer(vals2, vals1*0+1).T
    chi2_grid = np.array([[func(p1, p2) for p2 in vals2]
                                for p1 in vals1])
    return grid1, grid2, chi2_grid
def compute_chi2_grid(vals1, vals2, func, x, y, err):
    '''
    Calcula o chi^2 em uma grade 2D
    
    Parameters
    ----------
    vals1: array
        Valores para o parâmetro 1
    vals2: array
        Valores para o parâmetro 2
    func: function
        Função a ser ajustada, deve ter como input (x, parâmetro1, parâmetro2)
    x: array
        Valores da componente x
    y: array
        Valores da componente y
    err: array
        Erros na componente y
        
    Returns
    -------
    grid1: array 2D
        Valores do parâmetro 1 na grade
    grid2: array 2D
        Valores do parâmetro 2 na grade
    chi2_grid: array 2D
        Valores do chi^2 na grade
    '''
    return compute_grid(vals1, vals2, func=lambda p1, p2: chi2(y, func(x, p1, p2), err))

Definir  valores para a grade

In [None]:
a_vals = np.linspace(-3, 7, 101)
b_vals = np.linspace(-14, 16, 99)
a_grid, b_grid, chi2_grid = compute_chi2_grid(a_vals, b_vals, func, x, y, err)

* Grafico 3D do $\chi^2$:

In [None]:
%matplotlib widget
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(a_grid, b_grid, chi2_grid,
                 lw=.5)

for i, (a, b) in enumerate(parametros): 
    ax.scatter(a, b, chi2(y, func(x, a, b), err),
               color=f'C{i}', label=f'$c_{i}$')

ax.set_xlabel('a')
ax.set_ylabel('b')
ax.set_zlabel('$\chi^2$')

ax.legend(ncol=2)

fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.resizable = True

plt.show()

## Likelihood

A *likelihood* é como uma probablilidade no espaço de parâmetros. No caso de uma distribuição Gaussiana de dados, ela pode ser definida em termos do $\chi^2$:

\begin{equation}
\mathcal{L} = \frac{1}{\sqrt{2\pi \det({erro}^2)}} \exp{\left(-\frac{1}{2}\chi^2\right)}
=\frac{1}{\prod_i\sqrt{2\pi\,{erro}^2_i}}\exp{\left(-\frac{1}{2}\sum_i \frac{(data_i-modelo_i)^2}{(erro_i)^2}\right)}
\end{equation}

In [None]:
def like_chi2(chi2_, err):
    return np.exp(-.5*chi2_)/np.prod(np.sqrt(2*np.pi)*err)
def like(*args, **kwargs):
    return like_chi2(chi2(*args, **kwargs), err)
like_grid = like_chi2(chi2_grid, err)
#like_grid /= like_grid.sum()*da*db
md('Para o nosso conjunto de dados, temos:\n\n'+
   'Conjunto | $\chi^2$ | $\mathcal{L}$\n---|---|---\n'+
    '\n'.join([f'c{i}|${chi2(y, func(x, a, b), err):.0f}$|'+
                f'${np.exp(-chi2(y, func(x, a, b), err)/2):.2e}$'
               for i, (a, b) in enumerate(parametros)]))

* Grafico 3D da $\mathcal{L}$:

In [None]:
%matplotlib widget
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(a_grid, b_grid, like_grid,
                 lw=.5)

ax.set_xlabel('a')
ax.set_ylabel('b')
ax.set_zlabel('$\mathcal{L}$')

for i, (a, b) in enumerate(parametros): 
    ax.scatter(a, b, like(y, func(x, a, b), err),
               color=f'C{i}', label=f'c{i}')


fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.resizable = True

plt.show()

# Fixar *VS* marginalizar parâmetros
Como encontrar definitivamente qual conjunto de parâmetros que melhor se ajusta aos dados?

* Fixar parâmetros?
\begin{equation}
P_f(a) = P(a, b=b_0)
\end{equation}

In [None]:
%matplotlib inline
f, axes = plt.subplots(1, 2, figsize=(8, 4))

for b in [-4, -2, 0, 2, 4]:
    vals = np.linspace(0, 5, 100)
    axes[0].plot(vals, [like(y, func(x, a, b), err) for a in vals],
            label=f'b={b}')

for a in [1, 1.5, 2, 2.5, 3]:
    vals = np.linspace(-5, 10, 100)
    axes[1].plot(vals, [like(y, func(x, a, b), err) for b in vals],
            label=f'a={a}')
for ax in axes:
    ax.legend()
    ax.grid()
axes[0].set_xlabel('a')
axes[1].set_xlabel('b')
plt.show()

In [None]:
%matplotlib widget

fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(a_grid, b_grid, like_grid,
                 lw=.5)

ax.set_xlabel('a')
ax.set_ylabel('b')
ax.set_zlabel('$\mathcal{L}$')


fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.resizable = True

plt.show()


vals = widgets.FloatSlider(
    value=0,
    min=-10,
    max=10.0,
    step=0.01,
    description='value:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=widgets.Layout(width='500px')
)

@widgets.interact(par=['a', 'b'], v=vals)
def update(par, v):
    ax.lines = []
    if par=='a':
        l = [like(y, func(x, v, b), err) for b in b_vals]
        ax.plot(b_vals*0+v, b_vals, l, c='r')
    elif par=='b':
        l = [like(y, func(x, a, v), err) for a in a_vals]
        ax.plot(a_vals, a_vals*0+v, l, c='r')


* Marginalizar parâmetros
\begin{equation}
P_m(a) = \int_{-\infty}^{\infty} db P(a, b)
\approx \sum_i \Delta b P(a, b=b_i)
\end{equation}

In [None]:
%matplotlib inline
f, axes = plt.subplots(2, figsize=(6, 9))

norm = lambda x, dx: np.array(x)/sum(np.array(x)*dx)

da, db = a_vals[1]-a_vals[0], b_vals[1]-b_vals[0]
axes[0].plot(a_vals, norm(np.sum(like_grid, axis=1), da), c='0', label='Marginalized')
axes[1].plot(b_vals, norm(np.sum(like_grid, axis=0), db), c='0', label='Marginalized')

axes[0].set_xlabel('a')
axes[0].xaxis.tick_top()
axes[0].xaxis.set_label_position('top') 

axes[1].set_xlabel('b')



for b in [-4, -2, 0, 2, 4]:
    vals = np.linspace(a_vals[0], a_vals[-1], 200)
    dv = (vals[1]-vals[0])
    axes[0].plot(vals, norm([like(y, func(x, a, b), err) for a in vals], dv),
                label=f'b={b}', zorder=0, lw=.8)
for a in [1, 1.5, 2, 2.5, 3]:
    vals = np.linspace(b_vals[0], b_vals[-1], 200)
    dv = (vals[1]-vals[0])
    axes[1].plot(vals, norm([like(y, func(x, a, b), err) for b in vals], dv),
                label=f'a={a}', zorder=0, lw=.8)
    
for ax in axes:
    ax.grid()
    ax.grid(which='minor', lw=.5)
    ax.xaxis.set_minor_locator(MultipleLocator(1))
    ax.legend()
axes[0].set_xlim(-1, 5)
axes[1].set_xlim(-10, 10)
#axes[0].xaxis.set_major_locator(MultipleLocator(2))
axes[1].xaxis.set_major_locator(MultipleLocator(5))
plt.show()

### Degeneressencia

In [None]:
%matplotlib inline
print(like_grid.min(), like_grid.max())
fig = plt.figure(figsize=(10, 7))
cb = plt.contourf(a_grid, b_grid, like_grid,
            levels=np.linspace(like_grid.min(), like_grid.max(), 100)
            )
plt.colorbar(cb)

In [None]:
%matplotlib widget

plot_data = {
    'datapoints':[True for i in range(5)],
    'fitpars':{'a':2.5, 'b':.5},
    'err':{'v':1}
    }

avals = widgets.FloatSlider(
    value=plot_data['fitpars']['a'],
    min=a_vals[0]+.1,
    max=a_vals[-1]-.1,
    step=0.01,
    description='a:',
    disabled=False,
    continuous_update=True,
    orientation='vertical',
    readout=True,
    readout_format='.2f',
    #layout=widgets.Layout(width='500px')
)
bvals = widgets.FloatSlider(
    value=plot_data['fitpars']['b'],
    min=b_vals[0]+.1,
    max=b_vals[-1]-.1,
    step=0.01,
    description='b:',
    disabled=False,
    continuous_update=True,
    orientation='vertical',
    readout=True,
    readout_format='.2f',
    #layout=widgets.Layout(width='500px')
)
evals = widgets.FloatSlider(
    value=plot_data['err']['v'],
    min=0,
    max=10,
    step=0.1,
    description='err:',
    disabled=False,
    continuous_update=True,
    orientation='vertical',
    readout=True,
    readout_format='.2f',
    #layout=widgets.Layout(width='500px')
)
output = widgets.Output()
with output:
    fig, axes = plt.subplots(2, figsize=(6, 8))
axes[0].contourf(a_grid, b_grid, like_grid,
            levels=np.linspace(like_grid.min(), like_grid.max(), 100)
            )

a, b = plot_data['fitpars']['a'], plot_data['fitpars']['b']
axes[0].scatter(a, b, c='r')

axes[1].scatter(x, y)
axes[1].errorbar(x, y, err, 
            ls='', fmt='.',
            lw=1,
            capsize=5)
axes[1].grid(lw=.5)
axes[1].set_xlabel('x')
axes[1].set_ylabel('y')
axes[1].set_ylim(-2, 15)
axes[1].plot(x, func(x, a, b), c='r',
        label=f'$f(x) = {a:.2f}\,x {"+" if b>=0 else ""} {b:.2f}$\n'+
        f'$\chi^2 = {chi2(y[use], func(x[use], a, b), err[use]):.4f}$')
axes[1].legend(loc=2)



points = [widgets.Checkbox(
            description=f'point {i+1}',
            value=True
                )
          for i in range(5)]

#@widgets.interact(a=avals, b=bvals, **points)
def update(change, argtype, argname):
    plot_data[argtype][argname] = change.new
    
    a, b = plot_data['fitpars']['a'], plot_data['fitpars']['b']
    use = plot_data['datapoints']
    if argtype in ('datapoints', 'err'):
        like_temp = like_chi2(compute_chi2_grid(a_vals, b_vals,
                func, x[use], y[use], err[use]*plot_data['err']['v'])[2], err[use]*plot_data['err']['v'])
        axes[0].collections = []
        axes[0].contourf(a_grid, b_grid, like_temp,
            levels=np.linspace(like_temp.min(), like_temp.max(), 100)
            )
        axes[1].collections = []
        axes[1].scatter(x[use], y[use], color='C0',)
        axes[1].errorbar(x[use], y[use], err[use]*plot_data['err']['v'], 
                    ls='', fmt='.', color='C0',
                    lw=1, capsize=5)


    axes[0].collections = axes[0].collections[:-1]
    axes[0].scatter(a, b, c='r')
    axes[1].lines = []
    axes[1].plot(x, func(x, a, b), c='r',
            label=f'$f(x) = {a:.2f}\,x {"+" if b>=0 else ""} {b:.2f}$\n'+
            f'$\chi^2 = {chi2(y[use], func(x[use], a, b), err[use]*plot_data["err"]["v"]):.4f}$')
    axes[1].legend(loc=2)
points[0].observe(lambda c:update(c, 'datapoints', 0),'value')
points[1].observe(lambda c:update(c, 'datapoints', 1),'value')
points[2].observe(lambda c:update(c, 'datapoints', 2),'value')
points[3].observe(lambda c:update(c, 'datapoints', 3),'value')
points[4].observe(lambda c:update(c, 'datapoints', 4),'value')

avals.observe(lambda c:update(c, 'fitpars', 'a'),'value')
bvals.observe(lambda c:update(c, 'fitpars', 'b'),'value')
evals.observe(lambda c:update(c, 'err', 'v'),'value')

widgets.HBox([output, 
              widgets.VBox(
                  [widgets.HBox([avals, bvals])]+points+[evals])])

In [None]:
y[2] = 7

A correlação entre os parâmetros se dá pelo envolvimento na likelihood.

Ex:

\begin{equation}
\mathcal{L} \propto \exp{\left[-(a^2+b^2+2\theta ab)\right]}
\end{equation}

In [None]:
%matplotlib widget

fig = plt.figure(figsize=(10, 7))
ax = plt.axes()
ax.set_xlabel('a')
ax.set_ylabel('b')


fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.resizable = True

plt.show()


vals = widgets.FloatSlider(
    value=0,
    min=-1,
    max=1,
    step=0.1,
    description=r'$\theta$:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=widgets.Layout(width='500px')
)

vals_pars = np.linspace(-5, 5, 100)

@widgets.interact(t=vals)
def update(t):
    grid1, grid2, like2_grid = compute_grid(vals_pars, vals_pars,
                        func=lambda a, b: np.exp(-(a**2+b**2+2*t*a*b))
                                           )
    ax.collections = []
    ax.contourf(grid1, grid2, like2_grid,
            levels=np.linspace(like2_grid.min(), like2_grid.max(), 100)
            )
    
    #ax.pcolor(grid1, grid2, like2_grid,)
    ax.set_title(f'$\mathcal{{L}} \propto \exp{{\left[-(a^2+b^2{"+" if b>=0 else ""}{2*t}\; ab)\\right]}}$')
    plt.show()

# 4. Calibrando a relação massa-riqueza <a class="anchor" id="mr"></a>

A relação entre a massa dos halos e a riqueza dos aglomerados é approximada por uma relação de escala:

\begin{equation}
\left(\frac{M}{M^0}\right) \approx
\left(\frac{N_{gals}}{N_{gals}^0}\right)^\alpha 
\end{equation}

In [None]:
data = np.genfromtxt('data.txt', delimiter=',', names=True)

In [None]:
%matplotlib widget
f = plt.figure(figsize=(7, 7))
ax = plt.axes()
ax.errorbar(data['Richness'], data['Mass_Mpc'],
             data['MassError_Mpc'], ls='', lw=.5,
            fmt='.', markersize=1, capsize=2)
ax.set_xlabel('Richness')
ax.set_ylabel('Mass [$M_{\odot}$]')
ax.grid(which='both')
ax.grid(which='minor',  linewidth=.5)

xlog = widgets.Checkbox(
    value=False,
    description='log scale (x)',
    disabled=False,
    indent=False
)
ylog = widgets.Checkbox(
    value=False,
    description='log scale (y)',
    disabled=False,
    indent=False
)

@widgets.interact(xlog=xlog, ylog=ylog)
def update(xlog, ylog):
    xscale = 'log' if xlog else 'linear'
    yscale = 'log' if ylog else 'linear'
    ax.set_xscale(xscale)
    ax.set_yscale(yscale)
    plt.show()

### Pode ser definida em escala logaritimica:

\begin{equation}
logM=\alpha\,logN_{gals}+\beta,
\end{equation}

onde $\beta= logM_0-\alpha\,logN_{gals}^0$.

* Detalhe:
\begin{equation}
Err_{logM}=\frac{Err_{M}}{M ln(10)}
\end{equation}

In [None]:
%matplotlib widget


avals = widgets.FloatSlider(
    value=0,
    min=-5,
    max=5,
    step=0.01,
    description=r'$\alpha$:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=widgets.Layout(width='500px')
)
bvals = widgets.FloatSlider(
    value=15,
    min=0,
    max=20,
    step=0.01,
    description=r'$\beta$:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=widgets.Layout(width='500px')
)

xlog = widgets.Checkbox(
    value=False,
    description='log scale (x)',
    disabled=False,
    indent=False,
)
ylog = widgets.Checkbox(
    value=False,
    description='log scale (y)',
    disabled=False,
    indent=False,
)

f = plt.figure(figsize=(7, 7))
ax = plt.axes()
ax.errorbar(data['Richness'], data['Mass_Mpc'],
             data['MassError_Mpc'], ls='', lw=1,
            fmt='.', markersize=1, capsize=5)
ax.set_xlabel('Richness')
ax.set_ylabel('Mass [$M_{\odot}$]')
ax.grid(which='both')
ax.grid(which='minor',  linewidth=.5)

powfmt = lambda x: '%s\\times 10^{%s}'%tuple(f'{x:.2e}'.replace('+','').split('e'))


@widgets.interact(a=avals, b=bvals, xlog=xlog, ylog=ylog)
def update(a, b, xlog, ylog):
    ax.lines = []
    r_ = np.linspace(np.log10(data['Richness'].min()),
                     np.log10(data['Richness'].max()),
                     100)
    ax.plot(10**r_, 10**func(r_, a, b), c='r')
    xscale = 'log' if xlog else 'linear'
    yscale = 'log' if ylog else 'linear'
    ax.set_xscale(xscale)
    ax.set_yscale(yscale)
    if ylog:
        ax.set_ylim(1e13, 5e15)
    else:
        ax.set_ylim(1e10, 5e15)
    plt.title(f'$logM(N_{{gals}}) = {a:.2f}\,logN_{{gals}}+ {b:.2f}$\n'+
              f'$M(N_{{gals}}) = {powfmt(10**b)}\,(N_{{gals}})^{{{a:.2f}}}$\n')

In [None]:
def plot_likelihood(alpha_vals, beta_vals, func, x, y, err):
    '''
    Calcula o chi^2 em uma grade 2D
    
    Parameters
    ----------
    vals1: array
        Valores para o parâmetro 1
    vals2: array
        Valores para o parâmetro 2
    func: function
        Função a ser ajustada, deve ter como input (x, parâmetro1, parâmetro2)
    x: array
        Valores da componente x
    y: array
        Valores da componente y
    err: array
        Erros na componente y
        
    Returns
    -------
    grid1: array 2D
        Valores do parâmetro 1 na grade
    grid2: array 2D
        Valores do parâmetro 2 na grade
    chi2_grid: array 2D
        Valores do chi^2 na grade
    '''
    grid1, grid2, chi2_grid = compute_chi2_grid(alpha_vals, beta_vals,
                                                func, x, y, err)
    like2_grid = like_chi2(chi2_grid, err)
    
    %matplotlib widget
    fig = plt.figure(figsize=(7, 7))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_wireframe(grid1, grid2, like2_grid,
                     lw=.5)

    ax.set_xlabel(r'$\alpha$')
    ax.set_ylabel(r'$\beta$')
    ax.set_zlabel(r'$\mathcal{L}$')

    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.resizable = True

    plt.show()

In [None]:
# Fill below

logN = np.log10(data['Richness'])
logM = np.log10(data['Mass_Mpc'])
siglogM = data['MassError_Mpc']/(data['Mass_Mpc']*np.log(10))
alpha_vals = np.linspace(1, 1.5, 99)
beta_vals = np.linspace(12, 12.5, 99)

# Plot
plot_likelihood(
    alpha_vals=alpha_vals, beta_vals=beta_vals,
    func=func, x=logN, y=logM, err=siglogM
)