# Setting up a latent function Gaussian Process for H(z)

## GP modelling - priors

In [None]:
zlat = range(0, 3.0, length=15); # latent variable space
z_integ = cosmo1.settings.zs; # integral variable space

In [None]:
function model_latent_GP(eta, l, v; omega_m=cosmo1.cpar.Ωm, sigma_8=cosmo1.cpar.σ8, h_0 = cosmo1.cpar.h,
                         x=zlat, z=z_obs, data_cov=covariance_obs)
    # Dimensions of predictors .
    kernel = sqexp_cov_fn(x; eta=eta, l=l)
    cpar = CosmoPar(Ωm=omega_m,  Ωb=cosmo1.cpar.Ωb, h=h_0, σ8=sigma_8)
    mean_ex = Ez(cpar, x)
    # a latent GP is carried out using a Wiener filter
    ex = latent_GP(mean_ex, v, kernel)
    # transforms hx from x space to z space
    ez =  conditional(x, z, ex, sqexp_cov_fn;
                     eta=eta, l=l)
    return ex, ez
end

hx is the GP realisation of the hubble parameter. hz transforms hx from the latent space to another parameter space (integral or observation).

In [None]:
N_samples = 100
# data parameter space conditional
exs1=zeros(N_samples, 15) #latent variable chain
ezs1=zeros(N_samples, length(z_obs)) #observed variable chain
# pick 100 random hyperparameter combinations, and compute the GP realisation of them in both latent and conditional space.
for i in 1:N_samples
    v = rand(MvNormal(zeros(length(zlat)), ones(length(zlat))))
    exs1[i, :], ezs1[i, :] = model_latent_GP(0.1, 0.3, v)
end
y_m1, y_s1 = mean(ezs1, dims=1), std(ezs1, dims=1); #observed space
gp_m1, gp_s1 = mean(exs1, dims=1), std(exs1, dims=1); #latent space


In [None]:
# integral parameter space conditional
exs2=zeros(N_samples, length(zlat))
ezs2=zeros(N_samples, length(z_integ))
hz_pr=zeros(N_samples, length(z_integ))
aiso_pr=zeros(N_samples, length(z_integ))
aap_pr=zeros(N_samples, length(z_integ))
fs8_pr=zeros(N_samples, length(z_integ))
#mu_pr=zeros(N_samples, length(z_integ))

for i in 1:N_samples
    v = rand(MvNormal(zeros(length(zlat)), ones(length(zlat))))
    h_pr = rand(h0_dist)
    s8_pr = rand(s8_dist)
    omegam_pr = rand(omegam_dist)
    rd_pr = rand(rd_dist)
    #M_pr = rand(M_dist)

    exs2[i, :], ezs2[i, :] = model_latent_GP(0.1, 0.3, v; z=z_integ, sigma_8=s8_pr)
    hz_pr[i, :] = hz_from_ez(z_integ, h_pr, ezs2[i, :])
    aiso_pr[i, :] = a_iso(z_integ, h_pr, rd_pr, ezs2[i, :])
    aap_pr[i, :] = a_ap(z_integ, h_pr, ezs2[i, :])
    fs8_pr[i, :] = fs8_from_ez(z_integ, omegam_pr, s8_pr, h_pr, ezs2[i, :], eltype(ezs2[i, :]))
    #mu_pr[i, :] = mu(z_integ, h_pr, M_pr, ezs2[i, :])
end

y_m2, y_s2 = mean(ezs2, dims=1), std(ezs2, dims=1); #integral space
gp_m2, gp_s2 = mean(exs2, dims=1), std(exs2, dims=1); #latent space

hz_m2, hz_s2 = mean(hz_pr, dims=1), std(hz_pr, dims=1); 
aiso_m2, aiso_s2 = mean(aiso_pr, dims=1), std(aiso_pr, dims=1); 
aap_m2, aap_s2 = mean(aap_pr, dims=1), std(aap_pr, dims=1);
fs8_m2, fs8_s2 = mean(fs8_pr, dims=1), std(fs8_pr, dims=1);
#mu_m2, mu_s2 = mean(mu_pr, dims=1), std(mu_pr, dims=1);

In [None]:
# prior model predictions
hz_itp2 = LinearInterpolation(z_integ, vec(hz_m2), extrapolation_bc=Line());
aiso_itp2 = LinearInterpolation(z_integ, vec(aiso_m2), extrapolation_bc=Line());
aap_itp2 = LinearInterpolation(z_integ, vec(aap_m2), extrapolation_bc=Line());
fs8_itp2 = LinearInterpolation(z_integ, vec(fs8_m2), extrapolation_bc=Line());
#mu_itp2 = LinearInterpolation(z_integ, vec(mu_m2), extrapolation_bc=Line());

In [None]:
# integral space hz plot
plot(fakedatahz.z, fakedatahz.data, yerr=hz_err, label="Data hz", ms=2, seriestype=:scatter)
plot!(z_integ, vec(hz_m2), label="hz mean")
plot!(z_integ, vec(hz_m2 .- hz_s2),  fillrange = vec(hz_m2 .+ hz_s2), fillalpha=0.2, c=1, label="hz standard deviation")

In [None]:
# integral space aiso plot
plot(fakedataaiso.z, fakedataaiso.data, yerr=aiso_err, label="Data a_iso", ms=2, seriestype=:scatter)
plot!(z_integ, vec(aiso_m2), label="a_iso mean")
plot!(z_integ, vec(aiso_m2 .- aiso_s2),  fillrange = vec(aiso_m2 .+ aiso_s2), fillalpha=0.2, c=1, label="a_iso standard deviation")

In [None]:
# plotting chi data 
plot(fakedataaap.z, fakedataaap.data, yerr=aap_err, label="Data a_ap", ms=2, seriestype=:scatter)
plot!(z_integ, vec(aap_m2), label="a_ap mean")
plot!(z_integ, vec(aap_m2 .- aap_s2),  fillrange = vec(aap_m2 .+ aap_s2), fillalpha=0.2, c=1, label="a_ap standard deviation")

In [None]:
# plotting fs8 data 
plot(fakedatafs8.z, fakedatafs8.data, yerr=fs8_err, label="Data fs8", ms=2, seriestype=:scatter)
plot!(z_integ, vec(fs8_m2), label="fs8 mean")
plot!(z_integ, vec(fs8_m2 .- fs8_s2),  fillrange = vec(fs8_m2 .+ fs8_s2), fillalpha=0.2, c=1, label="fs8 standard deviation")

In [None]:
# plotting supernovae data 
# plot(fakedatamu.z, fakedatamu.data, yerr=mu_err, label="Data mu", ms=2, seriestype=:scatter)
# plot!(z_integ, vec(mu_m2), label="mu mean")
# plot!(z_integ, vec(mu_m2 .- mu_s2),  fillrange = vec(mu_m2 .+ mu_s2), fillalpha=0.2, c=1, label="mu standard deviation")

## GP modelling - posteriors

In [None]:
@model function stats_model(y; int_grid=z_integ, X=zlat, data_x=z_obs, data_cov=covariance_obs)
    # Priors, parameters
    eta2 = 50
    l2 = 0.3
    v_po ~ MvNormal(zeros(length(X)), ones(length(X)))
    omegam_po ~ omegam_dist
    s8_po ~ s8_dist
    h_po ~ h0_dist
    r_po ~ rd_dist
    #M_po ~ M_dist
    
    kernel = sqexp_cov_fn(X, eta=eta2, l=l2)
    cpar = CosmoPar(Ωm=cosmo1.cpar.Ωm, Ωb=cosmo1.cpar.Ωb, h=cosmo1.cpar.h, σ8=s8_po)
    mean_ez = Ez(cpar, X)
    ez_latent = latent_GP(mean_ez, v_po, kernel)
    ez_gp = conditional(X, int_grid, ez_latent, sqexp_cov_fn; eta=eta2, l=l2) # converting from latent space to integral space
    
    # sampling over parameters
    hz = hz_from_ez(int_grid, h_po, ez_gp);
    aiso = a_iso(int_grid, h_po, r_po, ez_gp);
    aap = a_ap(int_grid, h_po, ez_gp);
    fs8 = fs8_from_ez(int_grid, omegam_po, s8_po, h_po, ez_gp, eltype(ez_gp));
    #mu_mod = mu(int_grid, h_po, M_po, ez_gp);

    hz_interp = LinearInterpolation(int_grid, hz, extrapolation_bc=Line())
    aiso_interp = LinearInterpolation(int_grid, aiso, extrapolation_bc=Line())
    aap_interp = LinearInterpolation(int_grid, aap, extrapolation_bc=Line())
    fs8_interp = LinearInterpolation(int_grid, fs8, extrapolation_bc=Line())
    #mu_interp = LinearInterpolation(int_grid, mu_mod, extrapolation_bc=Line())

    hdata := hz_interp(fakedatahz.z)
    aisodata := aiso_interp(fakedataaiso.z)
    aapdata := aap_interp(fakedataaap.z)
    fs8data := fs8_interp(fakedatafs8.z)
    #mudata := mu_interp(fakedatamu.z)
   

    datay = [hdata;aisodata;aapdata;fs8data]

    y ~ MvNormal(datay, data_cov) 

end

In [None]:
#init_params = [h_po => 0.7, omegam_po => 0.3]
chain_gp = sample(stats_model(data_obs), NUTS(nadapts, 0.7), nsamp)

In [None]:
# data posterior gp chains
hdata_p = group(chain_gp, :hdata).value.data[:,:,1];
aisodata_p = group(chain_gp, :aisodata).value.data[:,:,1];
aapdata_p =  group(chain_gp, :aapdata).value.data[:,:,1];
fs8data_p =  group(chain_gp, :fs8data).value.data[:,:,1];
#mudata_p =  group(chain_gp, :mudata).value.data[:,:,1];

# parameter chains
v_p = group(chain_gp, :v_po).value.data[:, :, 1];
omegam_p = group(chain_gp, :omegam_po).value.data[:, :, 1];
s8_p = group(chain_gp, :s8_po).value.data[:, :, 1];

In [None]:
# data mean / standard dev.
hpmean, hps = mean(hdata_p, dims=1), std(hdata_p, dims=1);
aisopmean, aisops = mean(aisodata_p, dims=1), std(aisodata_p, dims=1);
aappmean, aapps = mean(aapdata_p, dims=1), std(aapdata_p, dims=1);
fs8pmean, fs8ps = mean(fs8data_p, dims=1), std(fs8data_p, dims=1);
#mupmean, mups = mean(mudata_p, dims=1), std(mudata_p, dims=1);

In [None]:
plot(fakedatahz.z, fakedatahz.data, yerr=hz_err, label="Data hz", ms=2, seriestype=:scatter)
plot!(fakedatahz.z, vec(hpmean), label="hz GP mean")
plot!(fakedatahz.z, vec(hpmean .- hps),  fillrange = vec(hpmean .+ hps), fillalpha=0.2, c=4, label="hz standard deviation")
plot!(fakedatahz.z, vec(hz_lcdm_m), label=" H(z) LCDM mean")
plot!(fakedatahz.z, vec(hz_lcdm_m .- hz_lcdm_s),  fillrange = vec(hz_lcdm_m .+ hz_lcdm_s), fillalpha=0.2, c=1, label="H(z) LCDM standard deviation")

In [None]:
plot(fakedataaiso.z, fakedataaiso.data, yerr=aiso_err, label="Data a_iso", ms=2, seriestype=:scatter)
plot!(fakedataaiso.z, vec(aisopmean), label="a_iso GP mean")
plot!(fakedataaiso.z, vec(aisopmean .- aisops),  fillrange = vec(aisopmean .+ aisops), fillalpha=0.2, c=1, label="a_iso standard deviation")

In [None]:
plot(fakedataaap.z, fakedataaap.data, yerr=aap_err, label="Data a_ap", ms=2, seriestype=:scatter)
plot!(fakedataaap.z, vec(aappmean), label="a_ap mean")
plot!(fakedataaap.z, vec(aappmean .- aapps),  fillrange = vec(aappmean .+ aapps), fillalpha=0.2, c=1, label="a_ap standard deviation")

In [None]:
plot(fakedatafs8.z, fakedatafs8.data, yerr=fs8_err, label="Data fs8", ms=2, seriestype=:scatter)
plot!(fakedatafs8.z, vec(fs8pmean), label="fs8 mean")
plot!(fakedatafs8.z, vec(fs8pmean .- fs8ps),  fillrange = vec(fs8pmean .+ fs8ps), fillalpha=0.2, c=1, label="fs8 standard deviation")

In [None]:
plot(fakedatafs8.z, fakedatafs8.data, yerr=fs8_err, label="Data fs8", ms=2, seriestype=:scatter)
plot!(fakedatafs8.z, vec(fs8pmean), label="fs8 mean")
plot!(fakedatafs8.z, vec(fs8pmean .- fs8ps),  fillrange = vec(fs8pmean .+ fs8ps), fillalpha=0.2, c=1, label="fs8 gp standard deviation")
plot!(fakedatafs8.z, vec(fs8_lcdm_m), label="fs8 lcdm")
plot!(fakedatafs8.z, vec(fs8_lcdm_m .- fs8_lcdm_s),  fillrange = vec(fs8_lcdm_m .+ fs8_lcdm_s), fillalpha=0.2, c=4, label="fs8 lcdm standard deviation")

In [None]:
# plot(fakedatamu.z, fakedatamu.data, yerr=mu_err, label="Data mu", ms=2, seriestype=:scatter)
# plot!(fakedatamu.z, vec(mupmean), label="mu mean")
# plot!(fakedatamu.z, vec(mupmean .- mups),  fillrange = vec(mupmean .+ mups), fillalpha=0.2, c=1, label="mu standard deviation")

In [None]:
p0 = plot(fakedatahz.z, fakedatahz.data, yerr=hz_err, seriestype=:scatter,label="H(z)")
p1 = plot(fakedataaiso.z, fakedataaiso.data, yerr=aiso_err, seriestype=:scatter,label="aiso(z)")
p2 = plot(fakedataaap.z, fakedataaap.data, yerr=aap_err, seriestype=:scatter, label="aap(z)")
p3 = plot(fakedatafs8.z, fakedatafs8.data, yerr=fs8_err, seriestype=:scatter, label="fs8(z)")
p4 = plot(fakedatamu.z, fakedatamu.data, yerr=mu_err, seriestype=:scatter, label="mu(z)")


plot!(p0, fakedatahz.z, vec(hpmean), label="hz GP mean")
plot!(p0, fakedatahz.z, vec(hpmean .- hps),  fillrange = vec(hpmean .+ hps), fillalpha=0.2, c=4, label="hz standard deviation")
plot!(p0, fakedatahz.z, vec(hz_lcdm_m), label="H(z) LCDM mean")
plot!(p0, fakedatahz.z, vec(hz_lcdm_m .- hz_lcdm_s),  fillrange = vec(hz_lcdm_m .+ hz_lcdm_s), fillalpha=0.2, c=1, label="H(z) LCDM standard deviation")


plot!(p1, fakedataaiso.z, vec(aisopmean), label="a_iso GP mean")
plot!(p1, fakedataaiso.z, vec(aisopmean .- aisops),  fillrange = vec(aisopmean .+ aisops), fillalpha=0.2, c=4, label="a_iso standard deviation")
plot!(p1, fakedataaiso.z, vec(aiso_lcdm_m), label="a_iso LCDM mean")
plot!(p1, fakedataaiso.z, vec(aiso_lcdm_m .- aiso_lcdm_s),  fillrange = vec(aiso_lcdm_m .+ aiso_lcdm_s), fillalpha=0.2, c=1, label="a_iso LCDM standard deviation")


plot!(p2, fakedataaap.z, vec(aappmean), label="a_ap GP mean")
plot!(p2, fakedataaap.z, vec(aappmean .- aapps),  fillrange = vec(aappmean .+ aapps), fillalpha=0.2, c=4, label="a_ap standard deviation")
plot!(p2, fakedataaap.z, vec(aap_lcdm_m), label="a_ap LCDM mean")
plot!(p2, fakedataaap.z, vec(aap_lcdm_m .- aap_lcdm_s),  fillrange = vec(aap_lcdm_m .+ aap_lcdm_s), fillalpha=0.2, c=1, label="a_ap LCDM standard deviation")


plot!(p3, fakedatafs8.z, vec(fs8pmean), label="fs8 GP mean")
plot!(p3, fakedatafs8.z, vec(fs8pmean .- fs8ps),  fillrange = vec(fs8pmean .+ fs8ps), fillalpha=0.2, c=4, label="fs8 standard deviation")
plot!(p3, fakedatafs8.z, vec(fs8_lcdm_m), label="fs8 LCDM mean")
plot!(p3, fakedatafs8.z, vec(fs8_lcdm_m .- fs8_lcdm_s),  fillrange = vec(fs8_lcdm_m .+ fs8_lcdm_s), fillalpha=0.2, c=1, label="fs8 LCDM standard deviation")


# plot!(p4, fakedatamu.z, vec(mupmean), label="mu GP mean")
# plot!(p4, fakedatamu.z, vec(mupmean .- mups),  fillrange = vec(mupmean .+ mups), fillalpha=0.2, c=4, label="mu standard deviation")
# plot!(p4, fakedatamu.z, vec(mu_lcdm_m), label="mu LCDM mean")
# plot!(p4, fakedatamu.z, vec(mu_lcdm_m .- mu_lcdm_s),  fillrange = vec(mu_lcdm_m .+ mu_lcdm_s), fillalpha=0.2, c=1, label="mu LCDM standard deviation")

plot(p0, p1, p2, p3, layout=(2,3), size=(1600,1000))