In [1]:
import redmagic
import numpy as np
from matplotlib import pyplot as plt
from astroML.density_estimation import XDGMM
from sklearn.mixture import GaussianMixture
import fitsio as fio
import glob
import os
import gc
import numpy as np
from scipy import special
from sklearn.model_selection import train_test_split
import scipy.optimize
import ezgal
import matplotlib.colors as clr
from scipy import interpolate
from plotly.subplots import make_subplots
from sklearn.inspection import DecisionBoundaryDisplay
from redmagic.utils import CubicSpline,make_nodes
from redmagic.fitters import MedZFitter
import plotly.graph_objects as go
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True

Roman_z = redmagic.load_data.load_true_gal_z(use_fio=True)
roman_truth,roman_detection = redmagic.load_data.load_roman_output(use_fio=True)

In [3]:
# load SED matched gals
Red_gal_ind = np.load('redmagic/data/roman_red_sed_bulge_disk_match_truth_ind.npy')

# Info is stored as the index in the roman truth catalog (same for measurement catalog)
# Matched to the index in total truth file to get redshift
matched_inds_in_truth = roman_truth['ind'][Red_gal_ind]

# Measured quantities of SED matched LRGs
mms, = np.where(roman_truth['gal_star'][Red_gal_ind]==0)
measured_qts = roman_detection[Red_gal_ind][mms]

# Redshift of SED matched LRGs from truth catalog
zs = Roman_z[matched_inds_in_truth][mms]

# Exclude identified LRGs from sample
roman_detection = np.delete(roman_detection, Red_gal_ind)
roman_truth = np.delete(roman_truth, Red_gal_ind)

mm = roman_truth['gal_star'] == 0
roman_ind_all = roman_truth['ind']
z_all = Roman_z[roman_ind_all][mm]

# Measured quantities of for all Roman galaxies
measured_qts_all = roman_detection[mm]

# Discard used variables to free up memory
del roman_detection,roman_truth,Roman_z,mm,mms,Red_gal_ind,matched_inds_in_truth
gc.collect();

In [4]:
model = ezgal.model( 'bc03_ssp_z_0.02_salp.model' )
model.set_ab_output()
model.set_normalization( 'sloan_i', 0.2, 17.85,apparent=True)
zss = np.linspace(1,2.3,200)
zf = 6
# fetch an array of redshifts out to given formation redshift
zss = model.get_zs( zf )
l_cut = np.log10(0.4)/(-0.4) + model.get_apparent_mags( zf, filters='Roman_F', zs=zss ) 

spl_limmag_ref = CubicSpline(zss, l_cut)

/root/anaconda3/envs/Redmagic/lib/python3.10/site-packages/ezgal/data/models/bc03_ssp_z_0.02_salp.model


  return 5. * num.log10(self.Dl(z) / self.pc / 10)


In [5]:
def select_galaxy (z_low,z_high):
    z_low = z_low
    z_high = z_high
    z_idx, = np.where((z_all > z_low) & (z_all < z_high))
    
    g_zs = z_all[z_idx]
    g_use, = np.where(measured_qts_all['mag_auto_%s'%'F184'][z_idx] < spl_limmag_ref(g_zs))
    mag_array = np.zeros(shape = (g_use.shape[0],4))
    fr = ['Y106','J129','H158','F184']
    for i in range(4):
        mag_array[:,i] = measured_qts_all['mag_auto_%s'%fr[i]][z_idx][g_use]
        
        
    z_idx, = np.where((zs > z_low) & (zs < z_high))
    g_zs = zs[z_idx]
    g_use, = np.where(measured_qts['mag_auto_%s'%'F184'][z_idx] < spl_limmag_ref(g_zs))
    mag_array_sub = np.zeros(shape = (g_use.shape[0],4))
    fr = ['Y106','J129','H158','F184']
    for i in range(4):
        mag_array_sub[:,i] = measured_qts['mag_auto_%s'%fr[i]][z_idx][g_use]
        
        
    return(mag_array,mag_array_sub)

def gen_training(data, Dim = 2,three_color = False):
    if Dim  == 2:
        X = np.zeros(shape = (len(data[:,-1]),2))
        X[:,0] = data[:,-1]
        X[:,1] = data[:,0]-data[:,1]
    elif Dim == 3 and three_color == False:
        X = np.zeros(shape = (len(data[:,-1]),3))
        X[:,0] = data[:,-1]
        X[:,1] = data[:,0]-data[:,1]   
        X[:,2] = data[:,1]-data[:,2]     
    elif Dim == 3 and three_color == True:
        X = np.zeros(shape = (len(data[:,-1]),3))
        X[:,0] = data[:,0]-data[:,1] 
        X[:,1] = data[:,1]-data[:,2]   
        X[:,2] = data[:,2]-data[:,3] 
    else:    
        raise ValueError('Invalid dimension!')

    return(X)

SVM

In [76]:
from sklearn.svm import LinearSVC
all,sub = select_galaxy(z_low=1.6,z_high=1.61)

X_red = gen_training(sub,Dim=3,three_color = False)
X_blue = gen_training(all,Dim=3,three_color = False)
X = np.vstack((X_red,X_blue))
Y_red = np.ones(shape = (len(sub)))
Y_blue = np.ones(shape = (len(all)))*0
Y = np.concatenate((Y_red,Y_blue))


clf = LinearSVC(C=1,max_iter=100000)
clf.fit(X, Y)
Z = lambda X,Y: (-clf.intercept_[0]-clf.coef_[0][0]*X-clf.coef_[0][1]*Y) / clf.coef_[0][2]
fig = make_subplots(rows=1, cols=2,
                    specs=[[{'type': 'surface'}, {'type': 'surface'}]])
SVM_pred = clf.predict(X)
trace1 = go.Mesh3d(x = X[:,0], y = X[:,1], z = Z(X[:,0],X[:,1]),opacity=0.66,color = 'cyan') ## for separating plane
trace2 = go.Scatter3d(x=X[:,0], y=X[:,1],z=X[:,2],mode='markers',marker = dict(size = 2,color = Y,colorscale = 'Bluered')) ## for vector plots
trace_pred = go.Scatter3d(x=X[:,0], y=X[:,1],z=X[:,2],mode='markers',marker = dict(size = 2,color = SVM_pred,colorscale = 'Bluered')) ## for vector plots
fig.add_trace(trace1,row = 1, col = 1)
fig.add_trace(trace2,row = 1, col = 1)
fig.add_trace(trace_pred,row = 1, col = 2)
fig.update_layout(
    autosize=False,
    width=1200,
    height=600
    )

In [77]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

all,sub = select_galaxy(z_low=1.60,z_high=1.61)

X_red = gen_training(sub,Dim=3,three_color = False)
X_blue = gen_training(all,Dim=3,three_color = False)
X = np.vstack((X_red,X_blue))
Y_red = np.ones(shape = (len(sub)))
Y_blue = np.ones(shape = (len(all)))*0
Y = np.concatenate((Y_red,Y_blue))
clf = LinearDiscriminantAnalysis()
clf.fit(X, Y)
X_LDA_pred = clf.predict(X)

fig = make_subplots(rows=1, cols=2,
                    specs=[[{'type': 'surface'}, {'type': 'surface'}]])

trace1 = go.Scatter3d(x=X[:,0], y=X[:,1],z=X[:,2],mode='markers',marker = dict(size = 2,color = Y,colorscale = 'Bluered')) ## for vector plots
fig.add_trace(trace1,row = 1, col = 1)
trace2 = go.Scatter3d(x=X[:,0], y=X[:,1],z=X[:,2],mode='markers',marker = dict(size = 2,color = X_LDA_pred,colorscale = 'Bluered')) ## for vector plots
fig.add_trace(trace2,row = 1, col = 2)
fig.update_layout(
    autosize=False,
    width=1200,
    height=600
    )