In [1]:
"""
For every bulk flow direction and amplitude, calculate the full sample and fit a
power law. The one with the least scatter is the best fit bulk flow.

quantities affected by bulk flow:
- observed redshift -> LX, YSZ, Mgas by bulk flow.
but there's no need to start from scratch! Just scale the quantities accodingly
"""

# -----------------------IMPORTS------------------------------------------------
import numpy as np
import pandas as pd
from numba import njit, prange, set_num_threads
import os

import sys
sys.path.append('/data1/yujiehe/anisotropy-flamingo')
import tools.constants as const
import tools.clusterfit as cf
import tools.xray_correct as xc

from astropy.cosmology import FlatLambdaCDM
cosmo = FlatLambdaCDM(H0=68.1, Om0=0.306, Ob0=0.0486) # The flamingo fiducial cosmology
# import astropy.coordinates as coord
# -----------------------CONFIGURATION------------------------------------------

# Input file is a halo catalog with lightcone data.
INPUT_FILE = '/data1/yujiehe/data/samples_in_lightcone0_with_trees_duplicate_excision_outlier_excision.csv'
OUTPUT_FILE = '/data1/yujiehe/data/fits/bulk_flow_lightcone0.csv'
OVERWRITE = True

# Relations to fit
RELATIONS = ['LX-T', 'YSZ-T', 'M-T', 'LX-YSZ', 'LX-M', 'YSZ-M'] # pick from 'LX-T', 'M-T', 'LX-YSZ', 'LX-M', 'YSZ-M', 'YSZ-T'

UBFMIN = 0 # ubf for bulk flow velocity
UBFMAX = 1000

UBF_STEP = 10
LON_STEP = 4
LAT_STEP = 2

# Nnumber of threads
N_THREADS = 8

B_STEP       = 0.003
LOGA_STEP    = 0.003
SCAT_STEP    = 0.003

C = 299792.458                  # the speed of light in km/s
FIT_RANGE = const.FIVE_MAX_RANGE_TIGHT_SCAT

# -----------------------END CONFIGURATION--------------------------------------


#@njit(parallel=True, fastmath=True)
def fit_bulk_flow(Y, X, z_obs, phi_lc, theta_lc, yname, xname,
                  B_min, B_max, scat_min, scat_max, logA_min, logA_max
                  ):
    scaling_relation = f'{yname}-{xname}'
    min_scat = 1000 # initialize a large number

    # Loop over the bulk flow direction and amplitude
    n_steps = (UBFMAX - UBFMIN)//UBF_STEP * (360//LON_STEP) * (180//LAT_STEP)
    ubf_arr = np.empty(n_steps, dtype=np.float64)
    vlon_arr = np.empty(n_steps, dtype=np.float64)
    vlat_arr = np.empty(n_steps, dtype=np.float64)
    scat_arr = np.empty(n_steps, dtype=np.float64)

    idx = 0
    for ubf in prange(UBFMIN//UBF_STEP, UBFMAX//UBF_STEP):
        ubf = ubf * UBF_STEP + UBFMIN
        for vlon in range(-180, 180, LON_STEP):
            for vlat in range(-90, 90, LAT_STEP):
                # Calculate the redshift
                angle = cf.angular_separation(phi_lc, theta_lc, vlon, vlat)

                # From: z_bf = z_obs - ubf * (1 + z_bf) * np.cos(angle) / C
                z_bf = (z_obs + ubf * np.cos(angle) / C) / (1 - ubf * np.cos(angle) / C) # the ubf convention than the paper
                
                # Calculate the angular diameter distance
                DA_zobs = cosmo.angular_diameter_distance(z_obs).value
                DA_zbf = cosmo.angular_diameter_distance(z_bf).value

                # To our fit parameters
                logY_ = cf.logY_(Y*(DA_zbf)**2/(DA_zobs)**2, z=z_bf, relation=scaling_relation)
                logX_ = cf.logX_(X, relation=scaling_relation)
                params = cf.run_fit(logY_, logX_, scat_step=SCAT_STEP,
                                    B_step=B_STEP, logA_step=LOGA_STEP,
                                    B_min=B_min,
                                    B_max=B_max,
                                    scat_min=scat_min,
                                    scat_max=scat_max,
                                    logA_min=logA_min,
                                    logA_max=logA_max,
                                    )

                # Parallel index. idx += 1 does not work! f u numba.
                idx = (ubf-UBFMIN)//UBF_STEP * (360//LON_STEP) * (180//LAT_STEP) \
                    + (vlon+180)//LON_STEP * (180//LAT_STEP) \
                    + (vlat+90)//LAT_STEP

                ubf_arr[idx] = ubf
                vlon_arr[idx] = vlon
                vlat_arr[idx] = vlat
                scat_arr[idx] = params['scat']

                # numba prange parallel loops can infer automatically for +=, *=, -=, /= 
                # idx += 1 
                # it doesn not work lol
    # print(scat_arr, np.min(scat_arr), np.max(scat_arr))

    # The best fit index
    fit_idx = np.argmin(scat_arr)

    # Save the best fit parameters
    fit_ubf = ubf_arr[fit_idx]
    fit_vlon = vlon_arr[fit_idx]
    fit_vlat = vlat_arr[fit_idx]
    min_scat = scat_arr[fit_idx]

    # For the sake of debugging also save all the fit parameters

    return fit_ubf, fit_vlon, fit_vlat, min_scat, \
        ubf_arr, vlon_arr, vlat_arr, scat_arr




import matplotlib.pyplot as plt



# Set the number of threads
set_num_threads(N_THREADS)


# Load the sample
halo_data = pd.read_csv(INPUT_FILE)

for scaling_relation in RELATIONS:

    n_clusters = cf.CONST[scaling_relation]['N']

    _ = scaling_relation.find('-')
    yname = scaling_relation[:_]
    xname = scaling_relation[_+1:]
    Y = np.array(halo_data[cf.COLUMNS[yname]][:n_clusters])
    X = np.array(halo_data[cf.COLUMNS[xname]][:n_clusters])

    # Also load the position data
    phi_lc   = np.array(halo_data['phi_on_lc'][:n_clusters])
    theta_lc = np.array(halo_data['theta_on_lc'][:n_clusters])
        
    # the cosmological redshift from lightcone (no peculiar velocity attached)
    z_obs = np.array(halo_data['ObservedRedshift'][:n_clusters])

    for zmax in np.arange(0.03, np.max(z_obs)+0.01, 0.05):
        zmask = (z_obs < zmax)
        ubf, lon, lat, min_scat, ubf_arr, vlon_arr, vlat_arr, scat_arr = fit_bulk_flow(Y=Y[zmask], X=X[zmask], 
                                        z_obs=z_obs[zmask], 
                                        phi_lc=phi_lc[zmask], theta_lc=theta_lc[zmask],
                                        yname=yname, xname=xname,
                                        **FIT_RANGE[scaling_relation])
        
        # Plot the results
        plt.scatter(ubf_arr, scat_arr, s=1)
        plt.axhline(min_scat, color='red')
        plt.axvline(ubf, color='red')
        plt.xlabel('ubf')
        plt.ylabel('scat')
        break
    
    break

plt.show()

KeyboardInterrupt: 

In [None]:
ubf, lon, lat

(500.0, 80.0, 30.0)