This notebook is for the inquiry into a heating ODE - starting with the steady state developed in cluster_notebook and evolving it through time to allow for a heating function.

In [None]:
import random
import statistics
import matplotlib.pyplot as plt
import numpy as np
#from scipy.integrate import odeint
#from scipy.stats import linregress
from scipy.interpolate import interp1d
from scipy.integrate import solve_ivp
from scipy.constants import k as k_B  # Boltzmann constant in J/K
from scipy.constants import G, proton_mass
from functools import partial
import lab_functions_1 as lf
from numba import njit
from scipy.integrate import cumulative_trapezoid
G = G * 1e3
mp_g = proton_mass*1e3
k_Bcgs = k_B*1e7
mu = 0.6
gamma = 5/3
ktc = 3.0857e21
etkv = 6.2415*10**8
kevtk = 1.16*10**7
z = 0.597
Mdot1 = 6.30391e25
stelmass = 3e12
totmass = 2e15
Mdot1 = 6.30391e25

In [None]:
radii_grid = np.linspace(0.1*ktc, 20000*ktc, 80000)
vc_grid = np.array([lf.vcgrab(r, z, 3e12, 2e15) for r in radii_grid])
vc_interp = interp1d(radii_grid, vc_grid, kind='cubic', fill_value='extrapolate')


r_grid = np.geomspace(ktc*0.01, 10000000 * ktc, 100000)  # finer grid = better accuracy
vc2_over_r = np.array([(lf.vcgrab(r, z, 3e12, 2e15 )**2 / r) for r in r_grid])

phi_cumint = cumulative_trapezoid(vc2_over_r, r_grid, initial=0.0)

phi_values = -(phi_cumint[-1] - phi_cumint)

phi_interp = interp1d(r_grid, phi_values, kind='cubic', fill_value="extrapolate")

def phi(r):
    """Interpolated gravitational potential at any r."""
    return phi_interp(r)

In [None]:
@njit
def compute_dvdr_dTdr(v, T, r, Mdot, vc2, Lambda):
    cs2 = (gamma * k_Bcgs * T) / (mu * mp_g)
    if cs2 == 0:
        print(T)
    tflow = r / abs(v)
    mach2 = v**2 / cs2
    rho = Mdot / (4 * np.pi * r**2 * v)
    n = rho / (mu * mp_g)
    tcool = (3 * k_Bcgs * T) / (2 * Lambda * n)

    dlnvdlnr = (2 - (vc2 / cs2) - (tflow / (gamma * tcool))) / (mach2 - 1.0)
    dlnTdlnr = (tflow / tcool) - (2 / 3) * (2 + dlnvdlnr)

    dvdr = (v / r) * dlnvdlnr
    dTdr = (T / r) * dlnTdlnr

    return dvdr, dTdr

In [None]:
def TheODE(r, C, Mdot, Lambdatype, recorder=None):
    v, T = C
    vc2 = vc_interp(r)**2

    # n is computed inside, so use dummy rho/n for Lambda
    rho = Mdot / (4 * np.pi * r**2 * v)
    n = rho / (mu * mp_g)
    Lambda = lf.Lambdacalc(np.log10(T), r, Lambdatype, n)

    # if T > 10**8.16:
    #     print(T)
    dvdr, dTdr = compute_dvdr_dTdr(v, T, r, Mdot, vc2, Lambda)

    if recorder is not None:
        recorder["rarray"].append(r)
        recorder["varray"].append(v)
        recorder["Tarray"].append(T)
        recorder["rhoarray"].append(rho)
        recorder["dvdr"].append(dvdr)
        recorder["dTdr"].append(dTdr)
    return [dvdr, dTdr]

In [None]:
# CLASSES (THE BANE OF MY EXISTENCE)
class IntegrationResult:
    def __init__(self, res, stop_reason, xval=None, R0=None, v0=None, T0=None, Mdot=None):
        self.res = res
        self._stop_reason = stop_reason
        self.xval = xval
        self.R0 = R0
        self.v0 = v0
        self.T0 = T0
    
    def stopReason(self):
        return self._stop_reason
    
    def Rs(self):
        return self.res.t
    
    def __getitem__(self, key):
        return self.res[key]


# EVENTS LIST
def event_unbound(r, C, Mdot, Lambdatype, _):
    v, T = C
    cs2 = (gamma * k_Bcgs * T) / (mu * mp_g)
    phi_r = phi(r)
    bern = 0.5 * v**2 + 1.5 * cs2 + phi_r
    return bern
event_unbound.terminal = True
event_unbound.direction = 1

def event_lowT(r, C, Mdot, Lambdatype, _):
    T = C[1]
    return T - (10**4.2)
event_lowT.terminal = True
event_lowT.direction = -1 

def event_sonic_point(r, C, Mdot, Lambdatype, _):
    v, T = C
    cs2 = (gamma * k_Bcgs * T) / (mu * mp_g)
    mach = v / np.sqrt(cs2)
    return mach - 1.0
event_sonic_point.terminal = True
event_sonic_point.direction = -1

def event_max_R(r, C, Mdot, Lambdatype, _):
    v, T = C
    return r - (20000*ktc)
event_max_R.terminal = True
event_max_R.direction = 1

def event_overstepdlnv(r, C, Mdot, Lambdatype, _):
    v, T = C

    vc2 = vc_interp(r)**2
    cs2 = (gamma * k_Bcgs * T) / (mu * mp_g)
    
    tflow = r / np.abs(v)
    mach = v / np.sqrt(cs2)

    rho = Mdot / (4 * np.pi * r**2 * v)
    n = rho / (mu * mp_g)
    Lambda = lf.Lambdacalc(np.log10(T), r, Lambdatype, n)
    
    tcool = (3 * k_Bcgs * T) / (2 * Lambda * n)

    dlnvdlnr = (2 - (vc2 / cs2) - (tflow/ (gamma*tcool))) / (mach**2 - 1.0)
    return np.abs(dlnvdlnr) - 50
event_overstepdlnv.terminal = True
event_overstepdlnv.direction = 1 

my_event_list = [
    event_sonic_point,
    event_unbound,
    event_lowT,
    event_max_R,
    event_overstepdlnv
]


event_names = ['sonic point', 'unbound', 'lowT', 'max R reached', 'overstepdlnv' ]

# SHOOTING METHOD
def sonic_point_shooting(Rsonic, Lambdatype, Rmax=20000*ktc, tol=1e-8, epsilon=1e-5, dlnMdlnRInit=-1, x_high=0.99, x_low=0.01, return_all_results=False):
    results = {}
    dlnMdlnRold = dlnMdlnRInit
    
    # x = v_c / 2*c_s is the iterative variable
    while x_high - x_low > tol:
        #INITIAL GUESSES
        x = 0.5 * (x_high + x_low)
        #print(x)
        cs2_sonic = vc_interp(Rsonic)**2 / (2 * x)
        v_sonic = cs2_sonic**0.5
        T_sonic = mu * mp_g * cs2_sonic / (gamma * k_Bcgs)
        tflow_to_tcool = (10/3) * (1 - x)
        
        rho_sonic = lf.rhocalc(v_sonic, tflow_to_tcool, T_sonic, Rsonic, Lambdatype)
        if rho_sonic == False:
            x_high = x
            continue
        Mdot = 4 * np.pi * Rsonic**2 * rho_sonic * v_sonic
        
        dlnTdlnR1, dlnTdlnR2 = lf.dlnTdlnrcalc(Rsonic, x, 0.597, T_sonic, Lambdatype)
        if dlnTdlnR1 is None:
            x_high = x
            continue
        
        dlnMdlnR1, dlnMdlnR2 = [3 - 5*x - 2*dlnTdlnR for dlnTdlnR in (dlnTdlnR1, dlnTdlnR2)]
        if abs(dlnMdlnR1 - dlnMdlnRold) < abs(dlnMdlnR2 - dlnMdlnRold):
            dlnTdlnR = dlnTdlnR1
        else:
            dlnTdlnR = dlnTdlnR2
        
        dlnMdlnR = 3 - 5*x - 2*dlnTdlnR
        
        dlnvdlnR = -1.5 * dlnTdlnR + 3 - 5 * x
        
        T0 = T_sonic * (1 + epsilon * dlnTdlnR)
        v0 = v_sonic * (1 + epsilon * dlnvdlnR)
        R0 = Rsonic * (1 + epsilon)

        # Early checks
        cs2_0 = (gamma * k_Bcgs * T0) / (mu * mp_g)
        mach0 = v0 / np.sqrt(cs2_0)
        if mach0 > 1.0:
            print("starts supersonic")
            x_high = x
            continue

        phi0 = phi(R0)
        bern = 0.5 * v0**2 + 1.5 * cs2_0 + phi0
        
        if bern > 0:
            print(f"starts unbound")
            x_low = x
            continue
        res_raw = solve_ivp(TheODE, [R0, Rmax], [v0, T0], args=(Mdot, Lambdatype, None), method='RK45', 
            atol=1e-5, rtol=1e-5, events=my_event_list, dense_output=True)
        
        if res_raw.status < 0:
            stop_reason = 'integration failure'
        elif any(len(evt) > 0 for evt in res_raw.t_events):
            for idx, t_evt in enumerate(res_raw.t_events):
                if len(t_evt) > 0:
                    stop_reason = event_names[idx]
                    break
        else:
            if res_raw.t[-1] >= Rmax:
                stop_reason = 'max R reached'
            else:
                stop_reason = 'unknown'
        
        res = IntegrationResult(res_raw, stop_reason, xval=x, R0=R0, v0=v0, T0=T0)
        
        #print(f"maximum r = {res.Rs()[-1] / 3.0857e21:.2f} kpc; stop reason: {res.stopReason()}")
        
        if res.stopReason() in ('sonic point', 'lowT', 'overstepdlnv'):
            x_high = x
            continue
        elif res.stopReason() == 'unbound':
            x_low = x
            continue
        elif res.stopReason() == 'max R reached':
            dlnMdlnRold = dlnMdlnR
            results[x] = res
            print(f"x = {x}, Rsonic = {Rsonic/ktc} Mdot = {Mdot/Mdot1}")
            break
        else:
            print(f"Warning: Unexpected stopReason '{res.stopReason()}' — stopping loop.")
            break
    
    if return_all_results:
        return results
    if len(results) == 0:
        print("no result reached maximum R")
        return None
    else:
        return results[x]

In [None]:
def find_converged_x(Rsonic, Lambdatype):
    result = sonic_point_shooting(Rsonic, Lambdatype)
    if result is None:
        print("No solution reached max radius.")
        return None
    return result 


def find_mdot(Rsonic, Lambdatype, result=None):
    if result is None:
        result = sonic_point_shooting(Rsonic, Lambdatype)
        if result is None:
            return np.nan
    x = result.xval
    cs2_sonic = vc_interp(Rsonic)**2 / (2 * x)
    v_sonic = cs2_sonic**0.5
    T_sonic = mu * mp_g * cs2_sonic / (gamma * k_Bcgs)
    tflow_to_tcool = (10 / 3) * (1 - x)
    rho_sonic = lf.rhocalc(v_sonic, tflow_to_tcool, T_sonic, Rsonic, Lambdatype)
    if rho_sonic is False:
        return np.nan
    Mdot = 4 * np.pi * Rsonic**2 * rho_sonic * v_sonic
    return Mdot

def postprocess(b, Mdot, Lambdatype):
    b_result = sonic_point_shooting(b, Lambdatype)
    x = b_result.xval
    R0 = b_result.R0
    v0 = b_result.v0
    T0 = b_result.T0
    Rmax = 20000 * ktc

    recorder = {"rarray": [], "varray": [], "Tarray": [], "rhoarray": [], "dvdr": [], "dTdr": []}

    res = solve_ivp(TheODE, [R0, Rmax], [v0, T0], args=(Mdot, Lambdatype, recorder), method='RK45', max_step=Rmax / 100,
        atol=1e-5, rtol=1e-5, dense_output=True)

    return x, R0, v0, T0, recorder

In [None]:
def BrentLooper(Mdot, Rsonlow, Rsonhigh, Lambdatype, tol=(2e-6 * Mdot1)):
    target = Mdot
    a, b = Rsonlow, Rsonhigh
    fa = find_mdot(a, Lambdatype) - target
    fb = find_mdot(b, Lambdatype) - target

    if fa * fb >= 0:
        print(fa, fb)
        raise ValueError("Not bounded correctly!")

    if abs(fa) < abs(fb):
        a, b = b, a
        fa, fb = fb, fa

    c = a
    fc = fa
    d = e = b - a
    mflag = True

    while abs(b - a) > tol:
        if abs(fb) < tol:
            return postprocess(b, Mdot, Lambdatype)
        if fa != fc and fb != fc:
            # Inverse quadratic interpolation
            s = (a * fb * fc / ((fa - fb) * (fa - fc)) +
                 b * fa * fc / ((fb - fa) * (fb - fc)) +
                 c * fa * fb / ((fc - fa) * (fc - fb)))
        else:
            # Secant method
            s = b - fb * (b - a) / (fb - fa)

        if a < b:
            cond1 = not ((3 * a + b) / 4 < s < b)
        else:
            cond1 = not (b < s < (3 * a + b) / 4)

        cond2 = mflag and abs(s - b) >= abs(b - c) / 2
        cond3 = (not mflag) and abs(s - b) >= abs(c - d) / 2
        cond4 = mflag and abs(b - c) < tol
        cond5 = (not mflag) and abs(c - d) < tol

        if cond1 or cond2 or cond3 or cond4 or cond5:
            s = (a + b) / 2
            mflag = True
        else:
            mflag = False

        fs = find_mdot(s, Lambdatype) - target
        d, c = c, b
        fd, fc = fc, fb

        if fa * fs < 0:
            b = s
            fb = fs
        else:
            a = s
            fa = fs

        if abs(fa) < abs(fb):
            a, b = b, a
            fa, fb = fb, fa
    return postprocess(b, Mdot, Lambdatype)

In [None]:
def secderiv(r, dTdr):
    """
    Numerically compute d²T/dr² from known dT/dr and r arrays.

    Parameters:
    - r: array of radius values (strictly increasing)
    - dTdr: array of dT/dr values at each radius

    Returns:
    - d2Tdr2: array of second derivative values, same shape
    """
    r = np.asarray(r)
    dTdr = np.asarray(dTdr)

    if r.shape != dTdr.shape:
        raise ValueError("r and dTdr must be the same shape.")
    if len(r) < 3:
        raise ValueError("Need at least 3 points to compute second derivative.")

    d2Tdr2 = np.zeros_like(r)

    # Forward difference at first point
    d2Tdr2[0] = (dTdr[1] - dTdr[0]) / (r[1] - r[0])

    # Central differences for interior
    d2Tdr2[1:-1] = (dTdr[2:] - dTdr[:-2]) / (r[2:] - r[:-2])

    # Backward difference at last point
    d2Tdr2[-1] = (dTdr[-1] - dTdr[-2]) / (r[-1] - r[-2])

    return d2Tdr2

In [None]:
def TimeODE(C, Lambdatype, Mdot=2500):
    r, v, T, rho, dvdr, dTdr, dTdr2 = C
    alpha=3.16e13

    # sanity check
    if T <= 0 or rho <= 0 or v == 0:
        raise ValueError(f"Non-physical inputs at r={r:.3e}: T={T}, rho={rho}, v={v}")

    # internal energy and stellar injection measures
    e = (3 * k_Bcgs * T) / (2 * mu * mp_g)
    v_star = v
    e_star = (1/2) * v_star**2 + (3/2) * k_Bcgs * T / (mu * mp_g)

    # electron/ion number densities
    ne = rho / (mp_g * 1.143)
    ni = rho / (mp_g * 1.231)

    # initial derivatives
    dedr = (e / T) * dTdr
    drhodr = -rho * ((2 / r) + (dvdr / v))
    dnedr = drhodr / (mp_g * 1.143)

    # stellar density profile
    rho_star = lf.Hernqdens(r, 0.597, stelmass)

    # spitzer conductivity
    a = 1.84e-7
    b = 23
    logterm = np.log(np.sqrt(ne) * T**(-1.5))
    denom = b - logterm
    k = (a * T**2.5) / denom

    # Derivative of k
    dlogterm_dr = (1 / (2 * ne)) * dnedr + (-1.5 / T) * dTdr
    ddenom_dr = -dlogterm_dr
    dkdr = ((2.5 * T**1.5 * dTdr * denom) - (a * T**2.5 * ddenom_dr)) / (denom**2)

    # Conductive heating/cooling term
    KTERM = 2 * k * dTdr + r * (k * dTdr2 + dkdr * dTdr)

    # The Big Derivatives
    drhodt = alpha * rho_star

    vc2 = vc_interp(r)**2
    dvdt = (1 / rho) * ((-2 / 3) * (rho * dedr + e * drhodr) - rho * vc2 + alpha * rho_star * v_star 
        - (Mdot / (4 * np.pi * r**2)) * dvdr - drhodt * v)

    Lambda = lf.Lambdacalc(np.log10(T), r, Lambdatype, rho / (mu * mp_g))
    dedt = (1 / rho) * (-ne * ni * Lambda + alpha * rho_star * e_star + (1 / r**2) * KTERM 
        - (5 * Mdot * Mdot1 / (12 * np.pi * r**2)) * dedr - drhodt * e)
    dTdt = (2 * mu * mp_g / (3 * k_Bcgs)) * dedt

    return [drhodt, dvdt, dTdt]

In [None]:
def TimeSolver(Mdot, Rsonlow, Rsonhigh, Lambdatype, tol=(2e-6 * Mdot1), evolve_time=False, t_span=(0, 3.16e13), t_eval=None):
    # Run steady-state solver first
    x, R0, v0, T0, recorder = BrentLooper(Mdot, Rsonlow, Rsonhigh, Lambdatype, tol=tol)

    rarray = recorder.get("rarray")
    varray = recorder.get("varray")
    Tarray = recorder.get("Tarray")
    rhoarray = recorder.get("rhoarray")
    dvdr = recorder.get("dvdr")
    dTdr = recorder.get("dTdr")

    # sanity checks
    arrays = [rarray, varray, Tarray, rhoarray, dvdr, dTdr]
    if any(arr is None for arr in arrays):
        raise ValueError("Recorder missing required arrays.")
    length = len(rarray)
    if not all(len(arr) == length for arr in arrays):
        raise ValueError("Recorder arrays lengths mismatch.")

    dTdr2 = secderiv(rarray, dTdr)

    if not evolve_time:
        return x, R0, v0, T0, recorder, None

    if t_eval is None:
        t_eval = np.linspace(t_span[0], t_span[1], 300)

    n_times = len(t_eval)
    T_evol = np.zeros((n_times, length))
    rho_evol = np.zeros((n_times, length))
    v_evol = np.zeros((n_times, length))

    def ode_wrapper(t, y, r, dvdr_i, dTdr_i, dTdr2_i):
        C = [r, y[1], y[2], y[0], dvdr_i, dTdr_i, dTdr2_i]
        return TimeODE(C, Lambdatype=Lambdatype, Mdot=Mdot)

    for i in range(length):
        y0 = [rhoarray[i], varray[i], Tarray[i]]

        sol = solve_ivp(fun=lambda t, y: ode_wrapper(t, y, rarray[i], dvdr[i], dTdr[i], dTdr2[i]), t_span=t_span,
            y0=y0, t_eval=t_eval, method="RK45", rtol=1e-5, atol=1e-5)

        rho_evol[:, i] = sol.y[0]
        v_evol[:, i] = sol.y[1]
        T_evol[:, i] = sol.y[2]

    recorderT = {
        "t_eval": t_eval,
        "rhoarray": rho_evol,
        "varray": v_evol,
        "Tarray": T_evol,
    }

    return x, R0, v0, T0, recorder, recorderT

In [None]:
def plot_time_series_profiles(Mdot, Rsonlow, Rsonhigh, Lambdatype, times_to_plot=None):
    # Call your solver with time evolution enabled
    _, _, _, _, recorder, recorderT = TimeSolver(
        Mdot=Mdot,
        Rsonlow=Rsonlow,
        Rsonhigh=Rsonhigh,
        Lambdatype=Lambdatype,
        evolve_time=True
    )

    # Unpack time evolution data
    t_eval = recorderT["t_eval"]
    rarray = recorderT["rarray"]
    Tarray = recorderT["Tarray"]  # shape (ntimes, nr)
    varray = recorderT["varray"]
    rhoarray = recorderT["rhoarray"]

    # Determine which times to plot
    if times_to_plot is None:
        times_to_plot = [0, len(t_eval)//4, len(t_eval)//2, 3*len(t_eval)//4, -1]

    time_labels = [f"{t_eval[i]:.2e} s" for i in times_to_plot]
    color_list = plt.cm.viridis(np.linspace(0, 1, len(times_to_plot)))

    # --- Temperature Plot ---
    plt.figure(figsize=(8, 5))
    for idx, color in zip(times_to_plot, color_list):
        plt.plot(rarray, Tarray[idx], label=f"{time_labels[times_to_plot.index(idx)]}", color=color)
    plt.xlabel("Radius [cm]")
    plt.ylabel("Temperature [K]")
    plt.title("Temperature Profile Over Time")
    plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # --- Velocity Plot ---
    plt.figure(figsize=(8, 5))
    for idx, color in zip(times_to_plot, color_list):
        plt.plot(rarray, varray[idx], label=f"{time_labels[times_to_plot.index(idx)]}", color=color)
    plt.xlabel("Radius [cm]")
    plt.ylabel("Velocity [cm/s]")
    plt.title("Velocity Profile Over Time")
    plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # --- Density Plot ---
    plt.figure(figsize=(8, 5))
    for idx, color in zip(times_to_plot, color_list):
        plt.plot(rarray, rhoarray[idx], label=f"{time_labels[times_to_plot.index(idx)]}", color=color)
    plt.xlabel("Radius [cm]")
    plt.ylabel("Density [g/cm³]")
    plt.title("Density Profile Over Time")
    plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
plot_time_series_profiles(Mdot=2500, Rsonlow=0.1*ktc, Rsonhigh=20*ktc, Lambdatype=0.25)