In [1]:
import torch as tc
tc.set_default_dtype(tc.float64)
# tc.set_default_tensor_type(tc.DoubleTensor)

import numpy as np
import scipy.constants as sc
from scipy.interpolate import interp1d
import camb
from copy import deepcopy

import time

import matplotlib.pyplot as plt
%matplotlib qt5

In [6]:
class Cl_kSZ2_HI2_test():

    def __init__(self, z_array, Tb = 1.8e-4, H0 = 67.75, ombh2 = 0.022):
        
        ##################################################s
        # Define the cosmological parameters
        params = camb.CAMBparams()
        params.set_cosmology(H0=H0, ombh2=ombh2)
        params.set_matter_power(redshifts = z_array, kmax=10, nonlinear=True)
        results = camb.get_results(params)
        backgrounds = camb.get_background(params)

        # Calculate the background evolution and results
        kh, z, Pm = results.get_matter_power_spectrum(minkh=1e-4, maxkh=10, npoints = 500, var1='delta_tot', var2='delta_tot')
        Xe_of_z = np.array(backgrounds.get_background_redshift_evolution(z_array, ['x_e'], format='array')).flatten()
        chi_of_z = np.array(results.comoving_radial_distance(z_array))

        ##################################################
        # Store the variables that we are interested in

        # Constant scalars and arrays
        self.TCMB = params.TCMB     # CMB temperature 2.7K
        self.Tb = Tb                # HI brightness temperature, in unite mK
        self.kh_array = kh          # Total kh array that we are interested in
        self.z_list = z             # Total redshift array that we are interested in
        
        # Adjust order for interpolations
        if len(z)<=5 : itp_order = 'linear'
        else: itp_order = 'cubic'

        # Interpolation functions of z
        self.H_of_z = backgrounds.hubble_parameter                  # Hubble parameter, in unit 
        self.f_of_z = self.Growth_Rate_of_z(backgrounds, itp_order) # Logarithmic growth rate
        self.Xe = interp1d(z_array, Xe_of_z, kind = itp_order)      # Ionized franction Xe
        self.chi = interp1d(z_array, chi_of_z, kind = itp_order)    # Comoving distance chi

        # interpolation functions of k and z
        self.Pm = Pm # interp2d_tc(kh, z, Pm)             # Matter power spectrum

        # save the cosmological model, for checking the result
        self.results = results
        self.BGEvolution = backgrounds
        
    def Growth_Rate_of_z(self, backgrounds, itp_order):
        '''
        Get the interpolation function for logarithmic growth rate f, 
        defined as f:=d(ln D)/d(ln a)
        '''
        # Since the growth rate almost does not vary with momentum scale, we fix kh=0.01 to get f
        f_of_z = backgrounds.get_redshift_evolution([0.01], self.z_list, ['growth'])
        return interp1d(self.z_list, np.array(f_of_z).flatten(), kind = itp_order)

    def Pm_interpolation(self, x, y, Mode='bilinear'):
        return interp2d_torch(tc.tensor(self.kh_array), tc.tensor(self.z_list), tc.tensor(self.Pm), x, y, mode=Mode)


    def dCl(self, z, l, l1, l_min = 1, l_max = 1000, N_l = 1000, N_theta = 81):
        """Evaluare the integrand, dCl, as a function of z, l and l_1.

        Here we sum over theta_1, l_2, and theta_2. To get the final C_l result, one has to integrate dCl over chi and l_1, for a given l.

        Input
        -----
        `z` : float. 
            The redshift. 

        `l` : float. 
            The moment for C_l. Don't need to be an integer since we are in flat-sky approximation.

        `l1` : float.
            The norm of \\vec{l}_1.

        """
        ##################################################
        # Redefine the inputs as tc.tensors
        z = tc.tensor([z], dtype=tc.float64)
        l = tc.tensor([l], dtype=tc.float64)
        l1 = tc.tensor([l1], dtype=tc.float64)

        # # Make the mesh grid for theta_1, |l_2|, and theta_2
        # t1_list = tc.arange(N_theta, dtype=tc.float64) * tc.pi / N_theta
        # t2_list = deepcopy(t1_list)
        # l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64)
        # l2, t1, t2 = tc.meshgrid(l2_list, t1_list, t2_list, indexing='ij')

        # Make the mesh grid for theta_1, |l_2|, and theta_2
        t1 = tc.tensor([tc.pi / 3.], dtype=tc.float64)
        t2_list = tc.arange(N_theta, dtype=tc.float64) * 2 * tc.pi / N_theta
        l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64) # 10**tc.linspace(np.log10(l_min), np.log10(l_max), N_l, dtype=tc.float64)
        l2, t2 = tc.meshgrid(l2_list, t2_list, indexing='ij')

        # Pre-define useful varibales and constants
        lsquare = l**2
        l1square = l1**2
        l2square = l2**2

        l_dot_l1 = Polar_dot(l, 0., l1, t1)
        l_dot_l2 = Polar_dot(l, 0., l2, t2)
        l1_dot_l2 = Polar_dot(l1, t1, l2, t2)

        l_m_l1_norm = tc.sqrt( lsquare + l1square - 2*l_dot_l1 )
        l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 )
        l1_p_l2_norm = tc.sqrt( l1square + l2square + 2*l1_dot_l2 )
        l_m_l1_p_l2_norm = tc.sqrt( lsquare + l1square + l2square - 2*l_dot_l1 + 2*l_dot_l2 - 2*l1_dot_l2 )

        theta_l_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l1_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l_m_l1_p_l2 = Evaluate_angle(3, l, tc.tensor([0.]), -l1, t1, l2, t2)


        Z_MEAN = 0.45 # mean redshift for HI observation
        FREQ_HI = 1420. # in unit MHz
        SIGMA_HI = 0.0115 * 1000. * (1. + Z_MEAN) / FREQ_HI
        SIGMA_KSZ = deepcopy(SIGMA_HI)

        ##################################################
        # Evaluate the integrand
        # Initialization
        dCl_tot = tc.zeros_like(t2)

        # Contribution originate from each term in Wick Theorem
        # Term 5 
        dCl = - tc.cos(theta_l1_p_l2 - t2) # - (l_dot_l2 + l2square) / l_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'e')
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI')
        dCl *= self.Cross_Power(z, l2, 'v', 'HI')
        dCl_tot += dCl
        # Term 6 
        dCl = - tc.cos(theta_l1_p_l2 - t2) # - (l_dot_l2 + l2square) / l_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'e')
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI')
        dCl *= self.Cross_Power(z, l2, 'v', 'HI')
        dCl_tot += dCl
        # Term 8 
        dCl = tc.cos(theta_l_p_l2 - theta_l1_p_l2) # (l_dot_l1 + l_dot_l2 + l1_dot_l2 + l2square) / l_p_l2_norm / l1_p_l2_norm
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'v')
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI')
        dCl *= self.Cross_Power(z, l2, 'v', 'HI')
        dCl_tot += dCl
        # Term 9
        dCl = tc.cos(theta_l_m_l1_p_l2 - t2) # (l_dot_l2 - l1_dot_l2 + l2square) / l_m_l1_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'v')
        dCl *= self.Cross_Power(z, l2, 'e', 'HI')
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI')
        dCl_tot += dCl
        # Term 10
        dCl = tc.cos(theta_l1_p_l2 - t2)# (l1_dot_l2 + l2square) / l1_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI')
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'v')
        dCl *= self.Cross_Power(z, l2, 'v', 'HI')
        dCl_tot += dCl
        # Term 11
        dCl = -1.
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI')
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'v', 'v')
        dCl *= self.Cross_Power(z, l2, 'e', 'HI')
        dCl_tot += dCl
        # Term 13
        dCl = tc.cos(theta_l_m_l1_p_l2 - theta_l_p_l2) # (lsquare + l2square + 2*l_dot_l2 - l_dot_l1 - l1_dot_l2) / l_p_l2_norm / l_m_l1_p_l2_norm
        dCl *= self.Cross_Power(z, l2, 'e', 'HI')
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'v')
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'v', 'HI')
        dCl_tot += dCl
        # Term 14
        dCl = -1.
        dCl = self.Cross_Power(z, l2, 'e', 'HI')
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'v', 'v')
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'HI')
        dCl_tot += dCl

        dCl_terms = dCl_tot.clone().detach()

        # The beam functions
        dCl_tot *= self.Beam_kSZ(l_m_l1_norm, SIGMA_KSZ) * self.Beam_kSZ(l1, SIGMA_KSZ) * self.Beam_HI(l_p_l2_norm, SIGMA_HI) * self.Beam_HI(l2, SIGMA_HI)
        # The window functions and the metric determinant contribution 
        dCl_tot *= l1 * l2 * self.F_kSZ(z)**2 * self.G_HI(z)**2 * self.dchi_by_dz(z)

        dCl_res = tc.sum(dCl_tot) * t2_list[1] * (l_max - l_min) / (N_l - 1)

        return l2, t2, dCl_res, dCl_tot, dCl_terms
    

    def dchi_by_dz(self, z):
        return sc.c / self.H_of_z(z)

    def F_kSZ(self, z):
        return tc.tensor(self.Xe(z)) * (1+z)**2 / tc.tensor(self.chi(z))**2
    
    def G_HI(self, z):
        return 1 / (self.z_list[-1] - self.z_list[0]) / self.chi(z)**2

    def Beam_kSZ(self, l, singma_kSZ):
        return tc.exp(-l**2 * singma_kSZ**2 / 2)
    
    def Beam_HI(self, l, singma_HI):
        return tc.exp(-l**2 * singma_HI**2 / 2)

    def Cross_Power(self, z, L, b1, b2, cut_off= tc.tensor([2.])):
        
        chi = self.chi(z)
        kh = L / chi
        kh_cutoff = cut_off / chi
        shape = kh.shape

        if b1 not in ['e', 'v', 'HI'] or b2 not in ['e', 'v', 'HI']:
            print('b1 and b2 must be "e", "v" or "HI"')
            raise
        else:
            if b1 == 'e': B1 = self.bias_electron
            elif b1 == 'v': B1 = self.bias_velocity
            elif b1 == 'HI': B1 = self.bias_HI

            if b2 == 'e': B2 = self.bias_electron
            elif b2 == 'v': B2 = self.bias_velocity
            elif b2 == 'HI': B2 = self.bias_HI

        mesh = tc.where(L <= cut_off)
        P = (self.Pm_interpolation(kh.flatten().clone().detach(), tc.tensor([z]))).reshape(shape) * B1(kh, z) * B2(kh, z)

        if b1=='v' and b2=='v' :
            P[mesh] = 2. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
        elif b1=='v' or b2=='v' :
            P[mesh] = self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
        else:
            P[mesh] = 2./3. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)

        return P


    def bias_electron(self, kh, z): # TO BE REVISED
        return kh/kh
    
    def bias_velocity(self, kh, z):
        b = 1/(1+z) * self.H_of_z(z) * self.f_of_z(z) / kh
        return b
    
    def bias_HI(self, kh, z): # TO BE REVISED
        return kh/kh
    
    


def Polar_dot(lx, thetax, ly, thetay):
    return lx * ly * np.cos(thetax - thetay)

def Evaluate_angle(N_vec, *vectors):
    if 2*N_vec != len(vectors):
        print('The input N_vec does not match the number of input vectors')
        raise
    else:
        # We need to do some adjustment on vectors to match the broadcast rule
        # In order to keep vectors unchanged, make a copy of them for calculation
        vec = deepcopy(vectors)

        l_x = 0.
        l_y = 0.
        for i in range(N_vec):
            # if len(vectors[2*i]) == 1: vec[2*i] = np.array(vectors[2*i])
            # if len(vectors[2*i+1]) == 1: vec[2*i+1] = np.array(vectors[2*i+1])
            # print('shape1', vec[2*i].shape, '   shape2', vec[2*i+1].shape, '   shape3', np.cos(vec[2*i+1]).shape)
            l_x = l_x + vec[2*i] * np.cos(vec[2*i+1])
            l_y = l_y + vec[2*i] * np.sin(vec[2*i+1])
            # print(l_x.shape, '   ', l_y.shape)
        
        return np.arctan2(l_y, l_x)
    
def interp2d_torch(x, y, z, x_new, y_new, mode='bilinear'):
    '''
    Interpolates 2D data over a grid using PyTorch, mimicking `scipy.interpolate.interp2d`.
    
    Parameters:
        x (torch.Tensor): 1D tensor of x coordinates (size: N).
        y (torch.Tensor): 1D tensor of y coordinates (size: M).
        z (torch.Tensor): 2D tensor of shape (M, N) representing the grid values.
        x_new (torch.Tensor): 1D tensor of new x coordinates for interpolation (size: N').
        y_new (torch.Tensor): 1D tensor of new y coordinates for interpolation (size: M').
        mode (str): Interpolation mode ('bilinear', 'nearest'). Defaults to 'bilinear'.
        
    Returns:
        torch.Tensor: Interpolated values at new (x_new, y_new) grid points.
    '''
    
    # Ensure the input tensors are of the correct shape
    z = z.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions (1, 1, M, N)
    
    # Create the meshgrid for new points (x_new, y_new)
    x_new_grid, y_new_grid = tc.meshgrid(x_new, y_new, indexing='ij')
    
    # Normalize new grid coordinates to range [-1, 1] (for grid_sample)
    x_min, x_max = x.min(), x.max()
    y_min, y_max = y.min(), y.max()
    
    x_new_norm = 2 * (x_new_grid - x_min) / (x_max - x_min) - 1
    y_new_norm = 2 * (y_new_grid - y_min) / (y_max - y_min) - 1
    
    # Stack and reshape the new coordinates into (1, H', W', 2) for grid_sample
    grid = tc.stack((x_new_norm, y_new_norm), dim=-1).unsqueeze(0)
    
    # Perform the interpolation using grid_sample
    interpolated = tc.nn.functional.grid_sample(z, grid, mode=mode, align_corners=True)
    
    # Remove the batch and channel dimensions and return the result
    return interpolated.squeeze()

In [2]:
class Cl_kSZ2_HI2():

    def __init__(self, z_array, Tb = 1.8e-4, H0 = 67.75, ombh2 = 0.022):
        
        ##################################################s
        # Define the cosmological parameters
        params = camb.CAMBparams()
        params.set_cosmology(H0=H0, ombh2=ombh2)
        params.set_matter_power(redshifts = z_array, kmax=10, nonlinear=True)
        results = camb.get_results(params)
        backgrounds = camb.get_background(params)

        # Calculate the background evolution and results
        kh, z, Pm = results.get_matter_power_spectrum(minkh=1e-4, maxkh=10, npoints = 500, var1='delta_tot', var2='delta_tot')
        Xe_of_z = np.array(backgrounds.get_background_redshift_evolution(z_array, ['x_e'], format='array')).flatten()
        chi_of_z = np.array(results.comoving_radial_distance(z_array))

        ##################################################
        # Store the variables that we are interested in

        # Constant scalars and arrays
        self.TCMB = params.TCMB         # CMB temperature 2.7K
        self.Tb = Tb                    # HI brightness temperature, in unite mK
        self.kh_list = kh               # Total kh array that we are interested in
        self.kh_array = tc.tensor(kh)
        self.z_list = z                 # Total redshift array that we are interested in
        self.z_array = tc.tensor(z)
        
        # Functions of redshift
        self.H_of_z = tc.tensor(backgrounds.hubble_parameter(z))        # Hubble parameter over c, in unit h/Mpc
        self.f_of_z = tc.tensor(                                        # Logarithmic growth rate
            backgrounds.get_redshift_evolution([0.01], z, ['growth'])
            ).flatten()
        self.Xe_of_z = tc.tensor(Xe_of_z)                               # Ionized franction Xe
        self.chi_of_z = tc.tensor(chi_of_z)                             # Comoving distance chi
        self.F_kSZ = self.Xe_of_z * (1+z)**2 / self.chi_of_z**2         # F_kSZ, propto visibility function of kSZ
        self.G_HI = 1 / (z[-1] - z[0]) / self.chi_of_z**2               # G_HI, proptp window function of HI

        # interpolation functions of k and z
        self.Pm = tc.tensor(Pm) # interp2d_torch(kh, z, Pm)             # Matter power spectrum

        # Interpolation functions for matter power spectrum
        # adding infrared asymptotic behavior (P proportional to k)
        N_add = 5
        self.kh_array_itp = tc.hstack([tc.linspace(0., kh[0], N_add), tc.tensor(kh[1:])])
        Pm_infared = tc.linspace(0., kh[0], N_add).repeat(len(z)).reshape([len(z), N_add]) * Pm[:, :1] / kh[0]
        self.Pm_itp = tc.hstack([Pm_infared, tc.tensor(Pm[:, 1:])])

        # save the cosmological model, for checking the result
        self.results = results
        self.BGEvolution = backgrounds
        
    def Growth_Rate_of_z(self, backgrounds, itp_order):
        '''
        Get the interpolation function for logarithmic growth rate f, 
        defined as f:=d(ln D)/d(ln a)
        '''
        # Since the growth rate almost does not vary with momentum scale, we fix kh=0.01 to get f
        f_of_z = backgrounds.get_redshift_evolution([0.01], self.z_list, ['growth'])
        return interp1d(self.z_list, np.array(f_of_z).flatten(), kind = itp_order)

    def Power_matter_1d(self, kh, zindex, Mode='cubic'):
        return torch_interp1d(self.kh_array_itp, (self.Pm_itp)[zindex], kh)

    def Pm_interpolation(self, x, y, Mode='bilinear'):
        return interp2d_torch(self.kh_array, self.z_array, self.Pm, x, y, mode=Mode)

    def dCl(self, z, l, l1, l_min = 1, l_max = 1000, N_l = 1000, N_theta = 81):
        """Evaluare the integrand, dCl, as a function of z, l and l_1.

        Here we sum over theta_1, l_2, and theta_2. To get the final C_l result, one has to integrate dCl over chi and l_1, for a given l.

        Input
        -----
        `z` : float. 
            The redshift. 

        `l` : float. 
            The moment for C_l. Don't need to be an integer since we are in flat-sky approximation.

        `l1` : float.
            The norm of \\vec{l}_1.

        """
        ##################################################
        # Redefine the inputs as tc.tensors
        z = tc.tensor([z], dtype=tc.float64)
        l = tc.tensor([l], dtype=tc.float64)
        l1 = tc.tensor([l1], dtype=tc.float64)

        # Make the mesh grid for theta_1, |l_2|, and theta_2
        t1_list = tc.arange(N_theta, dtype=tc.float64) * tc.pi / N_theta
        t2_list = deepcopy(t1_list)
        l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64)
        l2, t1, t2 = tc.meshgrid(l2_list, t1_list, t2_list, indexing='ij')

        # Make the mesh grid for theta_1, |l_2|, and theta_2
        # t1 = tc.tensor([tc.pi / 3.], dtype=tc.float64)
        # t2_list = tc.arange(N_theta, dtype=tc.float64) * 2 * tc.pi / N_theta
        # l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64) # 10**tc.linspace(np.log10(l_min), np.log10(l_max), N_l, dtype=tc.float64)
        # l2, t2 = tc.meshgrid(l2_list, t2_list, indexing='ij')

        # Pre-define useful varibales and constants
        lsquare = l**2
        l1square = l1**2
        l2square = l2**2

        l_dot_l1 = Polar_dot(l, 0., l1, t1)
        l_dot_l2 = Polar_dot(l, 0., l2, t2)
        l1_dot_l2 = Polar_dot(l1, t1, l2, t2)

        l_m_l1_norm = tc.sqrt( lsquare + l1square - 2*l_dot_l1 )
        l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 )
        l1_p_l2_norm = tc.sqrt( l1square + l2square + 2*l1_dot_l2 )
        l_m_l1_p_l2_norm = tc.sqrt( lsquare + l1square + l2square - 2*l_dot_l1 + 2*l_dot_l2 - 2*l1_dot_l2 )

        del(l_dot_l1)
        del(l_dot_l2)
        del(l1_dot_l2)

        theta_l_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l1_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l_m_l1_p_l2 = Evaluate_angle(3, l, tc.tensor([0.]), -l1, t1, l2, t2)

        

        Z_MEAN = 0.45 # mean redshift for HI observation
        FREQ_HI = 1420. # in unit MHz
        SIGMA_HI = 0.0115 * 1000. * (1. + Z_MEAN) / FREQ_HI
        SIGMA_KSZ = deepcopy(SIGMA_HI)

        ##################################################
        # Evaluate the integrand
        # Initialization
        dCl_tot = tc.zeros_like(t2)

        # Contribution originate from each term in Wick Theorem
        # Term 5 
        dCl = - tc.cos(theta_l1_p_l2 - t2) # - (l_dot_l2 + l2square) / l_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'e')
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI')
        dCl *= self.Cross_Power(z, l2, 'v', 'HI')
        dCl_tot += dCl
        # Term 6 
        dCl = - tc.cos(theta_l1_p_l2 - t2) # - (l_dot_l2 + l2square) / l_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'e')
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI')
        dCl *= self.Cross_Power(z, l2, 'v', 'HI')
        dCl_tot += dCl
        # Term 8 
        dCl = tc.cos(theta_l_p_l2 - theta_l1_p_l2) # (l_dot_l1 + l_dot_l2 + l1_dot_l2 + l2square) / l_p_l2_norm / l1_p_l2_norm
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'v')
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI')
        dCl *= self.Cross_Power(z, l2, 'v', 'HI')
        dCl_tot += dCl
        # Term 9
        dCl = tc.cos(theta_l_m_l1_p_l2 - t2) # (l_dot_l2 - l1_dot_l2 + l2square) / l_m_l1_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'v')
        dCl *= self.Cross_Power(z, l2, 'e', 'HI')
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI')
        dCl_tot += dCl
        # Term 10
        dCl = tc.cos(theta_l1_p_l2 - t2)# (l1_dot_l2 + l2square) / l1_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI')
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'v')
        dCl *= self.Cross_Power(z, l2, 'v', 'HI')
        dCl_tot += dCl
        # Term 11
        dCl = -1.
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI')
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'v', 'v')
        dCl *= self.Cross_Power(z, l2, 'e', 'HI')
        dCl_tot += dCl
        # Term 13
        dCl = tc.cos(theta_l_m_l1_p_l2 - theta_l_p_l2) # (lsquare + l2square + 2*l_dot_l2 - l_dot_l1 - l1_dot_l2) / l_p_l2_norm / l_m_l1_p_l2_norm
        dCl *= self.Cross_Power(z, l2, 'e', 'HI')
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'v')
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'v', 'HI')
        dCl_tot += dCl
        # Term 14
        dCl = -1.
        dCl = self.Cross_Power(z, l2, 'e', 'HI')
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'v', 'v')
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'HI')
        dCl_tot += dCl

        # The beam functions
        dCl_tot *= self.Beam_kSZ(l_m_l1_norm, SIGMA_KSZ) * self.Beam_kSZ(l1, SIGMA_KSZ) * self.Beam_HI(l_p_l2_norm, SIGMA_HI) * self.Beam_HI(l2, SIGMA_HI)
        # The window functions and the metric determinant contribution 
        dCl_tot *= l1 * l2 * self.F_kSZ(z)**2 * self.G_HI(z)**2 * self.dchi_by_dz(z)

        print(dCl_tot.shape)

        dCl_res = tc.sum(dCl_tot) * t2_list[1] * (l_max - l_min) / (N_l - 1)

        return dCl_res
    
    def dchi_by_dz(self, z):
        return sc.c / self.H_of_z(z)

    # def F_kSZ(self, z):
    #     return tc.tensor(self.Xe(z)) * (1+z)**2 / tc.tensor(self.chi(z))**2
    
    # def G_HI(self, z):
    #     return 1 / (self.z_list[-1] - self.z_list[0]) / self.chi(z)**2

    def Beam_kSZ(self, l, singma_kSZ):
        return tc.exp(-l**2 * singma_kSZ**2 / 2)
    
    def Beam_HI(self, l, singma_HI):
        return tc.exp(-l**2 * singma_HI**2 / 2)

    '''def Cross_Power(self, z, L, b1, b2, cut_off= tc.tensor([2.])):
        
        chi = self.chi(z)
        kh = L / chi
        kh_cutoff = cut_off / chi
        shape = kh.shape

        if b1 not in ['e', 'v', 'HI'] or b2 not in ['e', 'v', 'HI']:
            print('b1 and b2 must be "e", "v" or "HI"')
            raise
        else:
            if b1 == 'e': B1 = self.bias_electron
            elif b1 == 'v': B1 = self.bias_velocity
            elif b1 == 'HI': B1 = self.bias_HI

            if b2 == 'e': B2 = self.bias_electron
            elif b2 == 'v': B2 = self.bias_velocity
            elif b2 == 'HI': B2 = self.bias_HI

        mesh = tc.where(L <= cut_off)
        P = (self.Pm_interpolation(kh.flatten().clone().detach(), tc.tensor([z]))).reshape(shape) * B1(kh, z) * B2(kh, z)

        if b1=='v' and b2=='v' :
            P[mesh] = 2. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
        elif b1=='v' or b2=='v' :
            P[mesh] = self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
        else:
            P[mesh] = 2./3. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)

        return P'''

    def bias_electron(self, kh, z): # TO BE REVISED
        return kh/kh
    
    def bias_velocity(self, kh, z, cut_off = tc.tensor([1e-6], dtype=tc.float64)):
        z_dependence = 1/(1+z) * self.H_of_z(z) * self.f_of_z(z)
        b = tc.where(kh > cut_off, z_dependence / kh, z_dependence / cut_off)
        return b
    
    def bias_HI(self, kh, z): # TO BE REVISED
        return kh/kh


    
    def dCl_test(self, z, l, l1, l_min = 1, l_max = 1000, N_l = 1000, N_theta = 81):
        """Evaluare the integrand, dCl, as a function of z, l and l_1.

        Here we sum over theta_1, l_2, and theta_2. To get the final C_l result, one has to integrate dCl over chi and l_1, for a given l.

        Input
        -----
        `z` : float. 
            The redshift. 

        `l` : float. 
            The moment for C_l. Don't need to be an integer since we are in flat-sky approximation.

        `l1` : float.
            The norm of \\vec{l}_1.

        """
        ##################################################

        ti = time.time()

        Z_MEAN = 0.45 # mean redshift for HI observation
        FREQ_HI = 1420. # in unit MHz
        SIGMA_HI = 0.0115 * 1000. * (1. + Z_MEAN) / FREQ_HI
        SIGMA_KSZ = deepcopy(SIGMA_HI)

        # Redefine the inputs as tc.tensors
        z = tc.tensor([z], dtype=tc.float64)
        l = tc.tensor([l], dtype=tc.float64)
        l1 = tc.tensor([l1], dtype=tc.float64)

        # Make the mesh grid for theta_1, |l_2|, and theta_2
        t1_list = tc.arange(N_theta, dtype=tc.float64) * tc.pi / N_theta
        t2_list = deepcopy(t1_list)
        l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64)
        l2, t1, t2 = tc.meshgrid(l2_list, t1_list, t2_list, indexing='ij')

        print('Making meshgrid', time.time() - ti)

        # Make the mesh grid for theta_1, |l_2|, and theta_2
        # t1 = tc.tensor([tc.pi / 3.], dtype=tc.float64)
        # t2_list = tc.arange(N_theta, dtype=tc.float64) * 2 * tc.pi / N_theta
        # l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64) # 10**tc.linspace(np.log10(l_min), np.log10(l_max), N_l, dtype=tc.float64)
        # l2, t2 = tc.meshgrid(l2_list, t2_list, indexing='ij')

        # Pre-define useful varibales and constants
        lsquare = l**2
        l1square = l1**2
        l2square = l2**2

        l_dot_l1 = Polar_dot(l, 0., l1, t1)
        l_dot_l2 = Polar_dot(l, 0., l2, t2)
        l1_dot_l2 = Polar_dot(l1, t1, l2, t2)
        print('Evaluate dot product ', time.time() - ti)

        l_m_l1_norm = tc.sqrt( lsquare + l1square - 2*l_dot_l1 )
        l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 )
        l1_p_l2_norm = tc.sqrt( l1square + l2square + 2*l1_dot_l2 )
        l_m_l1_p_l2_norm = tc.sqrt( lsquare + l1square + l2square - 2*l_dot_l1 + 2*l_dot_l2 - 2*l1_dot_l2 )
        print('Evaluate norm ', time.time() - ti)

        del(l_dot_l1)
        del(l_dot_l2)
        del(l1_dot_l2)
        print('Delete the redundant ', time.time() - ti)

        theta_l_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l1_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l_m_l1_p_l2 = Evaluate_angle(3, l, tc.tensor([0.]), -l1, t1, l2, t2)
        print('Evaluate theta ', time.time() - ti)

        # define the mesh
        cut_off = tc.tensor([2.])
        mesh_l1_p_l2_norm = tc.where(l1_p_l2_norm <= cut_off)
        mesh_l_p_l2_norm = tc.where(l_p_l2_norm <= cut_off)
        mesh_l2 = tc.where(l2 <= cut_off)
        mesh_l_m_l1_p_l2_norm = tc.where(l_m_l1_p_l2_norm <= cut_off)
        print('Define cut-off mesh ', time.time() - ti)
        print('   ')

        ##################################################
        # Evaluate the integrand
        # Initialization
        dCl_tot = tc.zeros_like(t2)

        # Contribution originate from each term in Wick Theorem
        print('Term 5')# Term 5 
        dCl = - tc.cos(theta_l1_p_l2 - t2) # - (l_dot_l2 + l2square) / l_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'e', mesh=mesh_l1_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI', mesh=mesh_l_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l2, 'v', 'HI', mesh=mesh_l2, cut_off=cut_off)
        dCl_tot += dCl
        print('Term 6')# Term 6 
        dCl = - tc.cos(theta_l1_p_l2 - t2) # - (l_dot_l2 + l2square) / l_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'e', mesh=mesh_l_m_l1_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI', mesh=mesh_l_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l2, 'v', 'HI', mesh=mesh_l2, cut_off=cut_off)
        dCl_tot += dCl
        print('Term 8')# Term 8 
        dCl = tc.cos(theta_l_p_l2 - theta_l1_p_l2) # (l_dot_l1 + l_dot_l2 + l1_dot_l2 + l2square) / l_p_l2_norm / l1_p_l2_norm
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'v', mesh=mesh_l_m_l1_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI', mesh=mesh_l_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l2, 'v', 'HI', mesh=mesh_l2, cut_off=cut_off)
        dCl_tot += dCl
        print('Term 9')# Term 9
        dCl = tc.cos(theta_l_m_l1_p_l2 - t2) # (l_dot_l2 - l1_dot_l2 + l2square) / l_m_l1_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'v', mesh=mesh_l1_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l2, 'e', 'HI', mesh=mesh_l2, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'v', 'HI', mesh=mesh_l_p_l2_norm, cut_off=cut_off)
        dCl_tot += dCl
        print('Term 10')# Term 10
        dCl = tc.cos(theta_l1_p_l2 - t2)# (l1_dot_l2 + l2square) / l1_p_l2_norm / l2
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI', mesh=mesh_l_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'v', mesh=mesh_l1_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l2, 'v', 'HI', mesh=mesh_l2, cut_off=cut_off)
        dCl_tot += dCl
        print('Term 11')# Term 11
        dCl = -1.
        dCl *= self.Cross_Power(z, l_p_l2_norm, 'e', 'HI', mesh=mesh_l_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'v', 'v', mesh=mesh_l1_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l2, 'e', 'HI', mesh=mesh_l2, cut_off=cut_off)
        dCl_tot += dCl
        print('Term 13')# Term 13
        dCl = tc.cos(theta_l_m_l1_p_l2 - theta_l_p_l2) # (lsquare + l2square + 2*l_dot_l2 - l_dot_l1 - l1_dot_l2) / l_p_l2_norm / l_m_l1_p_l2_norm
        dCl *= self.Cross_Power(z, l2, 'e', 'HI', mesh=mesh_l2, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'e', 'v', mesh=mesh_l_m_l1_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'v', 'HI', mesh=mesh_l1_p_l2_norm, cut_off=cut_off)
        dCl_tot += dCl
        print('Term 14')# Term 14
        dCl = -1.
        dCl = self.Cross_Power(z, l2, 'e', 'HI', mesh=mesh_l2, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l_m_l1_p_l2_norm, 'v', 'v', mesh=mesh_l_m_l1_p_l2_norm, cut_off=cut_off)
        dCl *= self.Cross_Power(z, l1_p_l2_norm, 'e', 'HI', mesh=mesh_l1_p_l2_norm, cut_off=cut_off)
        dCl_tot += dCl

        tf = time.time()
        # The beam functions
        dCl_tot *= self.Beam_kSZ(l_m_l1_norm, SIGMA_KSZ) * self.Beam_kSZ(l1, SIGMA_KSZ) * self.Beam_HI(l_p_l2_norm, SIGMA_HI) * self.Beam_HI(l2, SIGMA_HI)
        # The window functions and the metric determinant contribution 
        dCl_tot *= l1 * l2 * self.F_kSZ(z)**2 * self.G_HI(z)**2 * self.dchi_by_dz(z)
        print('Taking in beams ', time.time() - tf)

        # print(dCl_tot.shape)

        dCl_res = tc.sum(dCl_tot) * t2_list[1] * (l_max - l_min) / (N_l - 1)
        print('Summation ', time.time() - tf)

        return dCl_res
    
    def Cross_Power_i(self, z, L, b1, b2, mesh, cut_off):
        
        chi = self.chi(z)
        kh = L / chi
        kh_cutoff = cut_off / chi
        shape = kh.shape

        if b1 not in ['e', 'v', 'HI'] or b2 not in ['e', 'v', 'HI']:
            print('b1 and b2 must be "e", "v" or "HI"')
            raise
        else:
            if b1 == 'e': B1 = self.bias_electron
            elif b1 == 'v': B1 = self.bias_velocity
            elif b1 == 'HI': B1 = self.bias_HI

            if b2 == 'e': B2 = self.bias_electron
            elif b2 == 'v': B2 = self.bias_velocity
            elif b2 == 'HI': B2 = self.bias_HI

        P = (self.Pm_interpolation(kh.flatten().clone().detach(), tc.tensor([z]))).reshape(shape) * B1(kh, z) * B2(kh, z)

        if b1=='v' and b2=='v' :
            P[mesh] = 2. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
        elif b1=='v' or b2=='v' :
            P[mesh] = self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
        else:
            P[mesh] = 2./3. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)

        return P
          
    def Cross_Power(self, z, L, b1, b2, mesh, cut_off):
        
        t0 = time.time()
        
        chi = self.chi(z)
        kh = L / chi
        kh_cutoff = cut_off / chi
        shape = kh.shape

        if b1 not in ['e', 'v', 'HI'] or b2 not in ['e', 'v', 'HI']:
            print('b1 and b2 must be "e", "v" or "HI"')
            raise
        else:
            if b1 == 'e': B1 = self.bias_electron
            elif b1 == 'v': B1 = self.bias_velocity
            elif b1 == 'HI': B1 = self.bias_HI

            print('load in b1', time.time() - t0)

            if b2 == 'e': B2 = self.bias_electron
            elif b2 == 'v': B2 = self.bias_velocity
            elif b2 == 'HI': B2 = self.bias_HI

            print('load in b2', time.time() - t0)

        itp = self.Pm_interpolation(kh.flatten().clone().detach(), tc.tensor([z])).reshape(shape)
        print('Interpolation', time.time() - t0)

        P = itp * B1(kh, z) * B2(kh, z)
        print('Multiply ', time.time() - t0)

        if b1=='v' and b2=='v' :
            P[mesh] = 2. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
            print('Mesh re-compute time', time.time() - t0)
        elif b1=='v' or b2=='v' :
            P[mesh] = self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
            print('Mesh re-compute time', time.time() - t0)
        else:
            P[mesh] = 2./3. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
            print('Mesh re-compute time', time.time() - t0)

        print('   ')
        return P



def Polar_dot(lx, thetax, ly, thetay):
    return lx * ly * np.cos(thetax - thetay)

def Evaluate_angle(N_vec, *vectors):

    if 2*N_vec != len(vectors):
        print('The input N_vec does not match the number of input vectors')
        raise
    else:
        # We need to do some adjustment on vectors to match the broadcast rule
        # In order to keep vectors unchanged, make a copy of them for calculation
        vec = deepcopy(vectors)

        l_x = 0.
        l_y = 0.
        for i in range(N_vec):
            # if len(vectors[2*i]) == 1: vec[2*i] = np.array(vectors[2*i])
            # if len(vectors[2*i+1]) == 1: vec[2*i+1] = np.array(vectors[2*i+1])
            # print('shape1', vec[2*i].shape, '   shape2', vec[2*i+1].shape, '   shape3', np.cos(vec[2*i+1]).shape)
            l_x = l_x + vec[2*i] * np.cos(vec[2*i+1])
            l_y = l_y + vec[2*i] * np.sin(vec[2*i+1])
            # print(l_x.shape, '   ', l_y.shape)
        
        return np.arctan2(l_y, l_x)
    
def interp2d_torch(x, y, z, x_new, y_new, mode='bilinear'):
    '''
    Interpolates 2D data over a grid using PyTorch, mimicking `scipy.interpolate.interp2d`.
    
    Parameters:
        x (torch.Tensor): 1D tensor of x coordinates (size: N).
        y (torch.Tensor): 1D tensor of y coordinates (size: M).
        z (torch.Tensor): 2D tensor of shape (M, N) representing the grid values.
        x_new (torch.Tensor): 1D tensor of new x coordinates for interpolation (size: N').
        y_new (torch.Tensor): 1D tensor of new y coordinates for interpolation (size: M').
        mode (str): Interpolation mode ('bilinear', 'nearest'). Defaults to 'bilinear'.
        
    Returns:
        torch.Tensor: Interpolated values at new (x_new, y_new) grid points.
    '''
    
    # Ensure the input tensors are of the correct shape

    z = z.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions (1, 1, M, N)
    
    # Create the meshgrid for new points (x_new, y_new)
    x_new_grid, y_new_grid = tc.meshgrid(x_new, y_new, indexing='ij')
    
    # Normalize new grid coordinates to range [-1, 1] (for grid_sample)
    x_min, x_max = x.min(), x.max()
    y_min, y_max = y.min(), y.max()
    
    x_new_norm = 2 * (x_new_grid - x_min) / (x_max - x_min) - 1
    y_new_norm = 2 * (y_new_grid - y_min) / (y_max - y_min) - 1
    
    # Stack and reshape the new coordinates into (1, H', W', 2) for grid_sample
    grid = tc.stack((x_new_norm, y_new_norm), dim=-1).unsqueeze(0)
    
    # Perform the interpolation using grid_sample
    interpolated = tc.nn.functional.grid_sample(z, grid, mode=mode, align_corners=True)
    
    # Remove the batch and channel dimensions and return the result
    return interpolated.squeeze()

def torch_interp1d(x, y, x_query):

    indices = tc.searchsorted(x, x_query) - 1
    indices = tc.clamp(indices, 0, len(x) - 2)

    x0 = x[indices]
    x1 = x[indices + 1]
    y0 = y[indices]
    y1 = y[indices + 1]

    slope = (y1 - y0) / (x1 - x0)
    y_query = y0 + slope * (x_query - x0)
    
    return y_query

In [76]:
class Cl_kSZ2_HI2():

    def __init__(self, z_array, Tb = 1.8e-4, H0 = 67.75, ombh2 = 0.022):
        
        ##################################################s
        # Define the cosmological parameters
        params = camb.CAMBparams()
        params.set_cosmology(H0=H0, ombh2=ombh2)
        params.set_matter_power(redshifts = z_array, kmax=10, nonlinear=True)
        results = camb.get_results(params)
        backgrounds = camb.get_background(params)

        # Calculate the background evolution and results
        kh, z, Pm = results.get_matter_power_spectrum(minkh=1e-4, maxkh=10, npoints = 500, var1='delta_tot', var2='delta_tot')
        Xe_of_z = np.array(backgrounds.get_background_redshift_evolution(z_array, ['x_e'], format='array')).flatten()
        chi_of_z = np.array(results.comoving_radial_distance(z_array))

        ##################################################
        # Store the variables that we are interested in

        # Instruments' properties
        Z_MEAN = 0.45 # mean redshift for HI observation
        FREQ_HI = 1420. # in unit MHz
        self.SIGMA_HI = 0.0115 * 1000. * (1. + Z_MEAN) / FREQ_HI
        self.SIGMA_KSZ = deepcopy(self.SIGMA_HI)

        # Constant scalars and arrays
        self.TCMB = params.TCMB     # CMB temperature 2.7K
        self.Tb = Tb                # HI brightness temperature, in unite mK
        self.kh_list = kh           # Total kh array that we are interested in
        self.kh_array = tc.tensor(kh)
        self.z_list = z             # Total redshift array that we are interested in
        self.z_array = tc.tensor(z)
        self.Pm = tc.tensor(Pm)     # Matter power spectrum

        # Functions of redshift
        self.H_of_z = tc.tensor(backgrounds.hubble_parameter(z)) / sc.c     # Hubble parameter over c, in unit h/Mpc
        self.f_of_z = tc.tensor(                                            # Logarithmic growth rate
            backgrounds.get_redshift_evolution([0.01], z, ['growth']) ).flatten()
        self.Xe_of_z = tc.tensor(Xe_of_z)                                   # Ionized franction Xe
        self.chi_of_z = tc.tensor(chi_of_z)                                 # Comoving distance chi, in unit Mpc/h
        self.dchi_by_dz = 1. / self.H_of_z                                  # Comoving distance growth rate dchi/dz
        self.F_kSZ = self.Xe_of_z * (1+self.z_array)**2 / self.chi_of_z**2  # F_kSZ, propto visibility function of kSZ
        self.G_HI = 1 / (z[-1] - z[0]) / self.chi_of_z**2                   # G_HI, proptp window function of HI

        # Interpolation functions for matter power spectrum
        # adding infrared asymptotic behavior (P proportional to k)
        N_add = 5
        self.kh_array_itp = tc.hstack([tc.linspace(0., kh[0], N_add), tc.tensor(kh[1:])])
        Pm_infared = tc.linspace(0., kh[0], N_add).repeat(len(z)).reshape([len(z), N_add]) * Pm[:, :1] / kh[0]
        self.Pm_itp = tc.hstack([Pm_infared, tc.tensor(Pm[:, 1:])])

        # save the cosmological model, for checking the result
        self.results = results
        self.BGEvolution = backgrounds
        
    def Growth_Rate_of_z(self, backgrounds, itp_order):
        '''
        Get the interpolation function for logarithmic growth rate f, 
        defined as f:=d(ln D)/d(ln a)
        '''
        # Since the growth rate almost does not vary with momentum scale, we fix kh=0.01 to get f
        f_of_z = backgrounds.get_redshift_evolution([0.01], self.z_list, ['growth'])
        return interp1d(self.z_list, np.array(f_of_z).flatten(), kind = itp_order)
    
    def Power_matter_1d(self, kh, zindex):
        return torch_interp1d(self.kh_array_itp, (self.Pm_itp)[zindex], kh)

    def Beam_kSZ(self, l):
        return tc.exp(-l**2 * self.SIGMA_KSZ**2 / 2)
    
    def Beam_HI(self, l):
        return tc.exp(-l**2 * self.SIGMA_HI**2 / 2)

    def Cross_Power(self, z, L, b1, b2, cut_off= tc.tensor([2.])):
        
        chi = self.chi(z)
        kh = L / chi
        kh_cutoff = cut_off / chi
        shape = kh.shape

        if b1 not in ['e', 'v', 'HI'] or b2 not in ['e', 'v', 'HI']:
            print('b1 and b2 must be "e", "v" or "HI"')
            raise
        else:
            if b1 == 'e': B1 = self.bias_electron
            elif b1 == 'v': B1 = self.bias_velocity
            elif b1 == 'HI': B1 = self.bias_HI

            if b2 == 'e': B2 = self.bias_electron
            elif b2 == 'v': B2 = self.bias_velocity
            elif b2 == 'HI': B2 = self.bias_HI

        mesh = tc.where(L <= cut_off)
        P = (self.Pm_interpolation(kh.flatten().clone().detach(), tc.tensor([z]))).reshape(shape) * B1(kh, z) * B2(kh, z)

        if b1=='v' and b2=='v' :
            P[mesh] = 2. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
        elif b1=='v' or b2=='v' :
            P[mesh] = self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)
        else:
            P[mesh] = 2./3. * self.Pm_interpolation(kh_cutoff, tc.tensor([z])) * B1(kh_cutoff, z) * B2(kh_cutoff, z)

        return P

    def bias_electron(self, kh, zindex): # TO BE REVISED
        return kh/kh
    
    def bias_velocity(self, kh, zindex, cut_off = tc.tensor([1e-6], dtype=tc.float64)):
        z_dependence = 1/(1+self.z_array[zindex]) * self.H_of_z[zindex] * self.f_of_z[zindex]
        # cut off the divergence at infrared
        return tc.where(kh > cut_off, z_dependence / kh, z_dependence / cut_off)
    
    def bias_HI(self, kh, zindex): # TO BE REVISED
        return kh/kh
    
    
    def dCl(self, zi, l, l1, l_min = 1, l_max = 800, N_l = 1600, N_theta = 243):
        """Evaluare the integrand, dCl, as a function of z, l and l_1.

        Here we sum over theta_1, l_2, and theta_2. To get the final C_l result, one has to integrate dCl over chi and l_1, for a given l.

        Input
        -----
        `z` : float. 
            The redshift. 

        `l` : float. 
            The moment for C_l. Don't need to be an integer since we are in flat-sky approximation.

        `l1` : float.
            The norm of \\vec{l}_1.

        """
        ##################################################
        # Redefine the inputs as tc.tensors
        # z = tc.tensor([self.z_list[zindex]], dtype=tc.float64)
        l = tc.tensor([l], dtype=tc.float64)
        l1 = tc.tensor([l1], dtype=tc.float64)

        # Make the mesh grid for theta_1, |l_2|, and theta_2
        t1_list = tc.arange(N_theta, dtype=tc.float64) * tc.pi / N_theta
        t2_list = deepcopy(t1_list)
        l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64)
        l2, t1, t2 = tc.meshgrid(l2_list, t1_list, t2_list, indexing='ij')

        # # Make the mesh grid for theta_1, |l_2|, and theta_2
        # t1 = tc.tensor([tc.pi / 3.], dtype=tc.float64)
        # t2_list = tc.arange(N_theta, dtype=tc.float64) * 2 * tc.pi / N_theta
        # l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64) # 10**tc.linspace(np.log10(l_min), np.log10(l_max), N_l, dtype=tc.float64)
        # l2, t2 = tc.meshgrid(l2_list, t2_list, indexing='ij')

        # Pre-define useful varibales and constants
        chi = self.chi_of_z[zi]
        lsquare = l**2
        l1square = l1**2
        l2square = l2**2

        l_dot_l1 = Polar_dot(l, 0., l1, t1)
        l_dot_l2 = Polar_dot(l, 0., l2, t2)
        l1_dot_l2 = Polar_dot(l1, t1, l2, t2)

        k_l1_p_l2_norm = tc.sqrt( l1square + l2square + 2*l1_dot_l2 ) / chi
        k_l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 ) / chi
        k_l_m_l1_p_l2_norm = tc.sqrt( lsquare + l1square + l2square - 2*l_dot_l1 + 2*l_dot_l2 - 2*l1_dot_l2 ) / chi
        k_l2 = l2 / chi

        # Delete redundant variables to save memory
        del(l_dot_l1, l_dot_l2, l1_dot_l2)

        theta_l_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l1_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l_m_l1_p_l2 = Evaluate_angle(3, l, tc.tensor([0.]), -l1, t1, l2, t2)

        # Pre-calculate the matter power spectrum
        P_l1_p_l2_norm = self.Power_matter_1d(k_l1_p_l2_norm, zi)
        P_l_p_l2_norm = self.Power_matter_1d(k_l_p_l2_norm, zi)
        P_l2 = self.Power_matter_1d(k_l2, zi)
        P_l_m_l1_p_l2_norm = self.Power_matter_1d(k_l_m_l1_p_l2_norm, zi)
       

        ##################################################
        # Evaluate the integrand
        # Initialization
        dCl_tot = tc.zeros_like(t2)

        # Contribution originate from each term in Wick Theorem
        # Term 5 and Term 6
        dCl = - tc.cos(theta_l1_p_l2 - t2)
        dCl *= P_l1_p_l2_norm * self.bias_electron(k_l1_p_l2_norm,zi)**2 + P_l_m_l1_p_l2_norm * self.bias_electron(k_l_m_l1_p_l2_norm,zi)**2
        dCl *= P_l_p_l2_norm        * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot += dCl
        # Term 8 
        dCl = tc.cos(theta_l_p_l2 - theta_l1_p_l2)
        dCl *= P_l_m_l1_p_l2_norm   * self.bias_electron(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot += dCl
        # Term 9
        dCl = tc.cos(theta_l_m_l1_p_l2 - t2)
        dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl *= P_l_p_l2_norm        * self.bias_velocity(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        dCl_tot += dCl
        # Term 10
        dCl = tc.cos(theta_l1_p_l2 - t2)
        dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot += dCl
        # Term 11
        dCl = -1.
        dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        dCl *= P_l1_p_l2_norm       * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot += dCl
        # Term 13
        dCl = tc.cos(theta_l_m_l1_p_l2 - theta_l_p_l2)
        dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl *= P_l_m_l1_p_l2_norm   * self.bias_electron(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        dCl *= P_l1_p_l2_norm       * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        dCl_tot += dCl
        # Term 14
        dCl = -1.
        dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl *= P_l_m_l1_p_l2_norm   * self.bias_velocity(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        dCl_tot += dCl

        # Delete redundant variables to save memory
        del(P_l1_p_l2_norm, P_l_p_l2_norm, P_l2, P_l_m_l1_p_l2_norm)
        # The beam functions
        l_m_l1_norm = tc.sqrt( lsquare + l1square - 2*l_dot_l1 )
        l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 )
        dCl_tot *= self.Beam_kSZ(l_m_l1_norm) * self.Beam_kSZ(l1) * self.Beam_HI(l_p_l2_norm) * self.Beam_HI(l2)
        # The window functions and the metric determinant contribution 
        dCl_tot *= l1 * l2 * self.F_kSZ[zi]**2 * self.G_HI[zi]**2 * self.dchi_by_dz[zi]

        dCl_res = tc.sum(dCl_tot) * t2_list[1]**2 * ((l_max - l_min) / (N_l - 1))**2

        return dCl_res

    
    def dCl_timer(self, zi, l, l1, l_min = 1, l_max = 800, N_l = 1600, N_theta = 243):
        """Evaluare the integrand, dCl, as a function of z, l and l_1.

        Here we sum over theta_1, l_2, and theta_2. To get the final C_l result, one has to integrate dCl over chi and l_1, for a given l.

        Input
        -----
        `z` : float. 
            The redshift. 

        `l` : float. 
            The moment for C_l. Don't need to be an integer since we are in flat-sky approximation.

        `l1` : float.
            The norm of \\vec{l}_1.

        """
        ##################################################
        # Redefine the inputs as tc.tensors

        t0 = time.time()
        # z = tc.tensor([self.z_list[zindex]], dtype=tc.float64)
        l = tc.tensor([l], dtype=tc.float64)
        l1 = tc.tensor([l1], dtype=tc.float64)

        # Make the mesh grid for theta_1, |l_2|, and theta_2
        t1_list = tc.arange(N_theta, dtype=tc.float64) * tc.pi / N_theta
        t2_list = deepcopy(t1_list)
        l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64)
        l2, t1, t2 = tc.meshgrid(l2_list, t1_list, t2_list, indexing='ij')

        print('Meshgrid finished ', time.time() - t0)
        # # Make the mesh grid for theta_1, |l_2|, and theta_2
        # t1 = tc.tensor([tc.pi / 3.], dtype=tc.float64)
        # t2_list = tc.arange(N_theta, dtype=tc.float64) * 2 * tc.pi / N_theta
        # l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64) # 10**tc.linspace(np.log10(l_min), np.log10(l_max), N_l, dtype=tc.float64)
        # l2, t2 = tc.meshgrid(l2_list, t2_list, indexing='ij')

        # Pre-define useful varibales and constants
        chi = self.chi_of_z[zi]
        lsquare = l**2
        l1square = l1**2
        l2square = l2**2

        l_dot_l1 = Polar_dot(l, 0., l1, t1)
        l_dot_l2 = Polar_dot(l, 0., l2, t2)
        l1_dot_l2 = Polar_dot(l1, t1, l2, t2)

        k_l1_p_l2_norm = tc.sqrt( l1square + l2square + 2*l1_dot_l2 ) / chi
        k_l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 ) / chi
        k_l_m_l1_p_l2_norm = tc.sqrt( lsquare + l1square + l2square - 2*l_dot_l1 + 2*l_dot_l2 - 2*l1_dot_l2 ) / chi
        k_l2 = l2 / chi

        print('k norms have been evaluated ', time.time() - t0)
        # Delete redundant variables to save memory
        del(l1_dot_l2)

        theta_l_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l1_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l_m_l1_p_l2 = Evaluate_angle(3, l, tc.tensor([0.]), -l1, t1, l2, t2)
        print('Angles have been evaluated ', time.time() - t0)

        # Pre-calculate the matter power spectrum
        P_l1_p_l2_norm = self.Power_matter_1d(k_l1_p_l2_norm, zi)
        P_l_p_l2_norm = self.Power_matter_1d(k_l_p_l2_norm, zi)
        P_l2 = self.Power_matter_1d(k_l2, zi)
        P_l_m_l1_p_l2_norm = self.Power_matter_1d(k_l_m_l1_p_l2_norm, zi)
        print('Powers have been evaluated ', time.time() - t0)
       

        ##################################################
        # Evaluate the integrand
        # Initialization
        dCl_tot = tc.zeros_like(t2)

        # Contribution originate from each term in Wick Theorem
        # Term 5 and Term 6
        dCl = - tc.cos(theta_l1_p_l2 - t2)
        dCl *= P_l1_p_l2_norm * self.bias_electron(k_l1_p_l2_norm,zi)**2 + P_l_m_l1_p_l2_norm * self.bias_electron(k_l_m_l1_p_l2_norm,zi)**2
        dCl *= P_l_p_l2_norm        * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot += dCl
        print('Term 5 and 6', time.time() - t0)
        # Term 8 
        dCl = tc.cos(theta_l_p_l2 - theta_l1_p_l2)
        dCl *= P_l_m_l1_p_l2_norm   * self.bias_electron(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot += dCl
        print('Term 8', time.time() - t0)
        # Term 9
        dCl = tc.cos(theta_l_m_l1_p_l2 - t2)
        dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl *= P_l_p_l2_norm        * self.bias_velocity(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        dCl_tot += dCl
        print('Term 9', time.time() - t0)
        # Term 10
        dCl = tc.cos(theta_l1_p_l2 - t2)
        dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot += dCl
        print('Term 10', time.time() - t0)
        # Term 11
        dCl = -1.
        dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        dCl *= P_l1_p_l2_norm       * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot += dCl
        print('Term 11', time.time() - t0)
        # Term 13
        dCl = tc.cos(theta_l_m_l1_p_l2 - theta_l_p_l2)
        dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl *= P_l_m_l1_p_l2_norm   * self.bias_electron(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        dCl *= P_l1_p_l2_norm       * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        dCl_tot += dCl
        print('Term 13', time.time() - t0)
        # Term 14
        dCl = -1.
        dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl *= P_l_m_l1_p_l2_norm   * self.bias_velocity(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        dCl_tot += dCl
        print('Term 14', time.time() - t0)

        # Delete redundant variables to save memory
        del(P_l1_p_l2_norm, P_l_p_l2_norm, P_l2, P_l_m_l1_p_l2_norm)
        # The beam functions
        l_m_l1_norm = tc.sqrt( lsquare + l1square - 2*l_dot_l1 )
        l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 )
        dCl_tot *= self.Beam_kSZ(l_m_l1_norm) * self.Beam_kSZ(l1) * self.Beam_HI(l_p_l2_norm) * self.Beam_HI(l2)
        print('Beams ', time.time() - t0)
        # The window functions and the metric determinant contribution 
        dCl_tot *= l1 * l2 * self.F_kSZ[zi]**2 * self.G_HI[zi]**2 * self.dchi_by_dz[zi]
        print('Window functions', time.time() - t0)

        dCl_res = tc.sum(dCl_tot) * t2_list[1] * (l_max - l_min) / (N_l - 1)

        return dCl_res, dCl, dCl_tot

    def dCl_timer_2d(self, zi, l, l1, l_min = 1, theta_option = True, l_max = 800, N_l = 1600, N_theta = 243):
        """Evaluare the integrand, dCl, as a function of z, l and l_1.

        Here we sum over theta_1, l_2, and theta_2. To get the final C_l result, one has to integrate dCl over chi and l_1, for a given l.

        Input
        -----
        `z` : float. 
            The redshift. 

        `l` : float. 
            The moment for C_l. Don't need to be an integer since we are in flat-sky approximation.

        `l1` : float.
            The norm of \\vec{l}_1.

        """
        ##################################################
        # Redefine the inputs as tc.tensors

        t0 = time.time()
        # z = tc.tensor([self.z_list[zindex]], dtype=tc.float64)
        l = tc.tensor([l], dtype=tc.float64)
        l1 = tc.tensor([l1], dtype=tc.float64)

        # # Make the mesh grid for theta_1, |l_2|, and theta_2
        # t1_list = tc.arange(N_theta, dtype=tc.float64) * tc.pi / N_theta
        # t2_list = deepcopy(t1_list)
        # l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64)
        # l2, t1, t2 = tc.meshgrid(l2_list, t1_list, t2_list, indexing='ij')
        
        # Make the mesh grid for theta_1, |l_2|, and theta_2
        t1 = tc.tensor([tc.pi / 3.], dtype=tc.float64)
        if theta_option:
            t2_list = tc.arange(N_theta, dtype=tc.float64) * 2 * tc.pi / N_theta + tc.pi / N_theta
        else:
            t2_list = tc.arange(N_theta, dtype=tc.float64) * 2 * tc.pi / N_theta
        l2_list = tc.linspace(l_min, l_max, N_l, dtype=tc.float64) - (l_max - l_min) / (N_l - 1) / 2
        l2, t2 = tc.meshgrid(l2_list, t2_list, indexing='ij')
        print('Meshgrid finished ', time.time() - t0)

        # Pre-define useful varibales and constants
        chi = self.chi_of_z[zi]
        lsquare = l**2
        l1square = l1**2
        l2square = l2**2

        l_dot_l1 = Polar_dot(l, 0., l1, t1)
        l_dot_l2 = Polar_dot(l, 0., l2, t2)
        l1_dot_l2 = Polar_dot(l1, t1, l2, t2)

        k_l1_p_l2_norm = tc.sqrt( l1square + l2square + 2*l1_dot_l2 ) / chi
        k_l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 ) / chi
        k_l_m_l1_p_l2_norm = tc.sqrt( lsquare + l1square + l2square - 2*l_dot_l1 + 2*l_dot_l2 - 2*l1_dot_l2 ) / chi
        k_l2 = l2 / chi

        print('k norms have been evaluated ', time.time() - t0)
        # Delete redundant variables to save memory
        del(l1_dot_l2)

        theta_l_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l1_p_l2 = Evaluate_angle(2, l, tc.tensor([0.]), l2, t2)
        theta_l_m_l1_p_l2 = Evaluate_angle(3, l, tc.tensor([0.]), -l1, t1, l2, t2)
        print('Angles have been evaluated ', time.time() - t0)

        # Pre-calculate the matter power spectrum
        P_l1_p_l2_norm = self.Power_matter_1d(k_l1_p_l2_norm, zi)
        P_l_p_l2_norm = self.Power_matter_1d(k_l_p_l2_norm, zi)
        P_l2 = self.Power_matter_1d(k_l2, zi)
        P_l_m_l1_p_l2_norm = self.Power_matter_1d(k_l_m_l1_p_l2_norm, zi)
        print('Powers have been evaluated ', time.time() - t0)
       

        ##################################################
        # Evaluate the integrand
        # Initialization
        dCl_tot = tc.zeros_like(t2)

        # Contribution originate from each term in Wick Theorem
        # Term 5 and Term 6
        dCl0 = - tc.cos(theta_l1_p_l2 - t2)
        dCl1 = P_l1_p_l2_norm * self.bias_electron(k_l1_p_l2_norm,zi)**2 + P_l_m_l1_p_l2_norm * self.bias_electron(k_l_m_l1_p_l2_norm,zi)**2
        dCl2 = P_l_p_l2_norm        * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        dCl3 = P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        dCl_tot = dCl_tot + dCl1 * dCl2 * dCl3 * dCl0
        print('Term 5 and 6', time.time() - t0)
        # # Term 8 
        # dCl = tc.cos(theta_l_p_l2 - theta_l1_p_l2)
        # dCl *= P_l_m_l1_p_l2_norm   * self.bias_electron(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        # dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        # dCl *= P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        # dCl_tot += dCl
        # print('Term 8', time.time() - t0)
        # # Term 9
        # dCl = tc.cos(theta_l_m_l1_p_l2 - t2)
        # dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        # dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        # dCl *= P_l_p_l2_norm        * self.bias_velocity(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        # dCl_tot += dCl
        # print('Term 9', time.time() - t0)
        # # Term 10
        # dCl = tc.cos(theta_l1_p_l2 - t2)
        # dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        # dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        # dCl *= P_l2                 * self.bias_velocity(k_l2,zi)               * self.bias_HI(k_l2,zi)
        # dCl_tot += dCl
        # print('Term 10', time.time() - t0)
        # # Term 11
        # dCl = -1.
        # dCl *= P_l_p_l2_norm        * self.bias_electron(k_l_p_l2_norm,zi)      * self.bias_HI(k_l_p_l2_norm,zi)
        # dCl *= P_l1_p_l2_norm       * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_velocity(k_l1_p_l2_norm,zi)
        # dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        # dCl_tot += dCl
        # print('Term 11', time.time() - t0)
        # # Term 13
        # dCl = tc.cos(theta_l_m_l1_p_l2 - theta_l_p_l2)
        # dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        # dCl *= P_l_m_l1_p_l2_norm   * self.bias_electron(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        # dCl *= P_l1_p_l2_norm       * self.bias_velocity(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        # dCl_tot += dCl
        # print('Term 13', time.time() - t0)
        # # Term 14
        # dCl = -1.
        # dCl *= P_l2                 * self.bias_electron(k_l2,zi)               * self.bias_HI(k_l2,zi)
        # dCl *= P_l_m_l1_p_l2_norm   * self.bias_velocity(k_l_m_l1_p_l2_norm,zi) * self.bias_velocity(k_l_m_l1_p_l2_norm,zi)
        # dCl *= P_l1_p_l2_norm       * self.bias_electron(k_l1_p_l2_norm,zi)     * self.bias_HI(k_l1_p_l2_norm,zi)
        # dCl_tot += dCl
        # print('Term 14', time.time() - t0)

        print(tc.sum(l1 * l2 * dCl_tot)  * (l_max - l_min) / (N_l - 1) * 2 * tc.pi / N_theta)

        # Delete redundant variables to save memory
        del(P_l1_p_l2_norm, P_l_p_l2_norm, P_l2, P_l_m_l1_p_l2_norm)
        # The beam functions
        l_m_l1_norm = tc.sqrt( lsquare + l1square - 2*l_dot_l1 )
        l_p_l2_norm = tc.sqrt( lsquare + l2square + 2*l_dot_l2 )
        dCl_tot *= self.Beam_kSZ(l_m_l1_norm) * self.Beam_kSZ(l1) * self.Beam_HI(l_p_l2_norm) * self.Beam_HI(l2)
        print('Beams ', time.time() - t0)
        # The window functions and the metric determinant contribution 
        dCl_tot *= l1 * l2 * self.F_kSZ[zi]**2 * self.G_HI[zi]**2 * self.dchi_by_dz[zi]
        print('Window functions', time.time() - t0)

        dCl_res = tc.sum(dCl_tot)  * (l_max - l_min) / (N_l - 1) * 2 * tc.pi / N_theta

        return l2, t2, dCl_res, dCl_tot, dCl1, dCl2, dCl3, dCl0



def Polar_dot(lx, thetax, ly, thetay):
    return lx * ly * np.cos(thetax - thetay)

def Evaluate_angle(N_vec, *vectors):

    if 2*N_vec != len(vectors):
        print('The input N_vec does not match the number of input vectors')
        raise
    else:
        # We need to do some adjustment on vectors to match the broadcast rule
        # In order to keep vectors unchanged, make a copy of them for calculation
        vec = deepcopy(vectors)

        l_x = 0.
        l_y = 0.
        for i in range(N_vec):
            l_x = l_x + vec[2*i] * tc.cos(vec[2*i+1])
            l_y = l_y + vec[2*i] * tc.sin(vec[2*i+1])
        
        return tc.atan2(l_y, l_x)
    
def torch_interp1d(x, y, x_query):

    indices = tc.searchsorted(x, x_query) - 1
    indices = tc.clamp(indices, 0, len(x) - 2)

    x0, x1 = x[indices], x[indices + 1]
    y0, y1 = y[indices], y[indices + 1]
    
    slope = (y1 - y0) / (x1 - x0)
    y_query = y0 + slope * (x_query - x0)
    
    return y_query

def torch_interp2d(x, y, z, x_new, y_new, mode='bilinear'):
    '''
    Interpolates 2D data over a grid using PyTorch, mimicking `scipy.interpolate.interp2d`.
    
    Parameters:
        x (torch.Tensor): 1D tensor of x coordinates (size: N).
        y (torch.Tensor): 1D tensor of y coordinates (size: M).
        z (torch.Tensor): 2D tensor of shape (M, N) representing the grid values.
        x_new (torch.Tensor): 1D tensor of new x coordinates for interpolation (size: N').
        y_new (torch.Tensor): 1D tensor of new y coordinates for interpolation (size: M').
        mode (str): Interpolation mode ('bilinear', 'nearest'). Defaults to 'bilinear'.
        
    Returns:
        torch.Tensor: Interpolated values at new (x_new, y_new) grid points.
    '''
    
    # Ensure the input tensors are of the correct shape

    z = z.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions (1, 1, M, N)
    
    # Create the meshgrid for new points (x_new, y_new)
    x_new_grid, y_new_grid = tc.meshgrid(x_new, y_new, indexing='ij')
    
    # Normalize new grid coordinates to range [-1, 1] (for grid_sample)
    x_min, x_max = x.min(), x.max()
    y_min, y_max = y.min(), y.max()
    
    x_new_norm = 2 * (x_new_grid - x_min) / (x_max - x_min) - 1
    y_new_norm = 2 * (y_new_grid - y_min) / (y_max - y_min) - 1
    
    # Stack and reshape the new coordinates into (1, H', W', 2) for grid_sample
    grid = tc.stack((x_new_norm, y_new_norm), dim=-1).unsqueeze(0)
    
    # Perform the interpolation using grid_sample
    interpolated = tc.nn.functional.grid_sample(z, grid, mode=mode, align_corners=True)
    
    # Remove the batch and channel dimensions and return the result
    return interpolated.squeeze()


In [77]:
testclass = Cl_kSZ2_HI2(np.linspace(0.3, 0.6, 10))

Note: redshifts have been re-sorted (earliest first)


In [78]:
l2, t2, dClres, dCl_tot, dCl1, dCl2, dCl3, dCl0 = testclass.dCl_timer_2d(0, 200, 200, theta_option=True)#, l_max=800, N_l=1600, N_theta=243)
dClres

Meshgrid finished  0.0
k norms have been evaluated  0.012997865676879883
Angles have been evaluated  0.018900632858276367
Powers have been evaluated  0.03425025939941406
Term 5 and 6 0.04226112365722656
tensor(-117712.7615)
Beams  0.04226112365722656
Window functions 0.04226112365722656


tensor(1.6846e-14)

In [72]:
fig = plt.figure()
ax3 = plt.axes(projection='3d')

# xi = 150; xe = 200; yi = 38; ye = 44
# xi = 150; xe = 250; yi = 25; ye = 38
# xi = 150; xe = 250; yi = 48; ye = 60
# xi = 100; xe = 300; yi = 115; ye = 130
# xi = 0; xe = 800; yi = 0; ye = -1
xi = 350; xe = 450; yi = 152; ye = 170

# ax3.plot_surface(l2[xi:xe,yi:ye], t2[xi:xe,yi:ye], dCl1[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl1[xi:xe,yi:ye],cmap='rainbow')

ax3.plot_surface(l2[xi:xe,yi:ye], t2[xi:xe,yi:ye], dCl1[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl_tot[xi:xe,yi:ye],cmap='rainbow')


plt.show()

In [4]:
dClres = testclass.dCl(0.4, 200, 200, l_max=800, N_l=1600, N_theta=243)
dClres

torch.Size([1600, 243, 243])


tensor(707.1861)

In [53]:
test1 = tc.arange(10)
test2 = tc.arange(13)
test3 = tc.arange(20)


In [54]:
del(test1, test2, test3)

In [28]:
l2, t2, dClres, dCl_tot, dCl_terms = testclass.dCl(0.4, 200, 200, l_max=800, N_l=1600, N_theta=729)
dClres

tensor(0.4943)

In [144]:
fig = plt.figure()
ax3 = plt.axes(projection='3d')

# xi = 150; xe = 200; yi = 38; ye = 44
# xi = 150; xe = 250; yi = 25; ye = 38
# xi = 150; xe = 250; yi = 48; ye = 60
# xi = 100; xe = 300; yi = 115; ye = 130
xi = 200; xe = 250; yi = 0; ye = -1
# xi = 150; xe = 250; yi = 50; ye = 70

# ax3.plot_surface(l2[xi:xe,yi:ye], t2[xi:xe,yi:ye], dCl1[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl1[xi:xe,yi:ye],cmap='rainbow')

ax3.plot_surface(l2[xi:xe,yi:ye], t2[xi:xe,yi:ye], dCl_tot[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl_tot[xi:xe,yi:ye],cmap='rainbow')


plt.show()

In [145]:
dClres

tensor(-0.1948)

In [126]:
dClres

tensor(-0.2574)

In [31]:
test1 = testclass.Cross_Power(0.4, tc.linspace(0,100,300), b1='v', b2='v')

mesh = np.isnan(np.array(test1))
mesh.any()

False

In [29]:
mesh.all()

False

In [101]:
l2, t2, dCl_tot, dCl0, dCl1, dCl2, dCl3, dCl_Beam, dCl_FG, l1_p_l2_norm, l_p_l2_norm, l_m_l1_p_l2_norm = testclass.dCl(0.4, 200, 200, N_theta=81)

In [104]:
fig = plt.figure()
ax3 = plt.axes(projection='3d')

# xi = 150; xe = 250; yi = 38; ye = 44
# xi = 100; xe = 300; yi = 115; ye = 130
# xi = 0; xe = -1; yi = 0; ye = -1
xi = 150; xe = 250; yi = 50; ye = 70

# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl0[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl1[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl2[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl3[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl_Beam[xi:xe,yi:ye],cmap='rainbow')

# ax3.plot_surface(l2[xi:xe,yi:ye], t2[xi:xe,yi:ye], dCl_tot[xi:xe,yi:ye],cmap='rainbow')
ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], dCl_tot[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], tc.log10(tc.abs(dCl_tot[xi:xe,yi:ye])),cmap='rainbow')

# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], l1_p_l2_norm[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], l_p_l2_norm[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(tc.log10(l2[xi:xe,yi:ye]), t2[xi:xe,yi:ye], l_m_l1_p_l2_norm[xi:xe,yi:ye],cmap='rainbow')
# ax3.plot_surface(l2[xi:xe,yi:ye], t2[xi:xe,yi:ye], l_m_l1_p_l2_norm[xi:xe,yi:ye],cmap='rainbow')

plt.show()

In [56]:
dCl2

tensor([[1168108.4772, 1168108.4772, 1168108.4772,  ..., 1168108.4772,
         1168108.4772, 1168108.4772],
        [1168108.4772, 1168108.4772, 1168108.4772,  ..., 1168108.4772,
         1168108.4772, 1168108.4772],
        [1168108.4772, 1168108.4772, 1168108.4772,  ..., 1168108.4772,
         1168108.4772, 1168108.4772],
        ...,
        [1168108.4772, 1168108.4772, 1168108.4772,  ..., 1168108.4772,
         1168108.4772, 1168108.4772],
        [1168108.4772, 1168108.4772, 1168108.4772,  ..., 1168108.4772,
         1168108.4772, 1168108.4772],
        [1168108.4772, 1168108.4772, 1168108.4772,  ..., 1168108.4772,
         1168108.4772, 1168108.4772]])

In [7]:
test = tc.where(l_p_l2_norm <= 2)

In [89]:
l_p_l2_norm.max()

tensor(1200.)

In [27]:
l_m_l1_p_l2_norm[190:210, 80:84]

tensor([[1.0322e+01, 9.0000e+00, 1.0322e+01, 1.3533e+01],
        [9.4695e+00, 8.0000e+00, 9.4695e+00, 1.2910e+01],
        [8.6490e+00, 7.0000e+00, 8.6490e+00, 1.2337e+01],
        [7.8701e+00, 6.0000e+00, 7.8701e+00, 1.1821e+01],
        [7.1465e+00, 5.0000e+00, 7.1465e+00, 1.1370e+01],
        [6.4967e+00, 4.0000e+00, 6.4967e+00, 1.0991e+01],
        [5.9448e+00, 3.0000e+00, 5.9448e+00, 1.0693e+01],
        [5.5203e+00, 2.0000e+00, 5.5203e+00, 1.0482e+01],
        [5.2543e+00, 1.0000e+00, 5.2543e+00, 1.0364e+01],
        [5.1712e+00, 2.6974e-06, 5.1712e+00, 1.0342e+01],
        [5.2797e+00, 1.0000e+00, 5.2797e+00, 1.0415e+01],
        [5.5685e+00, 2.0000e+00, 5.5685e+00, 1.0584e+01],
        [6.0119e+00, 3.0000e+00, 6.0119e+00, 1.0842e+01],
        [6.5785e+00, 4.0000e+00, 6.5785e+00, 1.1184e+01],
        [7.2395e+00, 5.0000e+00, 7.2395e+00, 1.1603e+01],
        [7.9714e+00, 6.0000e+00, 7.9714e+00, 1.2089e+01],
        [8.7566e+00, 7.0000e+00, 8.7566e+00, 1.2637e+01],
        [9.581

In [5]:
results = testclass.results
background = testclass.BGEvolution

kh_list = testclass.kh_array

In [77]:
testv = background.get_time_evolution(q=kh_list, eta=testclass.chi(0.45), vars='v_newtonian_baryon').flatten()
testm = background.get_time_evolution(q=kh_list, eta=testclass.chi(0.45), vars='delta_baryon').flatten()
testbias = (testclass.Growth_Rate_of_z(background, itp_order='linear'))(0.45)

In [79]:
plt.plot(kh_list, testv)
plt.plot(kh_list, testm)
plt.xscale('log')
plt.yscale('log')

In [9]:
testkh, testz, testPv = results.get_matter_power_spectrum(npoints=500, var1='v_newtonian_baryon', var2='v_newtonian_baryon')
testkh, testz, testPb = results.get_matter_power_spectrum(npoints=500, var1='delta_baryon', var2='delta_baryon')

In [13]:
zindex = 3

plt.plot(kh_list, testPv[zindex])
plt.plot(kh_list, testPb[zindex])
plt.plot(kh_list, testPb[zindex] * testclass.f_of_z(testz[zindex])**2, '--')
plt.xscale('log')
plt.yscale('log')
plt.show()

In [16]:
plt.plot(kh_list, testPv[zindex] / kh_list**2)
plt.xscale('log')
plt.yscale('log')
plt.show()

In [145]:
testclass = Cl_kSZ2_HI2(np.linspace(0.3, 0.6, 10))
z0 = 0.4; chi0 = testclass.chi(z0)

Note: redshifts have been re-sorted (earliest first)


In [146]:
dCl2_linear, l2_linear, t2_linear, l_p_l2_norm_linear, l1_p_l2_norm_linear, l_m_l1_p_l2_norm_linear = testclass.test(z0, 200, 200, method='linear', l_max=1000, N_l=1000, N_theta=18)
dCl2_log, l2_log, t2_log, l_p_l2_norm_log, l1_p_l2_norm_log, l_m_l1_p_l2_norm_log = testclass.test(z0, 200, 200, method='log', l_max=1000, N_l=1000, N_theta=18)

In [147]:
dCl2_linear[:10,0]

tensor([110119.8540, 109650.8207, 109186.4085, 108726.5493, 108271.1765,
        107820.2249, 107373.6302, 106931.3298, 106493.2619, 106059.3660])

In [148]:
dCl2_log[:10,0]

tensor([110119.8540, 110116.5835, 110113.2905, 110109.9749, 110106.6365,
        110103.2752, 110099.8908, 110096.4832, 110093.0521, 110089.5976])

In [42]:
l_p_l2_norm_linear[:10,0]

tensor([201., 202., 203., 204., 205., 206., 207., 208., 209., 210.])

In [43]:
l_p_l2_norm_log[:10,0]

tensor([201.0000, 201.0069, 201.0139, 201.0210, 201.0280, 201.0352, 201.0424,
        201.0496, 201.0569, 201.0642])

In [119]:
testk1 = l_p_l2_norm_linear / chi0
testk2 = l_p_l2_norm_log / chi0
P1, Pnew1, k1, bv1, bHI1 = testclass.Cross_Power_test(z0, testk1, 'v', 'HI')
P2, Pnew2, k2, bv2, bHI2 = testclass.Cross_Power_test(z0, testk2, 'v', 'HI')

testk1.shape, testk2.shape

  kh_new = tc.tensor(kh.flatten())


(torch.Size([1000, 18]), torch.Size([1000, 18]))

In [120]:
testk1[:10, :2], testk2[:10, :2]

(tensor([[0.1258, 0.1257],
         [0.1264, 0.1263],
         [0.1270, 0.1269],
         [0.1276, 0.1275],
         [0.1283, 0.1281],
         [0.1289, 0.1287],
         [0.1295, 0.1293],
         [0.1302, 0.1299],
         [0.1308, 0.1305],
         [0.1314, 0.1310]]),
 tensor([[0.1258, 0.1257],
         [0.1258, 0.1257],
         [0.1258, 0.1257],
         [0.1258, 0.1257],
         [0.1258, 0.1258],
         [0.1258, 0.1258],
         [0.1258, 0.1258],
         [0.1258, 0.1258],
         [0.1258, 0.1258],
         [0.1258, 0.1258]]))

In [121]:
testk1.flatten()

tensor([0.1258, 0.1257, 0.1256,  ..., 0.6968, 0.7261, 0.7446])

In [122]:
k1[:10], k2[:10]

(tensor([0.1258, 0.1257, 0.1256, 0.1255, 0.1253, 0.1250, 0.1248, 0.1247, 0.1246,
         0.1245]),
 tensor([0.1258, 0.1257, 0.1256, 0.1255, 0.1253, 0.1250, 0.1248, 0.1247, 0.1246,
         0.1245]))

In [124]:
testres1 = testclass.Pm_interpolation(k1[:-10], tc.tensor([z0]))
# testres2 = testclass.Pm(k2,z0)
testres1[:10]#, testres2[:10]

tensor([319.7195, 319.7062, 319.6678, 319.6089, 319.5365, 319.4594, 319.3868,
        319.3275, 319.2888, 319.2753])

In [125]:
P1[:10], P2[:10]

(tensor([319.7195, 319.7062, 319.6678, 319.6089, 319.5365, 319.4594, 319.3868,
         319.3275, 319.2888, 319.2753]),
 tensor([319.7195, 319.7062, 319.6678, 319.6089, 319.5365, 319.4594, 319.3868,
         319.3275, 319.2888, 319.2753]))

In [126]:
Pnew1[:10, 0], Pnew2[:10, 0]

(tensor([319.7195, 319.9416, 320.1637, 320.3858, 320.6079, 320.8300, 321.0521,
         321.2741, 321.4962, 321.7183]),
 tensor([319.7195, 319.7211, 319.7226, 319.7242, 319.7258, 319.7273, 319.7289,
         319.7305, 319.7322, 319.7338]))

In [127]:
bv1[:10,0], bv2[:10,0]

(tensor([344.4264, 342.7213, 341.0331, 339.3613, 337.7059, 336.0666, 334.4431,
         332.8352, 331.2426, 329.6653]),
 tensor([344.4264, 344.4145, 344.4026, 344.3905, 344.3784, 344.3662, 344.3539,
         344.3415, 344.3290, 344.3164]))

In [128]:
l_p_l2_norm_log / chi0

tensor([[0.1258, 0.1257, 0.1256,  ..., 0.1255, 0.1256, 0.1257],
        [0.1258, 0.1257, 0.1256,  ..., 0.1255, 0.1256, 0.1257],
        [0.1258, 0.1257, 0.1256,  ..., 0.1255, 0.1256, 0.1257],
        ...,
        [0.7423, 0.7360, 0.7175,  ..., 0.6883, 0.7175, 0.7360],
        [0.7466, 0.7403, 0.7218,  ..., 0.6925, 0.7218, 0.7403],
        [0.7509, 0.7446, 0.7261,  ..., 0.6968, 0.7261, 0.7446]])