In [3]:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Enhanced 21cmFAST Brightness Temperature Visualization Script

This script computes the brightness temperature using 21cmFAST across a range of redshifts
and generates comprehensive visualizations. It includes performance optimizations, robust
error handling, and enhanced plotting features.

Author: Adrita Khan
Date: 2024-12-07
"""

# Importing necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import logging
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Importing 21cmFAST and related tools
import py21cmfast as p21c
from py21cmfast import plotting, cache_tools

# Importing tools21cm
import tools21cm as t2c

# Importing custom utilities (ensure 'utils.py' is in the same directory or adjust the path)
from utils import *

# Configure logging
logger = logging.getLogger('21cmFAST')
logger.setLevel(logging.INFO)

# Define redshift range
z = np.arange(4.5, 15, 0.2)
print(f"Redshift range: {z}")
print(f"Number of redshift points: {len(z)}")

def brightness_temperature(box, redshift, hubble=0.7, matter=0.3):
    """
    Compute the brightness temperature using 21cmFAST.

    Parameters:
    - box (int): Size of the simulation box (in Mpc/h).
    - redshift (float): Redshift at which to compute the brightness temperature.
    - hubble (float): Dimensionless Hubble parameter.
    - matter (float): Matter density parameter.

    Returns:
    - brightness_temp (numpy.ndarray): 3D array of brightness temperatures.
    """
    try:
        # Initialize initial conditions
        initial_conditions = p21c.initial_conditions(
            user_params={"HII_DIM": box, "BOX_LEN": box},
            cosmo_params=p21c.CosmoParams(SIGMA_8=0.8, hlittle=hubble, OMm=matter),
            random_seed=54321
        )

        # Generate perturbed fields
        perturbed_field = p21c.perturb_field(
            redshift=redshift,
            init_boxes=initial_conditions
        )

        # Ionize the box
        ionized_field = p21c.ionize_box(
            perturbed_field=perturbed_field
        )

        # Calculate brightness temperature
        brightness_temp = p21c.brightness_temperature(
            ionized_box=ionized_field,
            perturbed_field=perturbed_field
        )

        return brightness_temp.brightness_temp

    except Exception as e:
        logger.error(f"Error in brightness_temperature function: {e}")
        return None

# Compute brightness temperature for each redshift
Yf = []
for redshift in z:
    temp = brightness_temperature(box=120, redshift=redshift)
    if temp is not None:
        Yf.append(temp)
    else:
        logger.warning(f"Brightness temperature at z={redshift} could not be computed.")
Yf = np.array(Yf)
print(f"Brightness temperature array shape: {Yf.shape}")

# Verify the shape of Yf
if Yf.ndim != 4:
    logger.error(f"Unexpected Yf dimensions: {Yf.ndim}. Expected 4D array (z, x, y, z).")
    sys.exit(1)

# Function to plot a single brightness temperature slice
def plot_brightness_temperature(slice_data, redshift, box_size, save=False, save_path=None):
    """
    Plot a single slice of brightness temperature.

    Parameters:
    - slice_data (numpy.ndarray): 2D array of brightness temperatures.
    - redshift (float): Redshift corresponding to the slice.
    - box_size (int): Size of the simulation box.
    - save (bool): Whether to save the plot as a file.
    - save_path (str): Path to save the plot.
    """
    plt.figure(figsize=(10, 8))
    im = plt.imshow(slice_data, origin='lower', cmap='viridis', extent=[0, box_size, 0, box_size])
    plt.title(f"Brightness Temperature at z = {redshift}", fontsize=16)
    plt.xlabel('X-axis (Mpc/h)', fontsize=14)
    plt.ylabel('Y-axis (Mpc/h)', fontsize=14)
    cbar = plt.colorbar(im)
    cbar.set_label('Brightness Temperature (mK)', fontsize=14)
    plt.tight_layout()
    
    if save and save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

# Plotting brightness temperature for each redshift and saving to a PDF
pdf_filename = 'brightness_temperature_plots.pdf'
with PdfPages(pdf_filename) as pdf:
    for i, redshift in enumerate(z):
        # Select the first slice (e.g., z=0) for visualization; adjust as needed
        slice_data = Yf[i, :, :, 0]
        
        # Plot the slice
        plt.figure(figsize=(10, 8))
        im = plt.imshow(slice_data, origin='lower', cmap='viridis', extent=[0, 120, 0, 120])
        plt.title(f"Brightness Temperature at z = {redshift}", fontsize=16)
        plt.xlabel('X-axis (Mpc/h)', fontsize=14)
        plt.ylabel('Y-axis (Mpc/h)', fontsize=14)
        cbar = plt.colorbar(im)
        cbar.set_label('Brightness Temperature (mK)', fontsize=14)
        plt.tight_layout()
        
        # Save the current figure to the PDF
        pdf.savefig()
        plt.close()

print(f"All plots have been saved to {pdf_filename}")

# Optional: Display a specific slice interactively
# Example: Display the first redshift slice
# plot_brightness_temperature(Yf[0, :, :, 0], z[0], box_size=120)


Redshift range: [ 4.5  4.7  4.9  5.1  5.3  5.5  5.7  5.9  6.1  6.3  6.5  6.7  6.9  7.1
  7.3  7.5  7.7  7.9  8.1  8.3  8.5  8.7  8.9  9.1  9.3  9.5  9.7  9.9
 10.1 10.3 10.5 10.7 10.9 11.1 11.3 11.5 11.7 11.9 12.1 12.3 12.5 12.7
 12.9 13.1 13.3 13.5 13.7 13.9 14.1 14.3 14.5 14.7 14.9]
Number of redshift points: 53
Brightness temperature array shape: (53, 120, 120, 120)
All plots have been saved to brightness_temperature_plots.pdf
