# Correlated-k Distribution

In [ ]:
from exojax.test.emulate_mdb import mock_mdbExomol, mock_wavenumber_grid
from exojax.opacity.premodit.api import OpaPremodit
from exojax.opacity.ckd.api import OpaCKD
from exojax.opacity.ckd.core import compute_g_ordinates
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import config

config.update("jax_enable_x64", True)

nus, wav, res = mock_wavenumber_grid(lambda0=22930.0, lambda1=22940.0, Nx=20000)
mdb = mock_mdbExomol("H2O")
opa = OpaPremodit(mdb, nus, auto_trange=[500.0, 1500.0])

In [ ]:
def sample_g(nus, xsv, j_pickup, Ng=10):
    idx, k_g, g = compute_g_ordinates(xsv)
    
    edges = jnp.linspace(0.0, 1.0, Ng + 1)
    cut_idx = jnp.searchsorted(g, edges)
    
    nus_segments = [nus[idx[cut_idx[i]:cut_idx[i+1]]] for i in range(Ng)]
    xsv_segments = [xsv[idx[cut_idx[i]:cut_idx[i+1]]] for i in range(Ng)]
    
    k_low = k_g[cut_idx[j_pickup]]
    k_high = k_g[cut_idx[j_pickup+1]-1]
    k_med = (k_low + k_high) * 0.5
    mask = (xsv >= k_low) & (xsv < k_high)
    
    return {
        'k_g': k_g, 'g': g, 'edges': edges, 'cut_idx': cut_idx,
        'nus_segments': nus_segments, 'xsv_segments': xsv_segments,
        'k_med': k_med, 'mask': mask
    }

In [ ]:
def plot_xsv(nus, xsv, nus_segments, xsv_segments, j, k_med, mask):
    plt.plot(nus, xsv, alpha=0.7)
    plt.fill_between(nus, xsv.min()*0.1, k_med, where=mask, alpha=0.3, color='orange')
    plt.plot(nus_segments[j], xsv_segments[j], '.', markersize=2)
    plt.yscale('log')
    plt.ylabel('σ(ν) (cm²)')
    plt.ylim(xsv.min()*0.1, xsv.max()*2)

T, P, j_pickup = 1000.0, 0.01, 6
xsv = opa.xsvector(T, P)
k_g, g, edges, cut_idx, nus_segments, xsv_segments, k_med, mask = sample_g(nus, xsv, j_pickup)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

plt.sca(ax1)
plot_xsv(nus, xsv, nus_segments, xsv_segments, j_pickup, k_med, mask)
plt.xlabel('ν (cm⁻¹)')

plt.sca(ax2)
plt.plot(g, k_g)
plt.plot(g[cut_idx[j_pickup]:cut_idx[j_pickup+1]], xsv_segments[j_pickup], '.', markersize=2)
plt.axhline(xsv_segments[j_pickup].max(), alpha=0.3, color="gray")
plt.axhline(xsv_segments[j_pickup].min(), alpha=0.3, color="gray")
plt.axvline(edges[j_pickup], alpha=0.3, color="gray")
plt.axvline(edges[j_pickup + 1], alpha=0.3, color="gray")
plt.text((edges[j_pickup] + edges[j_pickup + 1]) / 2, xsv.min()*0.1, "$\\Delta g_j$", 
         horizontalalignment="center", verticalalignment="bottom", fontsize=12)
plt.yscale('log')
plt.xlabel('g')
plt.ylabel('σ(g) (cm²)')
plt.ylim(xsv.min()*0.1, xsv.max()*2)

plt.tight_layout()
plt.show()

In [ ]:
conditions = [(700, 0.1), (1000, 0.01), (1000, 0.1), (1300, 0.1), (1000, 1.0)]

fig, axes = plt.subplots(2, 3, figsize=(12, 6))
axes = axes.flatten()

for i, (T, P) in enumerate(conditions):
    if i >= 5: break
    plt.sca(axes[i])
    xsv = opa.xsvector(T, P)
    k_g, g, edges, cut_idx, nus_segments, xsv_segments, k_med, mask = sample_g(nus, xsv, 6)
    plot_xsv(nus, xsv, nus_segments, xsv_segments, 6, k_med, mask)
    plt.title(f'T={T}K, P={P}bar', fontsize=10)
    
    if i >= 2:
        plt.xlabel('ν (cm⁻¹)')

axes[5].axis('off')
plt.tight_layout()
plt.show()

In [ ]:
opa_ckd = OpaCKD(opa, Ng=16)
T_grid = jnp.array([700.0, 1000.0, 1300.0])
P_grid = jnp.array([0.01, 0.1, 1.0])
opa_ckd.precompute_tables(T_grid, P_grid)