In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

from pop_cosmos.constants import COSMOS_FILTERS_LATEX
from pop_cosmos.catalogue import CatalogueGenerator

from astropy.cosmology import Planck18

plt.rcParams.update({"font.size": 15})

In [None]:
catalogue_generator = torch.load("../trained_models/catalogueModelCPU.pt", weights_only=False)

In [None]:
base_noise, base_sigma, base_phi = catalogue_generator.generate_base_samples(10000)

In [None]:
noisy_fluxes, noisy_magnitudes, noisy_asinh_magnitudes, flux_sigmas, theta_samples, model_fluxes = catalogue_generator(base_noise, base_sigma, base_phi)

In [None]:
lims = torch.inf*torch.ones(26)
lims[-2] = 26.0
selection = catalogue_generator.selection_cut(noisy_fluxes, noisy_magnitudes, flux_sigmas, lims).detach().numpy()

noisy_magnitudes = noisy_magnitudes.detach().numpy()
noisy_fluxes = noisy_fluxes.detach().numpy()
flux_sigmas = flux_sigmas.detach().numpy()
theta_samples = theta_samples.detach().numpy()

In [None]:
noisy_magnitudes_selected = noisy_magnitudes[selection]
noisy_flux_sigmas_selected = flux_sigmas[selection]
thetas_selected = theta_samples[selection]

In [None]:
mag_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -2, -1]

x_bins = np.linspace(16, 32, 33)
fig, ax = plt.subplots(3, 4, figsize=(12.8, 6.4))
ax = ax.flatten()
for i in range(len(mag_list)):
    x = mag_list[i]
    ax[i].hist(noisy_magnitudes_selected[:,x], color="#EE6677", bins=x_bins, histtype="step", linewidth=1.2, density=True)

    plt.text(0.05, 0.95, COSMOS_FILTERS_LATEX[x], transform=ax[i].transAxes, verticalalignment='top')
    if i == 3:
        ax[i].text(0.95, 0.95, "pop-cosmos", transform=ax[i].transAxes, verticalalignment='top', horizontalalignment="right", color="#EE6677")
    if i == 3 or i == 0:
        ax[i].set_ylim(0, ax[i].get_ylim()[1]*1.4)
    if i % 4 == 0:
        ax[i].set_ylabel('density')
    ax[i].yaxis.set_ticks([])
    if i > 7:
        ax[i].set_xlabel('mag')
        ax[i].set_xticks([20, 25, 30])
    else:
        ax[i].xaxis.set_ticks([])
    ax[i].set_xlim(16,32)
plt.subplots_adjust(hspace=0, wspace=0)
plt.show()

In [None]:
mag_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -2]

x_bins = np.linspace(-2.5, 2.5, 51)
fig, ax = plt.subplots(3, 3, figsize=(12.8, 6.4))
ax = ax.flatten()
j = 0
for i in range(len(mag_list)-1):
    if i != 5:
        x, y = mag_list[i], mag_list[i+1]
        ax[j].hist(noisy_magnitudes_selected[:,x] - noisy_magnitudes_selected[:,y], color="#EE6677", bins=x_bins, histtype="step", linewidth=1.2, density=True)

        plt.text(0.05, 0.95, COSMOS_FILTERS_LATEX[x] + " $-$ " + COSMOS_FILTERS_LATEX[y], transform=ax[j].transAxes, verticalalignment='top')
        if j == 8:
            ax[j].text(0.95, 0.95, "pop-cosmos", transform=ax[j].transAxes, verticalalignment='top', horizontalalignment="right", color="#EE6677")
            ax[j].set_ylim(0, ax[j].get_ylim()[1]*1.25)
        if j % 3 == 0:
            ax[j].set_ylabel('density')
        ax[j].yaxis.set_ticks([])
        if j >= 6:
            ax[j].set_xlabel('color')
            ax[j].xaxis.set_ticks([-2,-1,0,1,2])
        else:
            ax[j].xaxis.set_ticks([])
        ax[j].set_xlim(-2.5, 2.5)
        j += 1
plt.subplots_adjust(hspace=0, wspace=0)
plt.show()

In [None]:
plt.hist(thetas_selected[:,-1], bins=61, color="#EE6677", histtype="step", density=True)
plt.xlabel('redshift, $z$')
plt.ylabel("$n(z)$")
plt.show()