In [1]:
# %%
# %%
import os
import time
import subprocess
import pathlib
import functools
import math
import numba
import typing as t
import numpy as np
import numpy.typing as npt
import astropy.units as u
import astropy.constants as ac

import pandas as pd
from astropy import units as u
from pathlib import Path
from scipy.special import erf
from concurrent.futures import ThreadPoolExecutor

_DEFAULT_CHUNK_SIZE = 10000000
ac_h_c_on_kB = ac.h * ac.c.cgs / ac.k_B
ac_16_pi_c = 8 * np.pi * ac.c.cgs
const_h_c_on_kB = ac_h_c_on_kB.value
const_16_pi_c = ac_16_pi_c.value

# working_directory = r"/scratch/cbowesman/OH/"  # remote
working_directory = r"/mnt/c/PhD/OH/ExoMol/XABC11_Unbound_States_Trans/"  # local


log_file = fr"{working_directory}OH_cont.log"

continuum_states_file = Path(fr"{working_directory}16O-1H__MYTHOS.states.cont")
continuum_trans_files = Path(fr"{working_directory}16O-1H__MYTHOS.trans.cont")

continuum_states = pd.read_csv(
    continuum_states_file,
    names=["id", "energy", "g", "J"],
    usecols=[0, 1, 2, 3],
    sep=r"\s+",
)

wn_grid = pd.read_csv(fr"{working_directory}OH_T1_P1.0000e+03.xsec", delimiter=r"\s+", usecols=[0], names=["wn"])
wn_grid = wn_grid["wn"].to_numpy()
with open(log_file, "a") as file:
    file.write(f"Computing continuum absorption on wavenumber grid with {len(wn_grid)} points...\n")


@numba.njit(parallel=True)
def _continuum_band_profile_variable_width(
    wn_grid: npt.NDArray[np.float64],
    n_frac_i: npt.NDArray[np.float64],
    a_fi: npt.NDArray[np.float64],
    g_f: npt.NDArray[np.float64],
    g_i: npt.NDArray[np.float64],
    energy_fi: npt.NDArray[np.float64],
    temperature: float,
    cont_broad: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
    sqrtln2 = math.sqrt(math.log(2))
    _abs_xsec = np.zeros(wn_grid.shape)
    num_trans = energy_fi.shape[0]
    num_grid = wn_grid.shape[0]

    bin_widths = np.zeros(wn_grid.shape[0] + 1)
    bin_widths[1:-1] = (wn_grid[:-1] + wn_grid[1:]) / 2.0 - wn_grid[:-1]

    abs_coef = (
        g_f
        * n_frac_i
        * a_fi
        * (1 - np.exp(-const_h_c_on_kB * energy_fi / temperature))
        / (const_16_pi_c * g_i * energy_fi**2)
    )
    sqrtln2_on_alpha = sqrtln2 / cont_broad

    for i in numba.prange(num_trans):
        for j in range(num_grid):
            wn_shift = wn_grid[j] - energy_fi[i]
            upper_width = bin_widths[j + 1]
            lower_width = bin_widths[j]
            if min(abs(wn_shift - lower_width), abs(wn_shift + upper_width)) <= 1500:
                _abs_xsec[j] += (
                    abs_coef[i]
                    * (
                        math.erf(sqrtln2_on_alpha[i] * (wn_shift + upper_width))
                        - math.erf(sqrtln2_on_alpha[i] * (wn_shift - lower_width))
                    )
                    / (upper_width + lower_width)
                )
    # for i in numba.prange(num_trans):
    #     start_idx = min(
    #         np.argmax(abs(wn_grid - bin_widths[:-1] - energy_fi[i]) <= 1500),
    #         np.argmax(abs(wn_grid + bin_widths[1:] - energy_fi[i]) <= 1500),
    #     )
    #     end_idx = len(wn_grid) - min(
    #         np.argmax(abs(wn_grid - bin_widths[:-1] - energy_fi[i])[::-1] <= 1500),
    #         np.argmax(abs(wn_grid + bin_widths[1:] - energy_fi[i])[::-1] <= 1500),
    #     )
    #     wn_shift = wn_grid[start_idx:end_idx] - energy_fi[i]
    #     _abs_xsec[start_idx:end_idx] += (
    #         abs_coef[i]
    #         * (
    #             erf(sqrtln2_on_alpha[i] * (wn_shift + bin_widths[start_idx + 1 : end_idx + 1]))
    #             - erf(sqrtln2_on_alpha[i] * (wn_shift - bin_widths[start_idx:end_idx]))
    #         )
    #         / (bin_widths[start_idx + 1 : end_idx + 1] + bin_widths[start_idx:end_idx])
    #     )
    return _abs_xsec


def boltzmann_population(states: pd.DataFrame, temperature: u.Quantity) -> pd.DataFrame:
    states["q_lev"] = states["g"] * np.exp(-ac_h_c_on_kB * (states["energy"] << 1 / u.cm) / temperature)
    states["n"] = states["q_lev"] / states["q_lev"].sum()
    return states

numba.set_num_threads(16)

temperature_list = np.linspace(1, 4000, 40) << u.K

for temperature in temperature_list:
    cont_trans_reader = pd.read_csv(
        continuum_trans_files,
        names=["id_c", "id_i", "a_ci", "broad"],
        sep=r"\s+",
        chunksize=_DEFAULT_CHUNK_SIZE,
    )
    with open(log_file, "a") as file:
        file.write(f"Computing continuum absorption for T={temperature}...\n")
    abs_xsec = np.zeros_like(wn_grid)
    for chunk in cont_trans_reader:
        temp_states = boltzmann_population(states=continuum_states.copy(), temperature=temperature)
        chunk = chunk.merge(
            temp_states[["id", "g", "energy", "n"]],
            left_on="id_i",
            right_on="id",
            how="inner",
        )
        # chunk = chunk.loc[chunk["n_frac"] > 0.0]  # Need to keep these for rates!
        chunk = chunk.rename(
            columns={
                "g": "g_i",
                "energy": "energy_i",
                "n": "n_i",
            }
        )
        chunk = chunk.drop(columns=["id_i", "id"])

        chunk = chunk.merge(
            temp_states[["id", "g", "energy"]],
            left_on="id_c",
            right_on="id",
            how="inner",
        )
        chunk = chunk.rename(columns={"g": "g_c", "energy": "energy_c"})
        chunk = chunk.drop(columns=["id_c", "id"])

        chunk["energy_ci"] = chunk["energy_c"] - chunk["energy_i"]
        chunk = chunk.loc[
            (chunk["energy_c"] >= wn_grid[0])
            & (chunk["energy_i"] <= wn_grid[-1])  # TODO: IS this condition valid?
            & (chunk["energy_ci"] >= wn_grid[0])
            & (chunk["energy_ci"] <= wn_grid[-1])
        ]
        abs_xsec += _continuum_band_profile_variable_width(
            wn_grid=wn_grid,
            n_frac_i=chunk["n_i"].to_numpy(),
            a_fi=chunk["a_ci"].to_numpy(),
            g_f=chunk["g_c"].to_numpy(),
            g_i=chunk["g_i"].to_numpy(),
            energy_fi=chunk["energy_ci"].to_numpy(),
            temperature=temperature.value,
            cont_broad=chunk["broad"].to_numpy(),
        )
    np.savetxt(
        fr"{working_directory}OH_T{int(temperature.value)}.cont.xsec",
        np.array([wn_grid, abs_xsec]).T,
        fmt="%17.8E",
    )
    # Write out the profile for that T.

KeyboardInterrupt: 

In [26]:
"""
Add continuum Xsecs to outputs. The run "run_trove_sup2.csh" with the first launch line (which runs exocross) commented out to transform into pickles, stored in /xsec_combined/.
"""
import numpy as np
import pandas as pd
import glob

cont_directory = r"/scratch/dp060/dc-bowe4/exomolop/OH/"  # Dial
# cont_directory = r"/mnt/c/PhD/OH/ExoMol/XABC11_Unbound_States_Trans/"
out_directory = r"/scratch/dp060/dc-bowe4/exomolop/working/output/"  # Dial
# out_directory = r"/mnt/c/PhD/OH/ExoCross/output/"

log_file = fr"{cont_directory}OH_cont.log"

temperature_list = np.linspace(1, 4000, 40)
for temperature in temperature_list:
    cont_xsec = pd.read_csv(fr"{cont_directory}OH_T{int(temperature)}.cont.xsec", sep=r"\s+", names=["wn", "intens"])
    # cont_xsec = pd.read_csv(fr"{cont_directory}OH_T{int(temperature)}_P1.0000e+03.xsec", sep=r"\s+", names=["wn", "intens"])
    print(cont_xsec.head(4))
    output_files = glob.glob(fr"{out_directory}OH_T{int(temperature)}_P*.out")
    print(output_files)
    for out_file in output_files:
        out_xsec = pd.read_csv(out_file, sep=r"\s+", names=["wn", "intens"])
        if len(out_xsec) != len(cont_xsec):
            raise RuntimeError(f"Xsec length mismatch for file {out_file}.")
        else:
            out_xsec["intens"] += cont_xsec["intens"]
        # print(out_xsec.to_numpy())
        np.savetxt(
            out_file,
            out_xsec.to_numpy(),
            fmt="%17.8E",
        )
        with open(log_file, "a") as file:
            file.write(f"Processed {out_file}\n")

         wn  intens
0  1.000000     0.0
1  1.000001     0.0
2  1.000002     0.0
3  1.000003     0.0
['/mnt/c/PhD/OH/ExoCross/output/OH_T1_P1.0000e+03.out', '/mnt/c/PhD/OH/ExoCross/output/OH_T1_P1.0000e-08.out']


In [28]:
print([int(temp) for temp in np.linspace(1, 4000, 40)])

[1, 103, 206, 308, 411, 513, 616, 718, 821, 923, 1026, 1128, 1231, 1334, 1436, 1539, 1641, 1744, 1846, 1949, 2051, 2154, 2256, 2359, 2461, 2564, 2667, 2769, 2872, 2974, 3077, 3179, 3282, 3384, 3487, 3589, 3692, 3794, 3897, 4000]
