# Preliminaries

In [None]:
using ArbNumerics, DifferentialEquations, Integrals, NaturalUnits

using CairoMakie, Format, LaTeXStrings

In [None]:
set_theme!(theme_latexfonts())

include("tool_script-directory.jl")
include("tool_script-geomspace.jl")

include("tool_script-integral_over_s_and_t.jl")

In [None]:
EU = GeV
NU = NaturalUnit(EU)

# arb_float = identity
arb_float = x -> ArbFloat(x; bits=1000)

In [None]:
aᵢₙᵢ = arb_float(1)
Hᵢₙᵢ = arb_float(2e-5) * NU.M_Pl

In [None]:
save_plot_flag = false

# Playground

In [None]:
function evolution_system!(du_list, u_list, params, log_a)
    n_inflaton_comoving = EU(exp(u_list[1]), 3)
    ρ_total_comoving = EU(exp(u_list[2]), 4)
    a = exp(log_a)

    m_inflaton = params.m_inflaton
    λ_inflaton_reheaton_reheaton = params.λ_inflaton_reheaton_reheaton

    ρ_inflaton_comoving = min(m_inflaton * n_inflaton_comoving, ρ_total_comoving / a)
    # ρ_reheaton_comoving = ρ_total_comoving - ρ_inflaton_comoving * a
    n_inflaton_comoving = ρ_inflaton_comoving / m_inflaton

    ρ_total = ρ_total_comoving / a^4
    H = sqrt(ρ_total / (3 * NU.M_Pl^2))

    d_n_inflaton_comoving_OVER_d_a = -(λ_inflaton_reheaton_reheaton^2 / (32 * π * m_inflaton * a * H)) * exp(
        π * λ_inflaton_reheaton_reheaton^2 / (m_inflaton^4 * a^3 * H) * n_inflaton_comoving
    ) * n_inflaton_comoving
    d_ρ_total_comoving_OVER_d_a = ρ_inflaton_comoving

    du_list[1] = d_n_inflaton_comoving_OVER_d_a * a / n_inflaton_comoving
    du_list[2] = d_ρ_total_comoving_OVER_d_a * a / ρ_total_comoving

    return du_list
end

In [None]:
function solve_evolution_system(;
    a_fin=arb_float(1e6),
    m_inflaton=GeV(arb_float(1e13)),
    λ_inflaton_reheaton_reheaton=GeV(arb_float(1e7)),
    solver=Rosenbrock23(),
    initial_condition_fixing=false
)
    n_inflaton_ini = 3 * NU.M_Pl^2 * Hᵢₙᵢ^2 / m_inflaton
    ρ_total_comoving = m_inflaton * n_inflaton_ini * aᵢₙᵢ^4

    n_inflaton_ini = initial_condition_fixing ? min(
        m_inflaton^4 * Hᵢₙᵢ / (π * λ_inflaton_reheaton_reheaton^2) * log(
            32 * π * m_inflaton * Hᵢₙᵢ / λ_inflaton_reheaton_reheaton^2
        ) * 5,
        n_inflaton_ini
    ) : n_inflaton_ini
        
    evolution_problem = ODEProblem(evolution_system!,
        [
            (log ∘ EUval)(EU, n_inflaton_ini),
            (log ∘ EUval)(EU, ρ_total_comoving),
        ],
        (log(aᵢₙᵢ), log(a_fin)),
        (
            m_inflaton = m_inflaton,
            λ_inflaton_reheaton_reheaton = λ_inflaton_reheaton_reheaton,
        )
    )
    evolution_solution = solve(evolution_problem, solver)
    evolution_solution.retcode == ReturnCode.Success || @warn "ODE return code: $(evolution_solution.retcode)"

    return evolution_solution
end

In [None]:
function f̃_reheaton(a, p̃_reheaton, params)
    m_inflaton = params.m_inflaton
    λ_inflaton_reheaton_reheaton = params.λ_inflaton_reheaton_reheaton
    evolution_solution = params.evolution_solution

    (m_inflaton * aᵢₙᵢ) / 2 ≤ p̃_reheaton ≤ (m_inflaton * a) / 2 || return arb_float(0)

    a′ = (2 * p̃_reheaton) / m_inflaton

    n′_inflaton_comoving_in_EU = (exp ∘ first ∘ evolution_solution ∘ log)(a′)
    ρ′_total_comoving_in_EU = (exp ∘ last ∘ evolution_solution ∘ log)(a′)

    n′_inflaton = min(
        EU(n′_inflaton_comoving_in_EU, 3) / a′^3,
        m_inflaton^4 * aᵢₙᵢ * Hᵢₙᵢ / (π * λ_inflaton_reheaton_reheaton^2) * log(
            32 * π * m_inflaton * Hᵢₙᵢ / λ_inflaton_reheaton_reheaton^2
        ) * (aᵢₙᵢ / a′)^3,
    )

    ρ′_total = EU(ρ′_total_comoving_in_EU, 4) / a′^4
    
    H′ = sqrt(ρ′_total / (3 * NU.M_Pl^2))

    return exp(π * λ_inflaton_reheaton_reheaton^2 * n′_inflaton / (m_inflaton^4 * H′)) - 1
end

In [None]:
function integrand_f̃_graviton(α, Ẽq, p̃₁, params)
    k̃₁ = params.k̃₁
    m_inflaton = params.m_inflaton
    λ_inflaton_reheaton_reheaton = params.λ_inflaton_reheaton_reheaton
    evolution_solution = params.evolution_solution

    H = begin
        ρ_total_comoving_in_EU = (exp ∘ last ∘ evolution_solution ∘ log)(α)
        ρ_total = EU(ρ_total_comoving_in_EU, 4) / α^4
        sqrt(ρ_total / (3 * NU.M_Pl^2))
    end
    k̃₂ = Ẽq - k̃₁
    p̃₂ = Ẽq - p̃₁
    f̃_p̃₁ = f̃_reheaton(α, p̃₁, params)
    f̃_p̃₂ = f̃_reheaton(α, p̃₂, params)

    s̃ₘₐₓ = 4 * min(k̃₁ * k̃₂, p̃₁ * p̃₂)
    # @show k̃₁, Ẽq, p̃₁, s̃ₘₐₓ
    # @show integral_over_s_and_t(k̃₁, Ẽq, p̃₁, s̃ₘₐₓ)
    integrand = integral_over_s_and_t(k̃₁, Ẽq, p̃₁, s̃ₘₐₓ) / α^5

    integrand *= f̃_p̃₁ * f̃_p̃₂ / (H * k̃₁ * k̃₂)

    return integrand
end

function integrand_reduced(x_list, params)
    x_ln_α_over_αᵢₙᵢ, x_Ẽq, x_p̃₁ = x_list
    a = params.a
    k̃₁ = params.k̃₁
    m_inflaton = params.m_inflaton
    λ_inflaton_reheaton_reheaton = params.λ_inflaton_reheaton_reheaton
    evolution_solution = params.evolution_solution

    k̃₁ < m_inflaton * a || return arb_float(0)

    Jacobian = arb_float(1)

    αₘᵢₙ = max(aᵢₙᵢ, k̃₁ / m_inflaton)
    ln_α_over_αᵢₙᵢ = x_ln_α_over_αᵢₙᵢ * log(a / αₘᵢₙ)
    α = αₘᵢₙ * exp(ln_α_over_αᵢₙᵢ)
    Jacobian *= α * log(a / αₘᵢₙ) 

    Ẽqₘᵢₙ = max(k̃₁, m_inflaton * aᵢₙᵢ)
    Ẽqₘₐₓ = m_inflaton * α
    Ẽq = x_Ẽq * (Ẽqₘₐₓ - Ẽqₘᵢₙ) + Ẽqₘᵢₙ
    Jacobian *= (Ẽqₘₐₓ - Ẽqₘᵢₙ)

    p̃₁_length = min(m_inflaton * α - Ẽq, Ẽq - m_inflaton * aᵢₙᵢ)
    p̃₁ₘᵢₙ = (Ẽq - p̃₁_length) / 2
    p̃₁ₘₐₓ = (Ẽq + p̃₁_length) / 2
    p̃₁ = x_p̃₁ * (p̃₁ₘₐₓ - p̃₁ₘᵢₙ) + p̃₁ₘᵢₙ
    Jacobian *= (p̃₁ₘₐₓ - p̃₁ₘᵢₙ)

    integrand = Jacobian * integrand_f̃_graviton(α, Ẽq, p̃₁, params)

    # open("integrand.err", "a+") do io
    (isnan(integrand) || integrand < 0) && open("integrand.err", "a+") do io
    # rand(1:10000) == 500 && open("integrand.err", "a+") do io
        write(io, "x: $(x_list)\n")
        write(io, "k̃₁: $(k̃₁)\n")
        write(io, "α: $(α), Ẽq: $(Ẽq), p̃₁: $(p̃₁)\n")
        write(io, "Jacobian: $(Jacobian)\n")
        write(io, "integrand: $(integrand)\n\n")
    end

    return integrand
end

In [None]:
function f̃_graviton(a, k̃₁, params;
    abstol=arb_float(0), abstol_controlling_ratio=arb_float(1e-3), verbose_flag=false
)
    # a = params.a
    # k̃₁ = params.k̃₁
    # m_inflaton = params.m_inflaton
    # λ_inflaton_reheaton_reheaton = params.λ_inflaton_reheaton_reheaton
    # evolution_solution = params.evolution_solution


    # evolution_solution = solve_evolution_system(
    #     a_fin=a,
    #     m_inflaton=m_inflaton,
    #     λ_inflaton_reheaton_reheaton=λ_inflaton_reheaton_reheaton,
    #     solver=Rosenbrock23(),
    #     initial_condition_fixing=true
    # )
    params = (
        a = a,
        k̃₁ = k̃₁,
        m_inflaton = params.m_inflaton,
        λ_inflaton_reheaton_reheaton = params.λ_inflaton_reheaton_reheaton,
        evolution_solution = params.evolution_solution,
    )

    controlling_flag = false
    while true
        integral_problem = IntegralProblem(integrand_reduced, ([arb_float(0) for _ ∈ 1:3], [arb_float(1) for _ ∈ 1:3]), params)
        integral_solution = solve(integral_problem, HCubatureJL(); abstol=abstol)
        integral_solution.retcode == ReturnCode.Success || @warn "Integral return code: $(integral_solution.retcode)"
        integral = integral_solution.u

        if integral * abstol_controlling_ratio < abstol
            abstol = if controlling_flag
                abstol * 1e-1
            else
                controlling_flag = true
                integral * abstol_controlling_ratio
            end
            verbose_flag && @info "Integral value: $(integral),\nretrying with abstol=$(abstol)"
        else
            verbose_flag && @info "Integral value: $(integral)\nwith abstol=$(abstol)"
            return integral
        end
    end
end

In [None]:
function magic(;
    a = arb_float(1e6) * aᵢₙᵢ,
    Eₕ_min_ratio = arb_float(1e-10), number_of_Eₕ = 100,
    m_inflaton = GeV(arb_float(1e13)),
    λ_inflaton_reheaton_reheaton = GeV(arb_float(1e7)),
    solver = Rosenbrock23(),
    abstol = arb_float(0), abstol_controlling_ratio = arb_float(1e-3),
    verbose_flag = false
)
    evolution_solution = solve_evolution_system(
        a_fin=a,
        m_inflaton=m_inflaton,
        λ_inflaton_reheaton_reheaton=λ_inflaton_reheaton_reheaton,
        solver=solver,
        initial_condition_fixing=true
    )

    params = (
        m_inflaton = m_inflaton,
        λ_inflaton_reheaton_reheaton = λ_inflaton_reheaton_reheaton,
        evolution_solution = evolution_solution,
    )

    @assert 0 < Eₕ_min_ratio < 1
    Eₕ_list = geomspace(Eₕ_min_ratio, arb_float(1), number_of_Eₕ) .* m_inflaton
    fₕ_list = [arb_float(0.) for _ ∈ 1:number_of_Eₕ]
    
    for (ii, Eₕ) ∈ enumerate(Eₕ_list)
        k̃₁ = Eₕ * a
        verbose_flag && @info "For $ii/$(number_of_Eₕ)...\nk̃₁: $(k̃₁)"
        fₕ_list[ii] = f̃_graviton(a, k̃₁, params;
            abstol=abstol, abstol_controlling_ratio=abstol_controlling_ratio, verbose_flag=verbose_flag
        )
    end

    return Eₕ_list, fₕ_list
end

In [None]:
magic(
    a = arb_float(1e6) * aᵢₙᵢ,
    m_inflaton = GeV(arb_float(1e13)),
    λ_inflaton_reheaton_reheaton = GeV(arb_float(1e7)),
    abstol = arb_float(1e0),
    verbose_flag = true
)

# Dump

In [None]:
figure = Figure()

axis = Axis(figure[1, 1];
    xlabel = L"E_h ~ [\mathrm{GeV}]",
    ylabel = L"f_h",
    limits = (EUval.(GeV, (1e-10 * m_inflaton, m_inflaton)), (1e-28, 1e-21)),
    xscale = log10, yscale = log10,
)

lines!(axis, EUval.(GeV, k̃₁_list ./ a_fin), f̃_graviton_list_a1e6, label=L"a / a_\mathrm{ini} = 10^6")
lines!(axis, EUval.(GeV, k̃₁_list ./ a_fin), f̃_graviton_list_a1e8, label=L"a / a_\mathrm{ini} = 10^8")
lines!(axis, EUval.(GeV, k̃₁_list ./ a_fin), f̃_graviton_list_a1e10, label=L"a / a_\mathrm{ini} = 10^{10}")
axislegend(axis)


save_plot_flag || save(joinpath(plot_directory, "f_h_m1e13GeV_lambda1e7.pdf"), figure)

figure

In [None]:
a_fin = arb_float(1e6) * aᵢₙᵢ
m_inflaton = GeV(arb_float(1e13))
λ_inflaton_reheaton_reheaton = GeV(arb_float(1e7))
k̃₁_list = geomspace(arb_float(1e-10), arb_float(1), 100) .* (m_inflaton * a_fin)
f̃_graviton_list_a1e6 = [arb_float(0) for _ ∈ k̃₁_list]

for (ii, k̃₁) ∈ enumerate(k̃₁_list)
    @info "For $ii/$(length(k̃₁_list))"
    f̃_graviton_list_a1e6[ii] = magic(a_fin, k̃₁, m_inflaton, λ_inflaton_reheaton_reheaton; abstol=arb_float(1e-16), verbose_flag=true)
end

In [None]:
a_fin = arb_float(1e8) * aᵢₙᵢ
m_inflaton = GeV(arb_float(1e13))
λ_inflaton_reheaton_reheaton = GeV(arb_float(1e7))
# k̃₁_list = geomspace(arb_float(1e-10), arb_float(1), 50) .* (m_inflaton * a_fin)

magic(a_fin,
    GeV(arb_float(28117686979742.30611904412762381568232246223420018046)),
    m_inflaton, λ_inflaton_reheaton_reheaton;
)

In [None]:
a_fin = arb_float(1e8) * aᵢₙᵢ
k̃₁ = GeV(arb_float(4941713361323834.593603563024284407548743520486428903), 1)
m_inflaton = GeV(arb_float(1e13))
λ_inflaton_reheaton_reheaton = GeV(arb_float(1e7))

evolution_solution = solve_evolution_system(
    a_fin=a_fin,
    m_inflaton=m_inflaton,
    λ_inflaton_reheaton_reheaton=λ_inflaton_reheaton_reheaton,
    solver=Rosenbrock23(),
    initial_condition_fixing=true
)
params = (
    a = a_fin,
    k̃₁ = k̃₁,
    m_inflaton = m_inflaton,
    λ_inflaton_reheaton_reheaton = λ_inflaton_reheaton_reheaton,
    evolution_solution = evolution_solution,
)

integrand_f̃_graviton(
    arb_float(1057.904734372786893726421624960781014189150043902259),
    GeV(arb_float(9883420889295012.704148754698051538096475225278220861), 1),
    GeV(arb_float(4941710444647506.35207437734902576904823761263911043)),
    params
)

In [None]:
function collision_term_h(a, kₕ, m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution;
    abstol=0, reltol=1e-10
)
    a ≤ aᵢₙᵢ && return EU(arb_float(0))
    kₕ ≤ m_inflaton || return EU(arb_float(0))
    
    prefactor = inv(64 * (2 * π)^4 * kₕ)
    function integrand(Eq, p₁)
        sₘₐₓ = 4 * min(kₕ * (Eq - kₕ), p₁ * (Eq - p₁))

        result = CT_integrand(kₕ, Eq, p₁, sₘₐₓ) / (Eq - kₕ)
        result *= f_reheaton(a, p₁, m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution)
        result *= f_reheaton(a, Eq - p₁, m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution)

        return result
    end
    function numerical_integrand(x_list, _)
        x_Eq, x_p₁ = x_list

        Eqₘᵢₙ = max(kₕ, m_inflaton * (aᵢₙᵢ / a))
        Eqₘₐₓ = m_inflaton
        Eqₘᵢₙ < Eqₘₐₓ || return arb_float(0)
        Eq = Eqₘᵢₙ + (Eqₘₐₓ - Eqₘᵢₙ) * x_Eq

        p₁_length = min(m_inflaton - Eq, Eq - m_inflaton * (aᵢₙᵢ / a))
        EUval(EU, p₁_length) ≤ arb_float(0) && return arb_float(0)

        p₁ₘᵢₙ = (Eq - p₁_length) / 2
        # p₁ₘₐₓ = (Eq + p₁_length) / 2
        p₁ = p₁ₘᵢₙ + p₁_length * x_p₁

        # open("tmp.log", "a+") do io
        #     write(io,
        #         "x: $(x_list)\nintegrand: $(EUval(EU, integrand(s, Eq, p₁) * sₘₐₓ * (Eqₘₐₓ - Eqₘᵢₙ) * p₁_length))\n\n"
        #         # "a = $(a)\ns = $(s)\nEq = $(Eq)\np₁ = $(p₁)\nintegrand = $(integrand(s, Eq, p₁))\n\n";
        #     )
        # end

        # return EUval(EU, integrand(s, Eq, p₁) * sₘₐₓ * (Eqₘₐₓ - Eqₘᵢₙ) * (p₁ₘₐₓ - p₁ₘᵢₙ))
        return EUval(EU, integrand(Eq, p₁) * (Eqₘₐₓ - Eqₘᵢₙ) * p₁_length)
    end

    integral_problem = IntegralProblem(numerical_integrand, ([arb_float(0) for _ ∈ 1:3], [arb_float(1) for _ ∈ 1:3]))
    integral_solution = solve(integral_problem, HCubatureJL())
    # integral_solution = solve(integral_problem, VEGASMC())
    integral_solution.retcode == ReturnCode.Success || @warn "Integral return code: $(integral_solution.retcode)"

    return prefactor * EU(integral_solution.u, 2)
end

In [None]:
function calculate_fₕ(a, kₕ;
    m_inflaton=GeV(arb_float(1e13)),
    λ_inflaton_reheaton_reheaton=GeV(arb_float(1e7)),
    solver=Rosenbrock23(),
    number_of_CT::Int=10^6,
    verbose_flag::Bool=true,
    abstol=0, reltol=1e-10
)
    kₕ ≤ m_inflaton || return zero(arb_float)
    # k̃ₕ = kₕ * a / aᵢₙᵢ
    evolution_solution = solve_evolution_system(
        a_fin = a,
        m_inflaton = m_inflaton,
        λ_inflaton_reheaton_reheaton = λ_inflaton_reheaton_reheaton,
        initial_condition_fixing = true,
        solver = solver,
    )

    a′_list = geomspace(kₕ * a / m_inflaton, a, number_of_CT)
    # a′_list = geomspace(aᵢₙᵢ, a, number_of_CT)
    kₕ′_list = (kₕ * a) ./ a′_list
    verbose_flag && @info "Calculating collision term..."
    CT_list = [
        collision_term_h(a′, kₕ', m_inflaton, λ_inflaton_reheaton_reheaton, evolution_solution;
            abstol=abstol, reltol=reltol
        ) for (a′, kₕ′) in zip(a′_list, kₕ′_list)
    ]
    verbose_flag && @info "Done for collision term calculation."

    ρ′_total_comoving_in_EU_list = (exp ∘ last ∘ evolution_solution ∘ log).(a′_list)
    ρ′_total_list = EU.(ρ′_total_comoving_in_EU_list, 4) ./ a′_list.^4
    H′_list = sqrt.(ρ′_total_list ./ (3 * NU.M_Pl^2))
    
    integrand_list = CT_list ./ (a′_list.^4 .* H′_list)
    @assert all(>(0), integrand_list) "Non-positive integrand encountered in fₕ calculation: $(integrand_list)"

    integrand_interpolation = LinearInterpolation(log.(integrand_list), log.(a′_list))
    integral_problem = IntegralProblem(
        (a′, _) -> exp(integrand_interpolation(log(a′))),
        (first(a′_list), last(a′_list)),
    )
    integral_solution = solve(integral_problem, QuadGKJL(); abstol=abstol, reltol=reltol)
    integral_solution.retcode == ReturnCode.Success || @warn "Integral return code: $(integral_solution.retcode)"

    return integral_solution.u
end

In [None]:
m_inflaton = GeV(arb_float(1e13))
λ = GeV(arb_float(1e7))
a_fin = arb_float(1e8) * aᵢₙᵢ

kₕ_list = geomspace(1e-4 * m_inflaton, 9.9e-1 * m_inflaton, 100)
fₕ_list = calculate_fₕ.(a_fin, kₕ_list; number_of_CT=1000)