In [None]:
using ArbNumerics, DifferentialEquations, DataInterpolations, Format, Integrals, MCIntegration, NaturalUnits

using CairoMakie, LaTeXStrings

In [None]:
set_theme!(theme_latexfonts())

include("tool_script-directory.jl")
include("tool_script-geomspace.jl")

include("tool_script-collision_term_integrand.jl")

In [None]:
EU = GeV
NU = NaturalUnit(EU)

# arb_float = arb_float
arb_float = identity

In [None]:
a_ini = arb_float(1)
H_ini = arb_float(2e-5) * NU.M_Pl

In [None]:
save_plot_flag = false

In [None]:
function evolution_system!(du_list, u_list, params, log_a)
    n_inflaton_comoving = EU(exp(u_list[1]), 3)
    ρ_total_comoving = EU(exp(u_list[2]), 4)
    a = exp(log_a)

    m_inflaton = params.m_inflaton
    λ_inflaton_reheaton_reheaton = params.λ_inflaton_reheaton_reheaton

    ρ_inflaton_comoving = min(m_inflaton * n_inflaton_comoving, ρ_total_comoving / a)
    # ρ_reheaton_comoving = ρ_total_comoving - ρ_inflaton_comoving * a
    n_inflaton_comoving = ρ_inflaton_comoving / m_inflaton

    ρ_total = ρ_total_comoving / a^4
    H = sqrt(ρ_total / (3 * NU.M_Pl^2))

    d_n_inflaton_comoving_OVER_d_a = -(λ_inflaton_reheaton_reheaton^2 / (32 * π * m_inflaton * a * H)) * exp(
        π * λ_inflaton_reheaton_reheaton^2 / (m_inflaton^4 * a^3 * H) * n_inflaton_comoving
    ) * n_inflaton_comoving
    d_ρ_total_comoving_OVER_d_a = ρ_inflaton_comoving

    du_list[1] = d_n_inflaton_comoving_OVER_d_a * a / n_inflaton_comoving
    du_list[2] = d_ρ_total_comoving_OVER_d_a * a / ρ_total_comoving

    return du_list
end

In [None]:
function solve_evolution_system(;
    a_fin=arb_float(1e6),
    m_inflaton=GeV(arb_float(1e13)),
    λ_inflaton_reheaton_reheaton=GeV(arb_float(1e7)),
    solver=Rosenbrock23(),
    initial_condition_fixing=false
)
    n_inflaton_ini = 3 * NU.M_Pl^2 * H_ini^2 / m_inflaton
    ρ_total_comoving = m_inflaton * n_inflaton_ini * a_ini^4

    n_inflaton_ini = initial_condition_fixing ? min(
        m_inflaton^4 * H_ini / (π * λ_inflaton_reheaton_reheaton^2) * log(
            32 * π * m_inflaton * H_ini / λ_inflaton_reheaton_reheaton^2
        ) * 5,
        n_inflaton_ini
    ) : n_inflaton_ini
        
    evolution_problem = ODEProblem(evolution_system!,
        [
            (log ∘ EUval)(EU, n_inflaton_ini),
            (log ∘ EUval)(EU, ρ_total_comoving),
        ],
        (log(a_ini), log(a_fin)),
        (
            m_inflaton = m_inflaton,
            λ_inflaton_reheaton_reheaton = λ_inflaton_reheaton_reheaton,
        )
    )
    evolution_solution = solve(evolution_problem, solver)
    evolution_solution.retcode == ReturnCode.Success || @warn "ODE return code: $(evolution_solution.retcode)"

    return evolution_solution
end

In [None]:
function f_reheaton(a, p_reheaton, m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution)
    (m_inflaton / 2) * (a_ini / a) ≤ p_reheaton ≤ (m_inflaton / 2) || return arb_float(0)

    a′ = (2 * p_reheaton * a) / m_inflaton

    n′_inflaton_comoving_in_EU = (exp ∘ first ∘ evolution_solution ∘ log)(a′)
    ρ′_total_comoving_in_EU = (exp ∘ last ∘ evolution_solution ∘ log)(a′)

    n′_inflaton = EU(n′_inflaton_comoving_in_EU, 3) / a′^3
    ρ′_total = EU(ρ′_total_comoving_in_EU, 4) / a′^4
    
    H′ = sqrt(ρ′_total / (3 * NU.M_Pl^2))

    return exp(π * λ_inflaton_reheaton_reheaton^2 * n′_inflaton / (m_inflaton^4 * H′)) - 1
end

In [None]:
function collision_term_h(a, kₕ, m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution;
    abstol=0, reltol=1e-10
)
    a ≤ a_ini && return EU(arb_float(0))
    kₕ ≤ m_inflaton || return EU(arb_float(0))
    
    prefactor = inv(64 * (2 * π)^4 * kₕ)
    function integrand(Eq, p₁)
        sₘₐₓ = 4 * min(kₕ * (Eq - kₕ), p₁ * (Eq - p₁))

        result = CT_integrand(kₕ, Eq, p₁, sₘₐₓ) / (Eq - kₕ)
        result *= f_reheaton(a, p₁, m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution)
        result *= f_reheaton(a, Eq - p₁, m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution)

        return result
    end
    function numerical_integrand(x_list, _)
        x_Eq, x_p₁ = x_list

        Eqₘᵢₙ = max(kₕ, m_inflaton * (a_ini / a))
        Eqₘₐₓ = m_inflaton
        Eqₘᵢₙ < Eqₘₐₓ || return arb_float(0)
        Eq = Eqₘᵢₙ + (Eqₘₐₓ - Eqₘᵢₙ) * x_Eq

        p₁_length = min(m_inflaton - Eq, Eq - m_inflaton * (a_ini / a))
        EUval(EU, p₁_length) ≤ arb_float(0) && return arb_float(0)

        p₁ₘᵢₙ = (Eq - p₁_length) / 2
        # p₁ₘₐₓ = (Eq + p₁_length) / 2
        p₁ = p₁ₘᵢₙ + p₁_length * x_p₁

        # open("tmp.log", "a+") do io
        #     write(io,
        #         "x: $(x_list)\nintegrand: $(EUval(EU, integrand(s, Eq, p₁) * sₘₐₓ * (Eqₘₐₓ - Eqₘᵢₙ) * p₁_length))\n\n"
        #         # "a = $(a)\ns = $(s)\nEq = $(Eq)\np₁ = $(p₁)\nintegrand = $(integrand(s, Eq, p₁))\n\n";
        #     )
        # end

        # return EUval(EU, integrand(s, Eq, p₁) * sₘₐₓ * (Eqₘₐₓ - Eqₘᵢₙ) * (p₁ₘₐₓ - p₁ₘᵢₙ))
        return EUval(EU, integrand(Eq, p₁) * (Eqₘₐₓ - Eqₘᵢₙ) * p₁_length)
    end

    integral_problem = IntegralProblem(numerical_integrand, ([arb_float(0) for _ ∈ 1:3], [arb_float(1) for _ ∈ 1:3]))
    integral_solution = solve(integral_problem, HCubatureJL())
    # integral_solution = solve(integral_problem, VEGASMC())
    integral_solution.retcode == ReturnCode.Success || @warn "Integral return code: $(integral_solution.retcode)"

    return prefactor * EU(integral_solution.u, 2)
end

In [None]:
function calculate_fₕ(a, kₕ;
    m_inflaton=GeV(arb_float(1e13)),
    λ_inflaton_reheaton_reheaton=GeV(arb_float(1e7)),
    solver=Rosenbrock23(),
    number_of_CT::Int=10^6,
    verbose_flag::Bool=true,
    abstol=0, reltol=1e-10
)
    kₕ ≤ m_inflaton || return zero(arb_float)
    # k̃ₕ = kₕ * a / a_ini
    evolution_solution = solve_evolution_system(
        a_fin = a,
        m_inflaton = m_inflaton,
        λ_inflaton_reheaton_reheaton = λ_inflaton_reheaton_reheaton,
        initial_condition_fixing = true,
        solver = solver,
    )

    a′_list = geomspace(kₕ * a / m_inflaton, a, number_of_CT)
    # a′_list = geomspace(a_ini, a, number_of_CT)
    kₕ′_list = (kₕ * a) ./ a′_list
    verbose_flag && @info "Calculating collision term..."
    CT_list = [
        collision_term_h(a′, kₕ', m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution;
            abstol=abstol, reltol=reltol
        ) for (a′, kₕ′) in zip(a′_list, kₕ′_list)
    ]
    verbose_flag && @info "Done for collision term calculation."

    ρ′_total_comoving_in_EU_list = (exp ∘ last ∘ evolution_solution ∘ log).(a′_list)
    ρ′_total_list = EU.(ρ′_total_comoving_in_EU_list, 4) ./ a′_list.^4
    H′_list = sqrt.(ρ′_total_list ./ (3 * NU.M_Pl^2))
    
    integrand_list = CT_list ./ (a′_list.^4 .* H′_list)
    @assert all(>(0), integrand_list) "Non-positive integrand encountered in fₕ calculation: $(integrand_list)"

    integrand_interpolation = LinearInterpolation(log.(integrand_list), log.(a′_list))
    integral_problem = IntegralProblem(
        (a′, _) -> exp(integrand_interpolation(log(a′))),
        (first(a′_list), last(a′_list)),
    )
    integral_solution = solve(integral_problem, QuadGKJL(); abstol=abstol, reltol=reltol)
    integral_solution.retcode == ReturnCode.Success || @warn "Integral return code: $(integral_solution.retcode)"

    return integral_solution.u
end

In [None]:
m_inflaton = GeV(arb_float(1e13))
λ = GeV(arb_float(1e7))
a_fin = arb_float(1e8) * a_ini

kₕ_list = geomspace(1e-4 * m_inflaton, 9.9e-1 * m_inflaton, 100)
fₕ_list = calculate_fₕ.(a_fin, kₕ_list; number_of_CT=1000)

In [None]:
figure = Figure()
axis = Axis(figure[1, 1];
    xlabel = L"k_h ~ [\mathrm{GeV}]",
    ylabel = L"f_h",
    xscale = log10, yscale = log10,
)

figure