# Notebook to calculate P_x

In [None]:
import h5py
import numpy as np
import os
from matplotlib import pyplot as plt
import astrid as st

## Read the data from both files

In [None]:
# specify the path to the two files with Lyman alpha skewers, with (ON) and without (OFF) self-shielding 
fname_off='/data/desi/common/astrid/spectra_ASTRID_self-shield_off_z2.5_500x500x2500.hdf5'
fname_on='/data/desi/common/astrid/spectra_ASTRID_z2.5_500x500x2500.hdf5'

In [None]:
with h5py.File(fname_on,'r') as f:
    tau_on = f['tau/H/1/1215'][:]
    colden = f['colden']['H/1'][:]
    axes = f['spectra/axes'][:]
    # (x,y,z) start of the skewers, in kpc/h
    cofm_hkpc = f['spectra/cofm'][:]

In [None]:
with h5py.File(fname_off,'r') as f:
    tau_off = f['tau/H/1/1215'][:]
    # column densities were not stored for this file (should be the same as in the other file)
    #colden_off = f['colden']['H/1'][:]
    axes_off = f['spectra/axes'][:]
    cofm_off = f['spectra/cofm'][:]

In [None]:
# Constants
save_path = '/data/desi/scratch/jlopez/'
Nsk = 500
Np = 2500
block_grid_size = 5
block_size = Nsk // block_grid_size  # = 100
L_hMpc = 250
dz_hMpc=L_hMpc/Np
dxy_hMpc=L_hMpc/Nsk  # = 0.5 Mpc/h

## Break down the simulation into 25 boxes

In [None]:
D2tau_on = tau_on.reshape([Nsk, Nsk, Np])
D2tau_off = tau_off.reshape([Nsk, Nsk, Np])
D2colden = colden.reshape([Nsk, Nsk, Np])

In [None]:
D2tau_on_blocks = np.empty((block_grid_size, block_grid_size, block_size, block_size, Np))
tau_on_blocks = np.empty((block_grid_size, block_grid_size, block_size**2, Np))

for i in range(block_grid_size):
    for j in range(block_grid_size):
        row_start = i * block_size
        row_end   = (i + 1) * block_size
        col_start = j * block_size
        col_end   = (j + 1) * block_size

        block = D2tau_on[row_start:row_end, col_start:col_end, :]
        D2tau_on_blocks[i, j] = block
        tau_on_blocks[i, j] = block.reshape(block_size**2, Np)

In [None]:
D2tau_off_blocks = np.empty((block_grid_size, block_grid_size, block_size, block_size, Np))
tau_off_blocks = np.empty((block_grid_size, block_grid_size, block_size**2, Np))

for i in range(block_grid_size):
    for j in range(block_grid_size):
        row_start = i * block_size
        row_end   = (i + 1) * block_size
        col_start = j * block_size
        col_end   = (j + 1) * block_size

        block = D2tau_off[row_start:row_end, col_start:col_end, :]
        D2tau_off_blocks[i, j] = block
        tau_off_blocks[i, j] = block.reshape(block_size**2, Np)

In [None]:
D2colden_blocks = np.empty((block_grid_size, block_grid_size, block_size, block_size, Np))
colden_blocks = np.empty((block_grid_size, block_grid_size, block_size**2, Np))

for i in range(block_grid_size):
    for j in range(block_grid_size):
        row_start = i * block_size
        row_end   = (i + 1) * block_size
        col_start = j * block_size
        col_end   = (j + 1) * block_size

        block = D2colden[row_start:row_end, col_start:col_end, :]
        D2colden_blocks[i, j] = block
        colden_blocks[i, j] = block.reshape(block_size**2, Np)

## Mask skewers with very large DLAs

In [None]:
max_logN_mask = 21.3

In [None]:
new_tau_on_blocks = np.empty((block_grid_size, block_grid_size), dtype=object)
new_tau_off_blocks = np.empty((block_grid_size, block_grid_size), dtype=object)

In [None]:
for i in range(block_grid_size):
    for j in range(block_grid_size):
        max_logN_in_sk = np.log10(np.max(D2colden_blocks[i][j],axis=2)).flatten()
        mask_sk = (max_logN_in_sk > max_logN_mask)
        new_skwrs = np.where(~mask_sk)[0] 
        new_tau_on_blocks[i][j] = tau_on_blocks[i][j][new_skwrs]

In [None]:
for i in range(block_grid_size):
    for j in range(block_grid_size):
        max_logN_in_sk = np.log10(np.max(D2colden_blocks[i][j],axis=2)).flatten()
        mask_sk = (max_logN_in_sk > max_logN_mask)
        new_skwrs = np.where(~mask_sk)[0]
        new_tau_off_blocks[i][j] = tau_off_blocks[i][j][new_skwrs]

In [None]:
for i in range(block_grid_size):
    for j in range(block_grid_size):
        filename = f'tau_on{i}{j}.npy'
        full_path = save_path + filename

        np.save(full_path, new_tau_on_blocks[i, j])

In [None]:
for i in range(block_grid_size):
    for j in range(block_grid_size):
        filename = f'tau_off{i}{j}.npy'
        full_path = save_path + filename

        np.save(full_path, new_tau_off_blocks[i, j])

### Checkpoint

In [None]:
import numpy as np
import os
from matplotlib import pyplot as plt

In [None]:
# Constants
save_path = '/data/desi/scratch/jlopez/'
Nsk = 500
Np = 2500
block_grid_size = 5
block_size = Nsk // block_grid_size  # = 100
L_hMpc = 250
dz_hMpc=L_hMpc/Np
dxy_hMpc=L_hMpc/Nsk  # = 0.5 Mpc/h

In [None]:
# Define 4 radial bins (in Mpc/h)
r_edges = np.array([0, 0.6, 0.8, 1.2, 1.6])  # 5 edges → 4 bins
r_edges_pix = r_edges / dxy_hMpc  # convert to pixel units

In [None]:
# Load tau data
print("Loading tau blocks...")
tau_on_blocks, tau_off_blocks = st.load_all_tau_blocks(save_path, block_grid_size, block_size, Np)

In [None]:
for b in range(len(r_edges_pix) - 1):
    rmin = r_edges_pix[b]
    rmax = r_edges_pix[b + 1]
    print(f"Processing radial bin {b+1}: {rmin:.2f} < r < {rmax:.2f} pixels")

    for i in range(block_grid_size):
        for j in range(block_grid_size):
            tau_on = tau_on_blocks[i, j]
            tau_off = tau_off_blocks[i, j]

            N_skewers = tau_on.shape[0]  # might be < block_size**2

            #  Compute coordinate grid for current block
            ix, iy = np.divmod(np.arange(N_skewers), block_size)  
            dx = ix[:, None] - ix[None, :]
            dy = iy[:, None] - iy[None, :]
            distance_matrix = np.sqrt(dx**2 + dy**2)

            #  Define annulus mask locally per block
            mask = (distance_matrix >= rmin) & (distance_matrix < rmax)

            # Fields
            tau_tot, tau_lya, tau_hcd = st.fields(tau_on, tau_off)
            dF_tot = st.delta_F(tau_tot)
            dF_lya = st.delta_F(tau_lya)
            dF_hcd = st.delta_F(tau_hcd)
            dF_lyahcd = dF_lya * dF_hcd

            # FFTs
            fft_tot = np.fft.fft(dF_tot)
            fft_lya = np.fft.fft(dF_lya)
            fft_hcd = np.fft.fft(dF_hcd)
            fft_lyahcd = np.fft.fft(dF_lyahcd)

            # Px calculations
            Px_dict = {}
            Px_dict['Px_F'], _ = st.Px_sum(fft_tot, fft_tot, mask)
            Px_dict['Px_a'], _ = st.Px_sum(fft_lya, fft_lya, mask)
            Px_dict['Px_H'], _ = st.Px_sum(fft_hcd, fft_hcd, mask)
            Px_dict['Px_aH'], _ = st.Px_sum(fft_lya, fft_hcd, mask)
            tmp, _ = st.Px_sum(fft_hcd, fft_lya, mask)
            Px_dict['Px_aH'] += tmp
            Px_dict['Px_a3'], _ = st.Px_sum(fft_lya, fft_lyahcd, mask)
            tmp, _ = st.Px_sum(fft_lyahcd, fft_lya, mask)
            Px_dict['Px_a3'] += tmp
            Px_dict['Px_H3'], _ = st.Px_sum(fft_hcd, fft_lyahcd, mask)
            tmp, _ = st.Px_sum(fft_lyahcd, fft_hcd, mask)
            Px_dict['Px_H3'] += tmp
            Px_dict['Px_p4'], _ = st.Px_sum(fft_lyahcd, fft_lyahcd, mask)

            # Save results
            for key, Px in Px_dict.items():
                filename = f"{key}_rbin{b+1}_{i}{j}.npy"
                np.save(os.path.join(save_path, filename), Px.real)

print("All radial bins processed.")


In [None]:
for i in range(block_grid_size):
    for j in range(block_grid_size):
        tau_on = tau_on_blocks[i, j]
        tau_off = tau_off_blocks[i, j]

        # Fields
        tau_tot, tau_lya, tau_hcd = st.fields(tau_on, tau_off)
        dF_tot = st.delta_F(tau_tot)
        dF_lya = st.delta_F(tau_lya)
        dF_hcd = st.delta_F(tau_hcd)
        dF_lyahcd = dF_lya * dF_hcd

        # FFTs
        fft_tot = np.fft.fft(dF_tot)
        fft_lya = np.fft.fft(dF_lya)
        fft_hcd = np.fft.fft(dF_hcd)
        fft_lyahcd = np.fft.fft(dF_lyahcd)

        # P1D calculations
        P1D_dict = {}
        P1D_dict['P1D_F'] = st.P1D_sum(fft_tot, fft_tot)
        P1D_dict['P1D_a'] = st.P1D_sum(fft_lya, fft_lya)
        P1D_dict['P1D_H'] = st.P1D_sum(fft_hcd, fft_hcd)
        P1D_dict['P1D_aH'] = st.P1D_sum(fft_lya, fft_hcd)
        tmp = st.P1D_sum(fft_hcd, fft_lya)
        P1D_dict['P1D_aH'] += tmp

        P1D_dict['P1D_a3'] = st.P1D_sum(fft_lya, fft_lyahcd)
        tmp = st.P1D_sum(fft_lyahcd, fft_lya)
        P1D_dict['P1D_a3'] += tmp

        P1D_dict['P1D_H3'] = st.P1D_sum(fft_hcd, fft_lyahcd)
        tmp = st.P1D_sum(fft_lyahcd, fft_hcd)
        P1D_dict['P1D_H3'] += tmp

        P1D_dict['P1D_p4'] = st.P1D_sum(fft_lyahcd, fft_lyahcd)

        # Save results
        for key, P1D in P1D_dict.items():
            filename = f"{key}_{i}{j}.npy"
            np.save(os.path.join(save_path, filename), P1D.real)

## Checkpoint

In [None]:
import h5py
import numpy as np
import os
from matplotlib import pyplot as plt
import astrid as st

In [None]:
# Constants
save_path = '/data/desi/scratch/jlopez/'
Nsk = 500
Np = 2500
block_grid_size = 5
block_size = Nsk // block_grid_size  # = 100
L_hMpc = 250
dz_hMpc=L_hMpc/Np
dxy_hMpc=L_hMpc/Nsk  # = 0.5 Mpc/h
r_bins=4
n_blocks = block_grid_size ** 2
k = 2*np.pi/250*np.arange(Np)

In [None]:
# Define 4 radial bins (in Mpc/h)
r_edges = np.array([0, 0.6, 0.8, 1.2, 1.6])  # 5 edges → 4 bins
r_edges_pix = r_edges / dxy_hMpc  # convert to pixel units

In [None]:
tau_on_blocks, tau_off_blocks = st.load_all_tau_blocks(save_path, block_grid_size, block_size, Np)

In [None]:
Px_all = st.load_all_Px(save_path, block_grid_size, Np, r_bins)

In [None]:
P1D_all = st.load_all_P1D(save_path, block_grid_size, Np)

In [None]:
C = np.empty((block_grid_size, block_grid_size))
F_tot = np.empty((block_grid_size, block_grid_size))
F_lya = np.empty((block_grid_size, block_grid_size))
F_hcd = np.empty((block_grid_size, block_grid_size))

for i in range(block_grid_size):
    for j in range(block_grid_size):
        tau_on = tau_on_blocks[i, j]
        tau_off = tau_off_blocks[i, j]
        tau_tot, tau_lya, tau_hcd = st.fields(tau_on, tau_off)

        # Compute fluxes
        flux_tot = np.exp(-tau_tot)
        flux_lya = np.exp(-tau_lya)
        flux_hcd = np.exp(-tau_hcd)

        # Compute means
        avg_tot = np.mean(flux_tot)
        avg_lya = np.mean(flux_lya)
        avg_hcd = np.mean(flux_hcd)

        # Store result
        C[i, j] = avg_tot / (avg_lya * avg_hcd) - 1
        F_tot[i, j] = avg_tot
        F_lya[i, j] = avg_lya
        F_hcd[i, j] = avg_hcd

# Final average
F_lya_avg = np.mean(F_lya, keepdims=True)
F_hcd_avg = np.mean(F_hcd, keepdims=True)
F_tot_avg = np.mean(F_tot, keepdims=True)
C_1 = C * (F_lya / F_lya_avg) * (F_hcd / F_hcd_avg)
Cavg = np.mean(C_1)
print((np.mean(C) - Cavg) / Cavg)

In [None]:
P1D_F = np.empty(Np)
for i in range(block_grid_size):
        for j in range(block_grid_size):
            P1D_F += F_tot[i][j]*F_tot[i][j]*P1D_all["P1D_F"].reshape(block_grid_size,block_grid_size,Np)[i][j]
P1D_F /= n_blocks
print(P1D_F)
P1D_F /= (F_tot_avg[0] * F_tot_avg[0]) 
print(min(P1D_F))

In [None]:

keys = ['P1D_F', 'P1D_a', 'P1D_H', 'P1D_aH', 'P1D_a3', 'P1D_H3', 'P1D_p4']
colors = {
    'P1D_F': 'black',
    'P1D_a': 'tab:blue',
    'P1D_H': 'tab:orange',
    'P1D_aH': 'tab:green',
    'P1D_a3': 'tab:red',
    'P1D_H3': 'tab:purple',
    'P1D_p4': 'tab:brown'
}

P1D_all = {key: np.zeros((n_blocks, Np)) for key in keys}
# --- Load ---
for key in keys:
    idx = 0
    for i in range(block_grid_size):
        for j in range(block_grid_size):
            filename = f"{key}_{i}{j}.npy"
            path = os.path.join(save_path, filename)
            if os.path.exists(path):
                P1D_all[key][idx] = np.load(path)
            else:
                print(f"Missing: {filename}")
            idx += 1
# --- Mean, std, and total sum (excluding Px_F) ---
P1D_mean = {}
P1D_std = {}
P1D_sum = np.zeros(Np)
P1D_sum_std = np.zeros(Np)

for key in keys:
    P1D_mean[key] = P1D_all[key].mean(axis=0)
    P1D_std[key] = P1D_all[key].std(axis=0)
    if key != 'P1D_F':
        P1D_sum += P1D_mean[key]
        P1D_sum_std += P1D_std[key]**2  # accumulate variances

P1D_sum_std = np.sqrt(P1D_sum_std)

# --- Plot ---
plt.figure(figsize=(8, 6))

# Plot all components with shaded std
for key in keys:
    plt.plot(k, P1D_mean[key], label=key, color=colors[key])
    plt.fill_between(k,
                     P1D_mean[key] - P1D_std[key],
                     P1D_mean[key] + P1D_std[key],
                     color=colors[key], alpha=0.2)

# Plot sum (excluding P1D_F)
if 'Cavg' not in locals():
    Cavg = 0 
P1D_sum_model = P1D_sum / (1 + Cavg)**2
P1D_sum_model_std = P1D_sum_std / (1 + Cavg)**2

plt.plot(k, P1D_sum_model, label=r'$P_\times^{model}$', color='gray', linestyle='--', linewidth=2)
plt.fill_between(k,
                 P1D_sum_model - P1D_sum_model_std,
                 P1D_sum_model + P1D_sum_model_std,
                 color='gray', alpha=0.2)

plt.xlabel(r"$k_{\parallel}$ [$h$/Mpc]")
plt.ylabel(r'$P_{1D}(k_\parallel)\quad[Mpc/h]$')
plt.title(r"Average $P_{1D}$ for all sub boxes")
plt.xscale('log')
plt.xlim(k[1], 6)
plt.legend(loc='upper right')
plt.tight_layout()
plt.savefig("p1d")
plt.show()

In [None]:
# --- Constants ---
n_blocks = block_grid_size ** 2
bins_to_plot = [1, 2, 3, 4]

# --- FFT positive frequencies ---
k = 2*np.pi/250*np.arange(Np)
k_positive = k[:Np//2]

# --- Keys to load ---
keys = ['Px_F', 'Px_a', 'Px_H', 'Px_aH', 'Px_a3', 'Px_H3', 'Px_p4']
colors = {
    'Px_F': 'black',
    'Px_a': 'tab:blue',
    'Px_H': 'tab:orange',
    'Px_aH': 'tab:green',
    'Px_a3': 'tab:red',
    'Px_H3': 'tab:purple',
    'Px_p4': 'tab:brown'
}

for bin_number in bins_to_plot:

    # --- Storage ---
    Px_all = {key: np.zeros((n_blocks, Np)) for key in keys}

    # --- Load ---
    for key in keys:
        idx = 0
        for i in range(block_grid_size):
            for j in range(block_grid_size):
                filename = f"{key}_rbin{bin_number}_{i}{j}.npy"
                path = os.path.join(save_path, filename)
                if os.path.exists(path):
                    Px_all[key][idx] = np.load(path)
                else:
                    print(f"Missing: {filename}")
                idx += 1

    # --- Mean, std, and total sum (excluding Px_F) ---
    Px_mean = {}
    Px_std = {}
    Px_sum = np.zeros(Np)
    Px_sum_std = np.zeros(Np)

    for key in keys:
        Px_mean[key] = Px_all[key].mean(axis=0)
        Px_std[key] = Px_all[key].std(axis=0)
        if key != 'Px_F':
            Px_sum += Px_mean[key]
            Px_sum_std += Px_std[key]**2  # accumulate variances

    Px_sum_std = np.sqrt(Px_sum_std)

    # --- Plot ---
    plt.figure(figsize=(8, 6))

    # Plot all components with shaded std
    for key in keys:
        plt.plot(k, Px_mean[key], label=key, color=colors[key])
        plt.fill_between(k,
                         Px_mean[key] - Px_std[key],
                         Px_mean[key] + Px_std[key],
                         color=colors[key], alpha=0.2)

    # Plot sum (excluding Px_F)
    if 'Cavg' not in locals():
        Cavg = 0  # or appropriate default
    Px_sum_model = Px_sum / (1 + Cavg)**2
    Px_sum_model_std = Px_sum_std / (1 + Cavg)**2

    plt.plot(k, Px_sum_model, label=r'$P_\times^{model}$', color='gray', linestyle='--', linewidth=2)
    plt.fill_between(k,
                     Px_sum_model - Px_sum_model_std,
                     Px_sum_model + Px_sum_model_std,
                     color='gray', alpha=0.2)

    plt.xlabel(r"$k_{\parallel}$ [$h$/Mpc]")
    plt.ylabel(r'$P_\times(k_\parallel)\quad[Mpc/h]$')
    plt.title(fr"Average $P_\times$ at bin {bin_number} for all sub boxes")
    plt.xscale('log')
    plt.xlim(k[1], 5)
    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.show()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

# --- Constants ---
save_path = '/data/desi/scratch/jlopez/'
block_grid_size = 5
Np = 2500
n_blocks = block_grid_size ** 2
dxy_hMpc = 0.5  # update if needed

# --- FFT positive frequencies ---
k = 2 * np.pi / 250 * np.arange(Np)
k_positive = k[:Np//2]

# --- Radial bins ---
r_edges = np.array([0, 0.6, 0.8, 1.2, 1.6])  # Mpc/h
bins_to_plot = [1, 2, 3, 4]

# --- Px keys and colors ---
keys = ['Px_F', 'Px_a', 'Px_H', 'Px_aH', 'Px_a3', 'Px_H3', 'Px_p4']
colors = {
    'Px_F': 'black',
    'Px_a': 'tab:blue',
    'Px_H': 'tab:orange',
    'Px_aH': 'tab:green',
    'Px_a3': 'tab:red',
    'Px_H3': 'tab:purple',
    'Px_p4': 'tab:brown'
}

# --- Load P1D data for all fields ---
P1D_all = {key: [] for key in keys}

for key in keys:
    for i in range(block_grid_size):
        for j in range(block_grid_size):
            fname = f'{key.replace("Px", "P1D")}_{i}{j}.npy'
            path = os.path.join(save_path, fname)
            if os.path.exists(path):
                data = np.load(path)
                P1D_all[key].append(data.real)

P1D_mean = {key: np.mean(P1D_all[key], axis=0) if len(P1D_all[key]) > 0 else np.zeros(Np) for key in keys}
P1D_std = {key: np.std(P1D_all[key], axis=0) if len(P1D_all[key]) > 0 else np.zeros(Np) for key in keys}

# --- Compute P1D_model (sum of all except P1D_F) ---
P1D_model = np.zeros(Np)
P1D_model_std2 = np.zeros(Np)

for key in keys:
    if key != 'Px_F':
        P1D_model += P1D_mean[key]
        P1D_model_std2 += P1D_std[key] ** 2

P1D_model_std = np.sqrt(P1D_model_std2)

# --- Set up subplots: 1 row, 5 columns (leftmost = P1D) ---
fig, axes = plt.subplots(1, 5, figsize=(24, 6), sharey=True)

# --- Left panel: Full P1D ---
ax = axes[0]
for key in keys:
    ax.plot(k, P1D_mean[key], label=key, color=colors[key])
    ax.fill_between(k,
                    P1D_mean[key] - P1D_std[key],
                    P1D_mean[key] + P1D_std[key],
                    color=colors[key], alpha=0.2)

# Plot P1D model
P1D_model /= (1 + Cavg)**2
P1D_model_std /= (1 + Cavg)**2
ax.plot(k, P1D_model, label=r'$P_{1D}^{model}$', color='gray', linestyle='--', linewidth=2)
ax.fill_between(k,
                P1D_model - P1D_model_std,
                P1D_model + P1D_model_std,
                color='gray', alpha=0.2)

ax.set_title(r"$P_{1D}$")
ax.set_xlabel(r"$k_{\parallel}$ [$h$/Mpc]")
ax.set_xscale('log')
ax.set_xlim(k[1], 5)
ax.grid(True)

# --- Other panels: Px for each r_perp bin ---
for ax, bin_number in zip(axes[1:], bins_to_plot):
    # --- Storage ---
    Px_all = {key: np.zeros((n_blocks, Np)) for key in keys}

    # --- Load Px files ---
    for key in keys:
        idx = 0
        for i in range(block_grid_size):
            for j in range(block_grid_size):
                filename = f"{key}_rbin{bin_number}_{i}{j}.npy"
                path = os.path.join(save_path, filename)
                if os.path.exists(path):
                    Px_all[key][idx] = np.load(path)
                else:
                    print(f"Missing: {filename}")
                idx += 1

    # --- Compute mean, std, and sum ---
    Px_mean = {}
    Px_std = {}
    Px_sum = np.zeros(Np)
    Px_sum_std = np.zeros(Np)

    for key in keys:
        Px_mean[key] = Px_all[key].mean(axis=0)
        Px_std[key] = Px_all[key].std(axis=0)
        if key != 'Px_F':
            Px_sum += Px_mean[key]
            Px_sum_std += Px_std[key]**2

    Px_sum_std = np.sqrt(Px_sum_std)

    # --- Plot all Px components ---
    for key in keys:
        ax.plot(k, Px_mean[key], label=key, color=colors[key])
        ax.fill_between(k, Px_mean[key] - Px_std[key], Px_mean[key] + Px_std[key], color=colors[key], alpha=0.2)

    # --- Plot model sum (excluding Px_F) ---
    Px_sum_model = Px_sum / (1 + Cavg)**2
    Px_sum_model_std = Px_sum_std / (1 + Cavg)**2
    ax.plot(k, Px_sum_model, label=r'$P_\times^{model}$', color='gray', linestyle='--', linewidth=2)
    ax.fill_between(k, Px_sum_model - Px_sum_model_std, Px_sum_model + Px_sum_model_std, color='gray', alpha=0.2)

    # --- Overlay P1D_F without error ---
    ax.plot(k, P1D_mean['Px_F'], label=r'$P_{1D}^F$', color='black', linestyle=':', alpha=0.2)

    # --- Title with r_perp range ---
    rmin, rmax = r_edges[bin_number - 1], r_edges[bin_number]
    ax.set_title(fr"${rmin:.1f} < r_\perp < {rmax:.1f}$ [Mpc/h]")
    ax.set_xscale('log')
    ax.set_xlim(k[1], 5)
    ax.grid(True)
    ax.set_xlabel(r"$k_{\parallel}$ [$h$/Mpc]")

# --- Shared y-axis ---
axes[0].set_ylabel(r"$P(k_\parallel)\quad [Mpc/h]$")

# --- Legend ---
handles, labels_ = axes[-1].get_legend_handles_labels()
fig.legend(handles, labels_, loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=plt.gcf().transFigure)

# --- Layout and show ---
plt.suptitle(r"$P_{1D}(k_\parallel)$ and $P_\times(k_\parallel)$ across radial bins", fontsize=16)
plt.tight_layout(rect=[0, 0, 0.95, 0.95])
plt.savefig("p1d_px")
plt.show()


In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

# --- Constants ---
save_path = '/data/desi/scratch/jlopez/'
block_grid_size = 5
Np = 2500
k = 2 * np.pi / 250 * np.arange(Np)
n_blocks = block_grid_size ** 2

# --- Which bins to plot ---
bins_to_plot = [1, 2, 3, 4]  # bin indices are 1-based

# --- Load Px_F data and compute statistics ---
Px_mean = {}
Px_std = {}

for b in bins_to_plot:
    all_Px = []
    for i in range(block_grid_size):
        for j in range(block_grid_size):
            fname = f'Px_F_rbin{b}_{i}{j}.npy'
            path = os.path.join(save_path, fname)
            if os.path.exists(path):
                Px = np.load(path)
                all_Px.append(Px.real)
    all_Px = np.array(all_Px)
    
    if all_Px.shape[0] == 0:
        print(f"No data found for bin {b}")
        continue

    Px_mean[b] = np.mean(all_Px, axis=0)
    Px_std[b] = np.std(all_Px, axis=0)

# --- Load and average P1D_F ---
P1D_F_all = np.zeros((n_blocks, Np))

idx = 0
for i in range(block_grid_size):
    for j in range(block_grid_size):
        fname = f'P1D_F_{i}{j}.npy'
        path = os.path.join(save_path, fname)
        if os.path.exists(path):
            P1D_F_all[idx] = np.load(path)
        else:
            print(f"Missing: {fname}")
        idx += 1

P1D_F_mean = np.mean(P1D_F_all, axis=0)
P1D_F_std = np.std(P1D_F_all, axis=0)

# --- Plotting ---
plt.figure(figsize=(9, 6))

colors = ['C0', 'C1', 'C2', 'C3']
for idx, b in enumerate(bins_to_plot):
    mean = Px_mean[b]
    std = Px_std[b]
    label = f'$P_\\times^F$, bin {b}'

    plt.plot(k, mean, label=label, color=colors[idx])
    # Uncomment below if you want to add error bands for Px
    plt.fill_between(k, mean - std, mean + std,
                      color=colors[idx], alpha=0.3)

# Plot P1D_F mean and std
plt.plot(k, P1D_F_mean, color='black', label=r'$P_{1D}^F$')
plt.fill_between(k, P1D_F_mean - P1D_F_std, P1D_F_mean + P1D_F_std, color='black', alpha=0.2)

plt.xlabel(r"$k_{\parallel}$ [$h$/Mpc]")
plt.ylabel(r'$P(k_\parallel)\quad[Mpc/h]$')
plt.title(r"Comparison of $P_\times^F$ (bins 1–4) with $P_{1D}^F$")
plt.xscale('log')
plt.xlim(k[1], 6)
plt.legend(loc='upper right')
plt.tight_layout()
#plt.savefig('PxF_with_P1D')
plt.show()


In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

# --- Constants ---
save_path = '/data/desi/scratch/jlopez/'
block_grid_size = 5
Np = 2500
k = 2 * np.pi / 250 * np.arange(Np)
n_blocks = block_grid_size ** 2

# --- Which bins to plot ---
bins_to_plot = [1, 2, 3, 4]

# --- Load Px_a data and compute statistics ---
Px_mean = {}
Px_std = {}

for b in bins_to_plot:
    all_Px = []
    for i in range(block_grid_size):
        for j in range(block_grid_size):
            fname = f'Px_a_rbin{b}_{i}{j}.npy'
            path = os.path.join(save_path, fname)
            if os.path.exists(path):
                Px = np.load(path)
                all_Px.append(Px.real)
    all_Px = np.array(all_Px)
    
    if all_Px.shape[0] == 0:
        print(f"No data found for bin {b}")
        continue

    Px_mean[b] = np.mean(all_Px, axis=0)
    Px_std[b] = np.std(all_Px, axis=0)

# --- Load and average P1D_a ---
P1D_a_all = np.zeros((n_blocks, Np))

idx = 0
for i in range(block_grid_size):
    for j in range(block_grid_size):
        fname = f'P1D_a_{i}{j}.npy'
        path = os.path.join(save_path, fname)
        if os.path.exists(path):
            P1D_a_all[idx] = np.load(path)
        else:
            print(f"Missing: {fname}")
        idx += 1

P1D_a_mean = np.mean(P1D_a_all, axis=0)
P1D_a_std = np.std(P1D_a_all, axis=0)

# --- Plotting ---
plt.figure(figsize=(9, 6))

colors = ['C0', 'C1', 'C2', 'C3']
for idx, b in enumerate(bins_to_plot):
    mean = Px_mean[b]
    std = Px_std[b]
    label = f'$P_\\times^\\alpha$, bin {b}'

    plt.plot(k, mean, label=label, color=colors[idx])
    #Optional: Add error band
    plt.fill_between(k[:Np//2], mean[:Np//2] - std[:Np//2], mean[:Np//2] + std[:Np//2],
                      color=colors[idx], alpha=0.3)

# Plot P1D_a mean and std
plt.plot(k, P1D_a_mean, color='black', label=r'$P_{1D}^\alpha$')
plt.fill_between(k, P1D_a_mean - P1D_a_std, P1D_a_mean + P1D_a_std, color='black', alpha=0.2)

plt.xlabel(r"$k_{\parallel}$ [$h$/Mpc]")
plt.ylabel(r'$P(k_\parallel)\quad[Mpc/h]$')
plt.title(r"Comparison of $P_\times^\alpha$ (bins 1–4) with $P_{1D}^\alpha$")
plt.xscale('log')
plt.xlim(k[1], 10)
plt.legend(loc='upper right')
plt.tight_layout()
#plt.savefig('Pxa_with_P1D')
plt.show()


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

# --- Constants ---
save_path = '/data/desi/scratch/jlopez/'
block_grid_size = 5
Np = 2500
k = 2 * np.pi / 250 * np.arange(Np)
n_blocks = block_grid_size ** 2

# --- Effective r_perp values per bin ---
r_perp = {
    1: 0.5,
    2: 0.707,
    3: (1.0 + 1.118)/2,
    4: (1.414 + 1.5 + 1.581)/3
}
bins_to_use = [1, 2, 3, 4]
r_perp_vals = [r_perp[b] for b in bins_to_use]

# --- Target k_parallel values (added 0.1) ---
k_targets = [0.05, 0.1, 0.5]
k_indices = [np.argmin(np.abs(k - k_val)) for k_val in k_targets]

# --- Fields ---
fields = ['F', 'a', 'H', 'aH']
colors = {'F': 'black', 'a': 'tab:blue', 'H': 'tab:orange', 'aH': 'tab:green'}
labels = {
    'F': r'$P_\times^F$',
    'a': r'$P_\times^\alpha$',
    'H': r'$P_\times^H$',
    'aH': r'$P_\times^{\alpha H}$'
}

# --- Initialize containers ---
Px_mean = {field: {k_val: [] for k_val in k_targets} for field in fields}

# --- Compute Px_mean from subboxes ---
for field in fields:
    for b in bins_to_use:
        Px_all_boxes = []

        for i in range(block_grid_size):
            for j in range(block_grid_size):
                fname = f'Px_{field}_rbin{b}_{i}{j}.npy'
                path = os.path.join(save_path, fname)
                if os.path.exists(path):
                    Px = np.load(path)
                    Px_all_boxes.append(Px.real)

        Px_all_boxes = np.array(Px_all_boxes)

        if Px_all_boxes.shape[0] == 0:
            print(f"No data for field {field}, bin {b}")
            for k_val in k_targets:
                Px_mean[field][k_val].append(np.nan)
            continue

        for k_val, k_idx in zip(k_targets, k_indices):
            Px_mean[field][k_val].append(np.mean(Px_all_boxes[:, k_idx]))

# --- Load and average P1D for r_perp = 0 ---
P1D = {field: {} for field in fields}

for field in fields:
    P1D_all_boxes = []

    for i in range(block_grid_size):
        for j in range(block_grid_size):
            fname = f'P1D_{field}_{i}{j}.npy'
            path = os.path.join(save_path, fname)
            if os.path.exists(path):
                Pk = np.load(path)
                P1D_all_boxes.append(Pk.real)

    if len(P1D_all_boxes) == 0:
        print(f"No P1D data for field {field}")
        for k_val in k_targets:
            P1D[field][k_val] = np.nan
        continue

    P1D_all_boxes = np.array(P1D_all_boxes)

    for k_val, k_idx in zip(k_targets, k_indices):
        P1D[field][k_val] = np.mean(P1D_all_boxes[:, k_idx])

# --- Insert r_perp = 0 (P1D) point ---
r_perp_vals_with_zero = [0.0] + r_perp_vals
Px_mean_with_p1d = {
    field: {
        k_val: [P1D[field][k_val]] + Px_mean[field][k_val]
        for k_val in k_targets
    }
    for field in fields
}

# --- Plot: one subplot per k_parallel, share y-axis ---
fig, axes = plt.subplots(
    nrows=1,
    ncols=len(k_targets),
    figsize=(5 * len(k_targets), 6),
    sharey=True
)

for ax, k_val in zip(axes, k_targets):
    for field in ['a', 'H', 'F', 'aH']:
        if k_val not in Px_mean_with_p1d[field]:
            continue
        ax.plot(
            r_perp_vals_with_zero,
            Px_mean_with_p1d[field][k_val],
            label=labels[field],
            color=colors[field]
        )
    ax.set_xlabel(r"$r_\perp$ [Mpc/h]")
    ax.set_title(f"$k_\\parallel = {k_val}$")
    ax.legend()

axes[0].set_ylabel(r"$P(r_\perp)$ [Mpc/h]")

plt.suptitle(r"$P_\times$ vs $r_\perp$ for different $k_\parallel$")
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig('P_rperp_subplots.png')
plt.show()

