In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, jit, vmap
from jax import random
from jax import jacfwd
import jax_cosmo as jc
import numpy
import numpy as np
import scipy
import matplotlib.pyplot as plt
from scipy.stats import exponnorm
import sncosmo
from astropy.table import Table
import pandas
import pickle

In [2]:
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_debug_nans", True)

# Calculate the Fisher Matrix of SN programs

The data are m and measurement uncertainty dm

The model parameters are

Omega: Cosmological
c: calibration
Av, Rv: per supernova dust

The likelihood p(y| Omega, c, Av, Rv) is
N(c,dc) P(Av, Rv) Prod_SN N(m-(M+mu+C(c)+D(Av, Rv), dm_I^2 + dm^2)

The terms are
<partial ^2 lnL / partial alpha partial beta> = -partial X / partial alpha C^{-1} partial X /beta + delta_alpha,beta d2(ln p_alpha)


In [3]:
# Standard candle spectrum convolved with a nominal logarithmic filter
# historical used for first implementation
@jit
def M_sn(wl):
    wl0=5500.
    sig2 = 5e6
    return (wl-wl0)**2/2/sig2/2.5

In [4]:
class SNModel(object):
    
    def __init__(self, R):
        # self.wave=numpy.array([4000,5000,6000,7000,10000])
        # self.values = numpy.array([4,5,6,7,0])
        # the extreme wavelengths of this file are [2000,9200]
        phase, self.wave, values = sncosmo.io.read_griddata_ascii('../data/salt2-k21-frag/salt2_template_0.dat')
        self.wave = jnp.concatenate((numpy.array([1000]), self.wave, numpy.array([20000])))
        self.values=jnp.concatenate((numpy.array([0]),values[numpy.where(phase==0)[0][0],:],numpy.array([0])))
        # self.wave=jnp.array(self.wave)
        self.R = R
        
    @staticmethod
    def interp_(x, xp, fp):
        """
        Simple equivalent of np.interp that compute a linear interpolation.
    
        We are not doing any checks, so make sure your query points are lying
        inside the array.
    
        TODO: Implement proper interpolation!
    
        x, xp, fp need to be 1d arrays
        """
        # First we find the nearest neighbour
        ind = jnp.argmin((x - xp) ** 2)
    
        # Perform linear interpolation
        ind = jnp.clip(ind, 1, len(xp) - 2)
        xi = jnp.asarray(xp)[ind]
    
        # Figure out if we are on the right or the left of nearest
        s = jnp.sign(jnp.clip(x, jnp.asarray(xp)[1], jnp.asarray(xp)[-2]) - xi).astype(np.int64)
        a = (jnp.asarray(fp)[ind + jnp.copysign(1, s).astype(np.int64)] - jnp.asarray(fp)[ind]) / (
            jnp.asarray(xp)[ind + jnp.copysign(1, s).astype(np.int64)] - jnp.asarray(xp)[ind]
        )
        b = jnp.asarray(fp)[ind] - a * jnp.asarray(xp)[ind]
        
        return a * x + b
        
    @staticmethod
    def interp(x,xp,fp):
        return vmap(interp_,(0, None, None))(x,xp,fp)
        
    def flux(self,wave_in):
        return interp(wave_in,self.wave,self.values)

    def bandmag_one(self,wl):

        delta = numpy.arcsinh(1/2/self.R)
        lmin= wl*numpy.exp(-delta)
        lmax= wl*numpy.exp(delta)
        #left edge
        vall = SNModel.interp_(lmin,self.wave,self.values)
        valr = SNModel.interp_(lmax,self.wave,self.values)

        logic = jnp.logical_and(self.wave > lmin, self.wave < lmax)
        _flux = self.values * logic
        # plt.plot(self.wave,logic)
        _wave  = jnp.concatenate((jnp.array([lmin]), self.wave, jnp.array([lmax])))

        
        _flux = jnp.concatenate((jnp.array([vall]), _flux, jnp.array([valr])))


        
        args = jnp.argsort(_wave)
        _wave=_wave[args]
        _flux=_flux[args]
        
        # plt.plot(_wave,_flux)
        # print(wl, jax.scipy.integrate.trapezoid(_flux,_wave))
        return -2.5*jnp.log10(jax.scipy.integrate.trapezoid(_flux,_wave))

In [5]:
sn = SNModel(4.5)

I0000 00:00:1709776204.615930       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [6]:
@jit
def CCM_a(wl):
    def infrared(x):
        a = 0.574*x**1.61
        return a

    def optical(x):
        y=x-1.82
        a = 1 + 0.17699*y - 0.50447*y**2 - 0.02427*y**3 + 0.72085*y**4+ 0.01979*y**5 - 0.77530*y**6 + 0.32999*y**7
        return a

    def ultraviolet(x):
        a = 1.752 - 0.316*x - 0.104/((x - 4.67)**2 + 0.341)
        return a
        
    x = 1e4/wl
    return jnp.select(condlist=[(x >= 0.3) & (x<=1.1), (x >= 1.1) & (x<=3.3), (x >= 3.3) & (x<=5.9)],
                                choicelist=[infrared(x), optical(x), ultraviolet(x)], default=0.)
def CCM_a_multi(wl):
    return vmap(CCM_a,(0))(wl)

In [7]:
@jit
def CCM_b(wl):
    def infrared(x):
        b = -0.527*x**1.61
        return b

    def optical(x):
        y=x-1.82
        b = 1.41338*y + 2.28305*y**2 + 1.07233*y**3 - 5.38434*y**4 - 0.62251*y**5 + 5.30260*y**6 - 2.09002*y**7
        return b

    def ultraviolet(x):
        b = -3.090 + 1.825*x + 1.206/((x - 4.62)**2 + 0.263)
        return b
        
    x = 1e4/wl
    return jnp.select(condlist=[(x >= 0.3) & (x<=1.1), (x >= 1.1) & (x<=3.3), (x >= 3.3) & (x<=5.9)],
                                choicelist=[infrared(x), optical(x), ultraviolet(x)], default=0.)

def CCM_b_multi(wl):
    return vmap(CCM_b,(0))(wl)

In [8]:
def color(wl, Av, Rv):
    DeltaLnLambda=0.26  # rough separation between filters 
    # in dm/dlnl
    # return grad(lambda x: M_sn(x)+Av*(CCM_a(x)+CCM_b(x)/Rv), (0))(wl)/wl
    return grad(lambda x: sn.bandmag_one(x)+Av*(CCM_a(x)+CCM_b(x)/Rv), (0))(wl)*wl*DeltaLnLambda

def color_run():
    print (color(4400., 0.1, 3.1))

# color_run()

In [9]:
# distribution of Av Exponentially modified Gaussian distribution

# Reasonable values in wikipedia notation
Avmu = 0.
Avlambda = 5.
Avsigma = 0.15

# in scipy
AvK = 1./Avlambda/Avsigma
Avscale = Avsigma

def pAvrvs(size=1):
    return  exponnorm.rvs(size=size, K=AvK, loc=0, scale=Avscale)
    
def d2lnpAv(x):
    # d2 ln efrc((mu + lambda sigma^2 - x)/sqrt2/sigma)
    # z = (mu + lambda sigma^2 - x)/sqrt2/sigma
    # d1 = 1/erfc(z) (-2/sqrt(pi) exp(-z^2) (-1/sqrt2/sigma)
    # d2 = 
    z = (Avmu + Avlambda*Avsigma*Avsigma - x)/numpy.sqrt(2)/Avsigma
    erfc_ = scipy.special.erfc(z)
    exp_ = numpy.exp(-z**2)
    return 4*exp_/numpy.sqrt(numpy.pi)/Avscale/Avscale/erfc_ * (z - exp_/numpy.sqrt(numpy.pi)/erfc_)

## Chi-sq term (not including dm_int term which is another part of the code)

In [10]:
# For one object one band
def mbar_perfilt(Omega_c, w0, wa, dC, Av, Rv, z, efflam):
    # mu not efficient
    cosmo = jc.Planck15(Omega_c=Omega_c, w0=w0, wa=wa)
    restlam=(efflam+dC[0])/(1+z)
    
    # mu = M_sn(restlam) + 10*jnp.log10(1+z) + 5*jnp.log10(jc.background.angular_diameter_distance(cosmo, 1/(1+z)))[0]
    mu =  10*jnp.log10(1+z) + 5*jnp.log10(jc.background.angular_diameter_distance(cosmo, 1/(1+z)))[0]
    # extiction
    A = Av * (CCM_a(restlam)+CCM_b(restlam)/Rv)

    # SN color
    c = color(restlam, Av, Rv)
    
    # calibration
    return sn.bandmag_one(restlam) + mu + A + dC[1] + dC[2]*c

def dmbar_perfilt(Omega_c, w0, wa, dC, Av, Rv, z, efflam):
    return grad(mbar_perfilt,(0,1,2,3,4,5))(Omega_c, w0, wa, dC, Av, Rv, z, efflam)

def dmbar(Omega_c, w0, wa, dCs, Av, Rv, z, efflams):
    return vmap(dmbar_perfilt, (None, None, None, 0,  None, None, None, 0))(Omega_c, w0, wa, dCs, Av, Rv, z, efflams)
    

def mbar(Omega_c, w0, wa, dCs, Av, Rv, z, efflams):
    return vmap(mbar_perfilt, (None, None, None, 0,  None, None, None, 0))(Omega_c, w0, wa, dCs, Av, Rv, z, efflams)

def mbar_check():
    nsn=10
    efflams = numpy.array([4700., 6420.,  7849.])
    zs = numpy.linspace(0.1,1,nsn)
    Avs = pAV().rvs(size=nsn, scale=AVscale)
    Rvs = numpy.zeros(nsn)+3.1
    dCs= numpy.zeros((3,3))
    Omega_c=0.3
    w0=-1.
    wa=0.
    # ans = mbar_perfilt(Omega_c, w0, wa,dCs[0],Avs[0],Rvs[0],zs[0],efflams[0])
    # dans = dmbar_perfilt(Omega_c, w0, wa ,dCs[0],Avs[0],Rvs[0],zs[0],efflams[0])
    # print(ans)
    # for _ in dans:
    #     print(_)
    ans = mbar(Omega_c, w0, wa,dCs,Avs[0],Rvs[0],zs[0],efflams)
    dans = dmbar(Omega_c, w0, wa ,dCs,Avs[0],Rvs[0],zs[0],efflams)
    # print(ans)
    for _ in dans:
        print(_)

def F_chisq(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_datas):

    nfilt = len(efflams)
    nsn = len(Avs)
    ncal = 3
    ncosmo = 3
    npars = 3 + ncal*nfilt + 2*nsn  # 3 cosmology, ncal * nfilt, Av, Rv
                       
    F=numpy.zeros((npars,npars))
    
    # loop over SNe
    snind=0
    for Av , Rv, z, dm_data in zip( Avs, Rvs, zs, dm_datas):
        dmbars = dmbar(Omega_c, w0, wa, dCs, Av, Rv, z, efflams)
        
        cov = numpy.zeros((nfilt,nfilt))+dm_int**2
        diag = numpy.identity(nfilt) * dm_data**2
        cov = cov + diag
        # Despite the documentation, the example shows L.T L = cov
        L, lower = jsp.linalg.cho_factor(cov)
    
        #Cosmo
        for i in range(ncosmo):
            rhs = jsp.linalg.solve_triangular(L, dmbars[i],trans=1,lower=lower)
            F[i,i] += jnp.dot(rhs,rhs)
            #Cosmo
            for j in range(i+1,ncosmo):
                lhs = jsp.linalg.solve_triangular(L, dmbars[j],trans=1,lower=lower)
                F[i,j] += jnp.dot(lhs,rhs)
            # Calibrations
            for j in range(nfilt):
                for k in range(ncal):
                    partial = numpy.zeros(nfilt)
                    partial[j] = dmbars[ncosmo][j][k]
                    lhs = jsp.linalg.solve_triangular(L, partial,trans=1,lower=lower)
                    F[i,ncosmo+ncal*j+k] += jnp.dot(lhs,rhs)
            # Av
            lhs = jsp.linalg.solve_triangular(L, dmbars[ncosmo+1],trans=1,lower=lower)
            F[i,ncosmo+ncal*nfilt+snind] += jnp.dot(lhs,rhs)
            
            # Rv
            lhs = jsp.linalg.solve_triangular(L, dmbars[ncosmo+2],trans=1,lower=lower)
            F[i,ncosmo+ncal*nfilt+nsn+snind] += jnp.dot(lhs,rhs)
        
        #Calibrations
        for i in range(nfilt):
            for l in range(ncal):
                partial = numpy.zeros(nfilt)
                partial[i] = dmbars[ncosmo][i][l]
                rhs = jsp.linalg.solve_triangular(L, partial,trans=1,lower=lower)
                F[ncosmo+ncal*i+l,ncosmo+ncal*i+l] += jnp.dot(rhs,rhs)
                
                # Calibrations
                for j in range(i,nfilt):
                    for k in range(ncal):
                        partial = numpy.zeros(nfilt)
                        partial[j] = dmbars[ncosmo][j][k]
                        lhs = jsp.linalg.solve_triangular(L, partial,trans=1,lower=lower)
                        F[ncosmo+ncal*i+l,ncosmo+ncal*j+k] += jnp.dot(lhs,rhs)
                # Av
                lhs = jsp.linalg.solve_triangular(L, dmbars[ncosmo+1],trans=1,lower=lower)
                F[ncosmo+ncal*i+l,ncosmo+ncal*nfilt+snind] += jnp.dot(lhs,rhs)
                
                # Rv
                lhs = jsp.linalg.solve_triangular(L, dmbars[ncosmo+2],trans=1,lower=lower)
                F[ncosmo+ncal*i+l,ncosmo+ncal*nfilt+nsn+snind] += jnp.dot(lhs,rhs)
        

        #Av
        rhs = jsp.linalg.solve_triangular(L, dmbars[ncosmo+1],trans=1,lower=lower)
        F[ncosmo+ncal*nfilt+snind, ncosmo+ncal*nfilt+snind] += jnp.dot(rhs,rhs)

            #Rv
        lhs = jsp.linalg.solve_triangular(L, dmbars[ncosmo+2],trans=1,lower=lower)
        F[ncosmo+ncal*nfilt+snind, ncosmo+ncal*nfilt+nsn+snind] += jnp.dot(lhs,rhs)

        #Rv
        F[ncosmo+ncal*nfilt+nsn+snind,ncosmo+ncal*nfilt+nsn+snind] += jnp.dot(lhs,lhs)
        
        snind+=1

    for i in range(npars):
        for j in range(i+1,npars):
            F[j,i]=F[i,j]

    return F
        
def F_chisq_check():
    ncal=3
    nsn=10
    efflams = numpy.array([4700., 6420.,  7849.])
    zs = numpy.linspace(0.1,1,nsn)
    Avs = pAvrvs(size=nsn1)
    Rvs = numpy.zeros(nsn)+3.1
    dCs= numpy.zeros((3,ncal))
    dCs=numpy.array([[1.,2,3],[4,5,6],[6,7,8]])
    Omega_c=0.3
    w0=-1.
    wa=0.
    dm_int=0.1
    dm_data=0.02
    F = F_chisq(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data)
    print(F[3:12,3:12])

# F_chisq_check()

## Chi-sq term for dm_int (Not implemented correctly so intrinsic dispersion is not a free parameter yet)

In [11]:
def lnL_dm_int(dm_int, dm_data, nfilt):
    
    # For now assume all data have same incertainties
    cov = numpy.zeros((nfilt,nfilt))+dm_int**2
    diag = numpy.identity(nfilt) * dm_data**2
    cov = cov + diag
    # Despite the documentation, the example shows L.T L = cov
    L, lower = jsp.linalg.cho_factor(cov)    
    return -0.5/jnp.diagonal(L).prod()**2

def d2ln_dm_int(dm_int, dm_data, nfilt):
    return grad(grad(lnL_dm_int,(0)),(0))(dm_int,dm_data, nfilt)

def F_chisq_dm_int(efflams, dm_int, dm_data, nsn):
    nfilt = len(efflams)

## A term

In [12]:
def F_Av(Avs):
    ans = []
    for Av in Avs:
        ans.append(-d2lnpAv(Av))
    return ans

def F_Av_test():
    nsn=10
    Avs = pAv().rvs(size=nsn, scale=Avscale)    
    print(Avs, F_Av(Avs))

# F_Av_test()

## Log det term

In [13]:
def lnL_logdet(dm_int, dm_data, efflams):
    nfilt=len(efflams)
    # For now assume all data have same incertainties
    cov = numpy.zeros((nfilt,nfilt))+dm_int**2
    diag = numpy.identity(nfilt) * dm_data**2
    cov = cov + diag
    # Despite the documentation, the example shows L.T L = cov
    L, lower = jsp.linalg.cho_factor(cov)

    # logdet term
    ans = - jnp.log(jnp.diagonal(L)).sum()
    return ans

def F_logdet(dm_int, dm_data, efflams):
    return grad(grad(lnL_logdet,(0)),(0))(dm_int, dm_data, efflams)

def F_logdet_check():
    dm_int=0.1
    dm_data=0.02
    efflams = numpy.array([4700., 6420.,  7849.])
    print(F_logdet(dm_int, dm_data, efflams))

In [14]:
def addCalibrationUncertainty(ncosmo, nfilt, ncal,F):
    # I checked the ordering of the C variables
    ncalinvsig2 = [1./64,1/0.15/0.15, 1/0.15/0.15]
    for i in range(nfilt):
        for j in range(ncal):
            F[ncosmo+i*nfilt+j,ncosmo+i*nfilt+j] += ncalinvsig2[j]

## Combining all contributions to F (for now leaving out dm_int)

In [15]:
def F_all(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data):

    nsn = len(Avs)
    # prior uncertainty in the calibration parameters
    ncal=3
    ncosmo =3

    # lnlam = numpy.log(efflams)
    # deltalnlam = (lnlam[-1] - lnlam[0])/(nfilt-1)
    # print(deltalnlam)
    # wef
    
    nfilt = len(efflams)
    
    # chisq term fills everything
    F = F_chisq(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data)
    
    Favs = F_Av(Avs)
    for i in range(nsn):
        F[ncosmo+ncal*nfilt+i, ncosmo+ncal*nfilt+i]+=Favs[i]

    for i in range(nsn):
        F[ncosmo+ncal*nfilt+nsn+i, ncosmo+ncal*nfilt+nsn+i]+= 1/0.3**3  # by hand the underlying distribution of Rv

    return F

In [16]:
def F_merge(ncosmo, F1, F2):
    dim1 = F1.shape[0]
    dim2 = F2.shape[0]
    npars = dim1 + dim2 - ncosmo

    F = numpy.zeros((npars,npars))
    F[0:ncosmo, 0:ncosmo] = F1[0:ncosmo, 0:ncosmo]+ F2[0:ncosmo, 0:ncosmo]
    F[0:ncosmo, ncosmo:dim1] = F1[0:ncosmo,ncosmo:]
    F[0:ncosmo, dim1:] = F2[0:ncosmo,ncosmo:]
    F[ncosmo:dim1,0:ncosmo] = F1[ncosmo:,0:ncosmo]
    F[dim1:,0:ncosmo] = F2[ncosmo:,0:ncosmo]
    F[ncosmo:dim1,ncosmo:dim1] = F1[ncosmo:,ncosmo:]
    F[dim1:, dim1:] = F2[ncosmo:,ncosmo:]
    return F

# Pantheon Dataset

In [17]:
data = Table.read('../data/apjac8e04t7_mrt.txt', format="ascii.cds")
df = data.to_pandas()
# print(df['CID'].unique())

In [18]:
df = df[["CID", "zHD", "e_mBcorr", "c"]]
df = df.groupby(['CID']).mean()  # not quite right for error propagation
df =df.sort_values(by='zHD')

In [None]:
cosmo = jc.Planck15()
Omega_c=cosmo.Omega_c
w0=cosmo.w0
wa=cosmo.wa
dm_int =0.14
ncal=3
ncosmo=3
efflams_all = numpy.array([4700., 6420.,  7849., 8500, 12500, 16000 ])

# Consider 4 Surveys
# z<0.15
dfs = df.loc[df['zHD'] < 0.15]
efflams=efflams_all[0:3]

zs = dfs['zHD'].values
dm_data =  dfs['e_mBcorr'].values
Avs = dfs['c'].values
Rvs = numpy.zeros(len(zs))+2.8
dCs= numpy.zeros((3,ncal))

nsn1 = len(zs)
nfilt1 = len(efflams)

F =  F_all(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data)
addCalibrationUncertainty(ncosmo, nfilt1, ncal, F)

# 0.15 < z < 0.55
dfs = df.loc[(df['zHD'] > 0.15) & (df['zHD'] < 0.55)]
efflams=efflams_all[1:4]

zs = dfs['zHD'].values
dmags =  dfs['e_mBcorr'].values
Avs = dfs['c'].values
Rvs = numpy.zeros(len(zs))+2.8
dCs= numpy.zeros((3,ncal))

nsn2 = len(zs)
nfilt2 = len(efflams)

F_ =  F_all(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data)
addCalibrationUncertainty(ncosmo, nfilt2, ncal, F_)

F = F_merge(ncosmo, F, F_)

# # 0.55 < z < 0.9
dfs = df.loc[(df['zHD'] > 0.55) & (df['zHD'] < 0.9)]
efflams=efflams_all[2:5]
zs = dfs['zHD'].values
dmags =  dfs['e_mBcorr'].values
Avs = dfs['c'].values
Rvs = numpy.zeros(len(zs))+2.8
dCs= numpy.zeros((3,ncal))

nsn3 = len(zs)
nfilt3 = len(efflams)

F_ =  F_all(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data)
addCalibrationUncertainty(ncosmo, nfilt3, ncal, F_)

F = F_merge(ncosmo, F, F_)

# # 0.9 < z
dfs = df.loc[(df['zHD'] > 0.9)]
efflams=efflams_all[3:6]
zs = dfs['zHD'].values
dmags =  dfs['e_mBcorr'].values
Avs = dfs['c'].values
Rvs = numpy.zeros(len(zs))+2.8
dCs= numpy.zeros((3,ncal))

nsn4 = len(zs)
nfilt4 = len(efflams)

F_ =  F_all(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data)
addCalibrationUncertainty(ncosmo, nfilt4, ncal, F_)

F = F_merge(ncosmo, F, F_)

# print(nsn1,nsn2,nsn3,nsn4)

In [None]:
import pickle
file = open('test.pkl', 'wb')
pickle.dump([F, ncosmo, ncal, nfilt1, nsn1, nfilt2, nsn2, nfilt3, nsn3, nfilt4, nsn4] , file)
file.close()

In [None]:
file = open('test.pkl', 'rb')
F, ncosmo, ncal, nfilt1, nsn1, nfilt2, nsn2, nfilt3, nsn3, nfilt4, nsn4 =pickle.load(file)
file.close()
F_hold=numpy.array(F)

In [None]:
F = numpy.array(F_hold)

# indeces to delete
ind0=[1,2]
_F = numpy.delete(F,ind0,0)
_F = numpy.delete(_F,ind0,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0,0])

ind1=[1,2]  #REMOVE wa
ind2=[1,2]
for i in range(nfilt1):
    ind1.append(ncosmo+i*nfilt1+2)
    ind2.append(ncosmo+i*nfilt1)
    ind2.append(ncosmo+i*nfilt1+2)

for i in range(nfilt2):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2+2)

for i in range(nfilt3):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3+2)

for i in range(nfilt4):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4+2)

_F = numpy.delete(F,ind1,0)
_F = numpy.delete(_F,ind1,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0,0], numpy.sqrt(Finv[0,0]))

_F = numpy.delete(F,ind2,0)
_F = numpy.delete(_F,ind2,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0,0], numpy.sqrt(Finv[0,0]))


# indeces to delete
ind0=[2]
_F = numpy.delete(F,ind0,0)
_F = numpy.delete(_F,ind0,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0,0], numpy.sqrt(Finv[0,0]))
# print(numpy.linalg.det(Finv[0:2,0:2]))


ind1=[2]  #REMOVE wa
ind2=[2]
for i in range(nfilt1):
    ind1.append(ncosmo+i*nfilt1+2)
    ind2.append(ncosmo+i*nfilt1)
    ind2.append(ncosmo+i*nfilt1+2)

for i in range(nfilt2):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2+2)

for i in range(nfilt3):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3+2)

for i in range(nfilt4):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4+2)
    
_F = numpy.delete(F,ind1,0)
_F = numpy.delete(_F,ind1,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0:1,0:1])
print(Finv[0,0], numpy.sqrt(Finv[0:1,0:1]).diagonal())
# print(numpy.linalg.det(Finv[0:2,0:2]))

_F = numpy.delete(F,ind2,0)
_F = numpy.delete(_F,ind2,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0:1,0:1])
print(Finv[0,0], numpy.sqrt(Finv[0:1,0:1]).diagonal())
# print(numpy.linalg.det(Finv[0:2,0:2]))

ind1=[2]  #REMOVE wa and R
ind2=[2]
for i in range(nfilt1):
    ind1.append(ncosmo+i*nfilt1+2)
    ind2.append(ncosmo+i*nfilt1)
    ind2.append(ncosmo+i*nfilt1+2)
    for j in range(nsn1):
        ind1.append(ncosmo+ncal*nfilt1 + nsn1+j)
        ind2.append(ncosmo+ncal*nfilt1 + nsn1+j)    

for i in range(nfilt2):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 +i*nfilt2+2)
    for j in range(nsn2):
        ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + nsn2 +j)
        ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + nsn2 + j)   
        
for i in range(nfilt3):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + i*nfilt3+2)
    for j in range(nsn3):
        ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + nsn3 +j)
        ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + nsn3 + j)   
    
for i in range(nfilt4):
    ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4+2)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4)
    ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3  + i*nfilt4+2)
    for j in range(nsn4):
        ind1.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3 + nsn3 +j)
        ind2.append(ncosmo+ncal*nfilt1 + 2*nsn1 + ncal*nfilt2 + 2*nsn2 + ncal*nfilt3 + 2*nsn3 + nsn3 + j)   
        

 

_F = numpy.delete(F,ind1,0)
_F = numpy.delete(_F,ind1,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0:1,0:1])
print(Finv[0,0], numpy.sqrt(Finv[0:1,0:1]).diagonal())
# print(numpy.linalg.det(Finv[0:2,0:2]))

_F = numpy.delete(F,ind2,0)
_F = numpy.delete(_F,ind2,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0:1,0:1])
print(Finv[0,0], numpy.sqrt(Finv[0:1,0:1]).diagonal())
# print(numpy.linalg.det(Finv[0:2,0:2]))



# # Omega_M prior
# F[0,0]  += 1/0.03**2 # Omega_M prior
# # indeces to delete
# Finv= numpy.linalg.inv(F)
# print(Finv[0:3,0:3])
# # print(numpy.linalg.det(Finv[0:2,0:2]))

# ind1=[]  #REMOVE wa
# ind2=[]
# for i in range(nfilt):
#     ind1.append(ncosmo+i*nfilt+2)
#     ind1.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+2)
#     ind2.append(ncosmo+i*nfilt)
#     ind2.append(ncosmo+i*nfilt+2)
#     ind2.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt)
#     ind2.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+2)
    
# _F = numpy.delete(F,ind1,0)
# _F = numpy.delete(_F,ind1,1)
# Finv= numpy.linalg.inv(_F)
# print(Finv[0:3,0:3])
# # print(numpy.linalg.det(Finv[0:2,0:2]))

# _F = numpy.delete(F,ind2,0)
# _F = numpy.delete(_F,ind2,1)
# Finv= numpy.linalg.inv(_F)
# print(Finv[0:3,0:3])
# # print(numpy.linalg.det(Finv[0:2,0:2]))

## Run a real example

In [21]:
cosmo = jc.Planck15()
Omega_c=cosmo.Omega_c
w0=cosmo.w0
wa=cosmo.wa
dm_int =0.1
ncal=3
ncosmo=3
efflams = numpy.array([4700., 6420.,  7849.])
ffwhm = numpy.array([1500., 1480.,  1470.])
nfilt=len(efflams)

#survey 1
nsn1=100 # 1300
dm_data =0.15*numpy.sqrt(2)
zs = numpy.linspace(0.3,1,nsn1)
zs = numpy.linspace(0.3**3,1**3,nsn1)
zs = zs**(1/3)

Avs = pAvrvs(size=nsn1)
Rvs = numpy.random.normal(2.8, 0.3,size=nsn1)
dCs= numpy.zeros((3,ncal))

F1 =  F_all(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data)

# #survey 2
nsn2=100 # 227
dm_data =0.15*numpy.sqrt(2)
zs = numpy.linspace(0.023**3,0.15**3,nsn2)
zs = zs**(1/3)

Avs = pAvrvs(size=nsn2)
Rvs = numpy.random.normal(2.8, 0.3,size=nsn2)
dCs= numpy.zeros((3,ncal))

F2 =  F_all(Omega_c, w0, wa, dCs, Avs, Rvs, zs, efflams, dm_int, dm_data)

F = F_merge(ncosmo, F1,F2)

# print(numpy.linalg.slogdet(F))
# Finv= numpy.linalg.inv(F)
# print(Finv[0:ncosmo,0:ncosmo])


In [30]:
import pickle
file = open('temp.pkl', 'wb')
pickle.dump(F, file)
file.close()

# Play with result

In [31]:
file = open('temp.pkl', 'rb')
F =pickle.load(file)
file.close()
F_hold=numpy.array(F)

## Calibration error priors and stripping

In [32]:
F = numpy.array(F_hold)

# examine effect if calibration errors
# calibration axes
# by hand put in priors for calibration
ncalinvsig2 = [1./64,1/0.15/0.15, 1/0.15/0.15]
# I checked the ordering of the C variables
for i in range(nfilt):
    for j in range(ncal):
        F[ncosmo+i*nfilt+j,ncosmo+i*nfilt+j] += ncalinvsig2[j]
        F[ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+j,ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+j] += ncalinvsig2[j]


# indeces to delete
ind0=[1,2]
_F = numpy.delete(F,ind0,0)
_F = numpy.delete(_F,ind0,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0,0])

ind1=[1,2]  #REMOVE wa
ind2=[1,2]
for i in range(nfilt):
    ind1.append(ncosmo+i*nfilt+2)
    ind1.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+2)
    ind2.append(ncosmo+i*nfilt)
    ind2.append(ncosmo+i*nfilt+2)
    ind2.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt)
    ind2.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+2)
    
_F = numpy.delete(F,ind1,0)
_F = numpy.delete(_F,ind1,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0,0])

_F = numpy.delete(F,ind2,0)
_F = numpy.delete(_F,ind2,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0,0])


# indeces to delete
ind0=[2]
_F = numpy.delete(F,ind0,0)
_F = numpy.delete(_F,ind0,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0:2,0:2])
# print(numpy.linalg.det(Finv[0:2,0:2]))

ind1=[2]  #REMOVE wa
ind2=[2]
for i in range(nfilt):
    ind1.append(ncosmo+i*nfilt+2)
    ind1.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+2)
    ind2.append(ncosmo+i*nfilt)
    ind2.append(ncosmo+i*nfilt+2)
    ind2.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt)
    ind2.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+2)
    
_F = numpy.delete(F,ind1,0)
_F = numpy.delete(_F,ind1,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0:2,0:2])
# print(numpy.linalg.det(Finv[0:2,0:2]))

_F = numpy.delete(F,ind2,0)
_F = numpy.delete(_F,ind2,1)
Finv= numpy.linalg.inv(_F)
print(Finv[0:2,0:2])
# print(numpy.linalg.det(Finv[0:2,0:2]))

# # Omega_M prior
# F[0,0]  += 1/0.03**2 # Omega_M prior
# # indeces to delete
# Finv= numpy.linalg.inv(F)
# print(Finv[0:3,0:3])
# # print(numpy.linalg.det(Finv[0:2,0:2]))

# ind1=[]  #REMOVE wa
# ind2=[]
# for i in range(nfilt):
#     ind1.append(ncosmo+i*nfilt+2)
#     ind1.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+2)
#     ind2.append(ncosmo+i*nfilt)
#     ind2.append(ncosmo+i*nfilt+2)
#     ind2.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt)
#     ind2.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+2)
    
# _F = numpy.delete(F,ind1,0)
# _F = numpy.delete(_F,ind1,1)
# Finv= numpy.linalg.inv(_F)
# print(Finv[0:3,0:3])
# # print(numpy.linalg.det(Finv[0:2,0:2]))

# _F = numpy.delete(F,ind2,0)
# _F = numpy.delete(_F,ind2,1)
# Finv= numpy.linalg.inv(_F)
# print(Finv[0:3,0:3])
# # print(numpy.linalg.det(Finv[0:2,0:2]))

7.995366490723224e-05
7.732273581370408e-05
7.671439134000466e-05
[[ 0.00488361 -0.01289917]
 [-0.01289917  0.03463791]]
[[ 0.00460984 -0.01219521]
 [-0.01219521  0.0328125 ]]
[[ 0.00458559 -0.01213609]
 [-0.01213609  0.03266551]]


In [None]:
print(numpy.sqrt(8.060455772927836e-05),numpy.sqrt(0.035))

In [33]:
F = numpy.array(F_hold)

# examine effect if calibration errors
# calibration axes
# by hand put in priors for calibration
ncalinvsig2 = [1./64,1/0.15/0.15, 1/0.15/0.15]
# I checked the ordering of the C variables
for i in range(nfilt):
    for j in range(ncal):
        F[ncosmo+i*nfilt+j,ncosmo+i*nfilt+j] += ncalinvsig2[j]
        F[ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+j,ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+j] += ncalinvsig2[j]

Finv= numpy.linalg.inv(F)

mask=[]
for i in range(3): #nfilt):
    for j in range(ncal):    
        mask.append(ncosmo+i*nfilt+j)
for i in range(3): #nfilt):
    for j in range(ncal):    
        mask.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+j)

print(numpy.sqrt(Finv[numpy.ix_(mask,mask)].diagonal()))


[3.37338016e+00 5.57083839e-03 3.02822944e-03 7.21333481e+00
 5.33863241e-03 9.51546192e-03 7.88118686e+00 5.30194247e-03
 1.99371006e-02 7.89840362e+00 1.25002221e-02 3.70081345e-02
 7.98391911e+00 1.24053701e-02 6.41245871e-02 7.91464934e+00
 1.26313894e-02 2.47461949e-02]


In [None]:
mask=[0]
for i in range(3): #nfilt):
    for j in range(ncal):    
        mask.append(ncosmo+i*nfilt+j)
for i in range(3): #nfilt):
    for j in range(ncal):    
        mask.append(ncosmo+ncal*nfilt + 2*nsn1 + i*nfilt+j)

print(F[numpy.ix_([0,1],mask)])
Finv= numpy.linalg.inv(F)
FinvCor = Finv / (numpy.outer(numpy.sqrt(Finv.diagonal()),numpy.sqrt(Finv.diagonal()))) 
print(FinvCor[numpy.ix_([0,1],mask)])

In [None]:
# Omega_M prior
F[0,0]  += 1/0.03**2 # Omega_M prior

## Properties of the DES filter system

In [None]:
efflams = numpy.array([4700., 6420.,  7849.])
ffwhm = numpy.array([1500., 1480.,  1470.])
nfilt=len(efflams)
# separation between filters in log-lambda
lnlam = numpy.log(efflams)
print(lnlam)
deltalnlam = (lnlam[-1] - lnlam[0])/(nfilt-1)
print(deltalnlam)

# lambda/Delta lambda
print (efflams/ffwhm)

delta = numpy.arcsinh(1/2/4.5)
print(efflams*(numpy.exp(delta)-numpy.exp(-delta)))