In [None]:
import os
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from glob import glob
from numpy.random import normal
from scipy.stats import sigmaclip
from astropy.coordinates import Distance
from astropy.cosmology import Planck18_arXiv_v2 as cosmo

from mockFRBhosts import draw_galaxies, observed_bands
from mockFRBhosts.observable import beck_mag_cuts, estimate_photo_err

#%matplotlib widget
sns.set_theme(style='whitegrid')

In [None]:
# Load FRBs from the pickle files.
n_frbs = 1000
pickles = sorted(glob('../Simulated_FRBs/*.pickle'))
survey_models, z_models = [], []
for file in pickles:
    # Extract models from file names.
    params = os.path.basename(file)
    params = os.path.splitext(params)[0]
    params = params.split('_', 1)
    survey_models.append(params[0])
    z_models.append(params[1])

In [None]:
# Pick an FRB survey and redshift distribution
chosen = pickles[4]
print(chosen)

In [None]:
# Weight galaxy choice depending on file name
if os.path.splitext(chosen)[0][-3:] == 'sfr':  # last thre letters before extension
    weights = 'mstardot'
else:
    weights = 'mstars_total'

frbs = np.load(chosen, allow_pickle=True)
print(frbs.shape[0], "FRBs in file, using only first", n_frbs)
frbs = frbs.iloc[:n_frbs].copy() 

galaxies, snapnum = draw_galaxies(frbs['z'], weights=weights, seed=42)

# Order FRBs such that they correspond to galaxies at the same positions.
frbs.loc[:, 'snapnum'] = snapnum
frbs.sort_values('snapnum', ascending=True, inplace=True)

n_bands_obs_SDSS, n_bands_obs_LSST, n_bands_obs_Euclid, n_bands_obs_DES = observed_bands(frbs, galaxies)

frbs['n_bands_SDSS'] = n_bands_obs_SDSS.to_numpy()
frbs['n_bands_LSST'] = n_bands_obs_LSST.to_numpy()
frbs['n_bands_Euclid'] = n_bands_obs_Euclid.to_numpy()
frbs['n_bands_DES'] = n_bands_obs_DES.to_numpy()

In [None]:
# Get apparent magnitudes in SDSS
frb_zs = frbs['z'].to_numpy()
dist = Distance(z=frb_zs, cosmology=cosmo)
apparent_mag_SDSS = (dist.distmod.value[:, np.newaxis] + 5*np.log10(cosmo.h)
                    - 2.5*np.log10(1+frb_zs)[:, np.newaxis]
                    + galaxies.loc[:, 'mag_SDSS-u_tot':'mag_SDSS-z_tot'])
mag_limits_SDSS = np.array([22.0, 22.2, 22.2, 21.3, 20.5])

In [None]:
apparent_mag_SDSS.rename(columns={'mag_SDSS-u_tot' : 'u', 
                                  'mag_SDSS-g_tot' : 'g', 
                                  'mag_SDSS-r_tot' : 'r',
                                  'mag_SDSS-i_tot' : 'i', 
                                  'mag_SDSS-z_tot' : 'z'},
                        inplace=True)

In [None]:
# Get real SDSS galaxies to estimate the errors.
sdss = pd.read_csv('Skyserver_SQL3_4_2022 4 21 09 PM.csv', delimiter=',', header=1)

ngal = sdss.shape[0]

# Save magnitudes and errors seperately.
bands = sdss.loc[:, 'u':'z':2]
errs = sdss.loc[:, 'err_u':'err_z':2]

# Exclude outliers.
not_outlier = (bands > -100).all(axis=1)
bands = bands[not_outlier]
errs = errs[not_outlier]

# Draw errors from the data statistics depending on the binned magnitudes.
sim_errs = pd.DataFrame(index=apparent_mag_SDSS.index)

# Simulate errors in each band.
for b, e in zip(bands, errs):
    # Define edges of magnitude bins. Make sure all magnitudes are inside.
    sim_errs[e], med_mag, std_mag, bins = estimate_photo_err(apparent_mag_SDSS[b], bands[b], errs[e], bins=30)


In [None]:
# Apply the SDSS magnitude cut?
apply_mag_cut = True

if apply_mag_cut:
    # Median 5-sigma depths
    # https://www.sdss.org/dr14/imaging/other_info/
    max_mag = np.array([22.15, 23.13, 22.70, 22.20, 20.71])
    bright_enough = (apparent_mag_SDSS < max_mag).all(1)

# Deterine which galaxies pass the Beck cuts assuming the real and est errors
beck_passed = beck_mag_cuts(apparent_mag_SDSS, sim_errs, verbose=False)

reliable_photometry = bright_enough & beck_passed
print(f"{bright_enough.sum()} out of {bright_enough.shape[0]} galaxies have sufficient magnitude, "
      f"{beck_passed.sum()} pass the color tests, {reliable_photometry.sum()} pass both.")

In [None]:
# Same with original data
bright_enough_real = (bands < max_mag).all(1)

# Deterine which galaxies pass the Beck cuts assuming the real and est errors
beck_passed_real = beck_mag_cuts(bands, errs, verbose=False)

reliable_photometry_real = bright_enough_real & beck_passed_real
print(f"{bright_enough_real.sum()} out of {bright_enough_real.shape[0]} galaxies have sufficient magnitude, "
      f"{beck_passed_real.sum()} pass the color tests, {reliable_photometry_real.sum()} pass both.")

In [None]:
#Plot photometry and errors
for b, e in zip(bands, errs):
    fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True)
    ax[0].set_yscale('log')
    ax[1].set_yscale('log')
    
    sns.histplot(x=bands[b], y=errs[e], ax=ax[0])
    sns.histplot(x=apparent_mag_SDSS[b].to_numpy(), y=sim_errs[e].to_numpy(), ax=ax[1])

    #plt.legend()
    ax[0].set_xlabel(f'{bands[b].name}-band magnitude')
    ax[0].set_ylabel(f'Error {bands[b].name}-band magnitude')

In [None]:
#Calculate color differences. Could use bands.diff(periods=-1, axis=1)
ugcolor = bands['u'] - bands['g']
grcolor = bands['g'] - bands['r']
ricolor = bands['r'] - bands['i']
izcolor = bands['i'] - bands['z']

err_gr = np.sqrt(errs['err_g']**2 + errs['err_r']**2)
err_ri = np.sqrt(errs['err_r']**2 + errs['err_i']**2)
err_iz = np.sqrt(errs['err_i']**2 + errs['err_z']**2)

ugcolor_sim = apparent_mag_SDSS['u'] - apparent_mag_SDSS['g']
grcolor_sim = apparent_mag_SDSS['g'] - apparent_mag_SDSS['r']
ricolor_sim = apparent_mag_SDSS['r'] - apparent_mag_SDSS['i']
izcolor_sim = apparent_mag_SDSS['i'] - apparent_mag_SDSS['z']

err_gr_sim = np.sqrt(sim_errs['err_g']**2 + sim_errs['err_r']**2)
err_ri_sim = np.sqrt(sim_errs['err_r']**2 + sim_errs['err_i']**2)
err_iz_sim = np.sqrt(sim_errs['err_i']**2 + sim_errs['err_z']**2)

In [None]:
vinds_est = beck_passed
vinds_real = beck_mag_cuts(bands, errs, verbose=False)

In [None]:
#Plot colors and errors
plt.figure()
plt.plot(grcolor, err_gr, 'ko')
plt.plot(grcolor[vinds_real], err_gr[vinds_real], 'ro')
plt.plot(grcolor_sim[vinds_est], err_gr_sim[vinds_est], 'y+')
plt.axhline(0.225)
plt.yscale('log')

plt.xlabel('gr color')
plt.ylabel('gr color error')

plt.figure()
plt.plot(ricolor, err_ri, 'ko')
plt.plot(ricolor[vinds_real], err_ri[vinds_real], 'ro')
plt.plot(ricolor_sim[vinds_est], err_ri_sim[vinds_est], 'y+')
plt.axhline(0.15)
plt.yscale('log')

plt.xlabel('ri color')
plt.ylabel('ri color error')

plt.figure()
plt.plot(izcolor, err_iz, 'ko')
plt.plot(izcolor[vinds_real], err_iz[vinds_real], 'ro')
plt.plot(izcolor_sim[vinds_est], err_iz_sim[vinds_est], 'y+')
plt.axhline(0.25)
plt.yscale('log')

plt.xlabel('iz color')
plt.ylabel('iz color error')

In [None]:
#Compare estimated statistics with true scatter in the data
myband = bands['u']
merr,med_mag,std_mag, bins = estimate_photo_err(bands['u'], bands['u'], errs['err_u'], bins=30)

#This plot doesn't make sense. Rethink
#py.plot(myband[0], myband[1], 'r+')
#py.errorbar(bins, med_mag, yerr=std_mag, fmt='o')
#py.xlabel('magnitude bin')
#py.ylabel('median error')

plt.figure()
plt.errorbar(bins, med_mag, yerr=2*std_mag, fmt='o')
plt.errorbar(bins, med_mag, yerr=std_mag, fmt='o')
plt.axhline(0.15, color='k')
plt.xlabel('magnitude')
#plt.xlim((14,22))
#plt.ylim((0,0.2))

In [None]:
dr16_bright = (apparent_mag_SDSS < mag_limits_SDSS).all(1)
bright_enough = (apparent_mag_SDSS < max_mag).all(1)

In [None]:
dr16_bright.sum(), bright_enough.sum(), (dr16_bright | bright_enough).sum(), reliable_photometry.sum()


In [None]:
n_bands_obs = frbs['n_bands_SDSS']
n_bands = n_bands_obs.max()
#n_observed_sdss[i] = np.histogram(frbs.loc[n_bands_obs.to_numpy() == n_bands, 'z'], bins=np.linspace(0, z_max[0], n_z_bins+1))[0]

In [None]:
frbs.loc[n_bands_obs.to_numpy() == n_bands, 'z']

In [None]:
fig, ax = plt.subplots()
ax.hist(frbs['z'], density=False, alpha=0.5, bins=np.linspace(0, 3, 30), color='blue')  #bins[bins<2]
ax.hist(frbs.loc[bright_enough.to_numpy(), 'z'], density=False, alpha=0.5, bins=np.linspace(0, 3, 30), color='orange')
ax.hist(frbs.loc[reliable_photometry.to_numpy(), 'z'], density=False, alpha=0.5, bins=np.linspace(0, 3, 30), color='red')