In [None]:
import numpy as np
from astropy import units as u, constants as const
from matplotlib import pyplot as plt
from scipy.stats import beta
from scipy.special import beta as betafct
from numba import jit, jitclass
from scipy import interpolate
import logging

import sys
import os
sys.path.insert(1, '/home/jsipple/one_d_spherical_collapse/one-d-spherical-collapse-of-fuzzy-dark-matter/src')

import importlib
import simulation_strategies, collapse, plotting, utils
importlib.reload(simulation_strategies)
importlib.reload(collapse)
importlib.reload(plotting)
importlib.reload(utils)
from simulation_strategies import *
from collapse import *
from plotting import *
from utils import *

In [None]:
def my_mpl():
    plt.rc('font', family='serif', size=20)
    plt.rc('axes', grid=True)
    plt.rc('lines', lw=3)
    ts = 8
    plt.rc('xtick.minor', size=ts-2)
    plt.rc('ytick.minor', size=ts-2)
    plt.rc('xtick.major', size=ts)
    plt.rc('ytick.major', size=ts)
    plt.rc('figure', figsize=[12, 9])
my_mpl()
logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logging.getLogger('matplotlib').setLevel(logging.ERROR)

In [None]:
t_i = 1
G=1
rho = 1/(6*np.pi*G*t_i**2)
H = 2/(3*t_i)

In [None]:
import cProfile
import pstats
# 't_max': 5.7,
# Create a SphericalCollapse instance
config = {
    'softlen': 0,
    'density_strategy': "const",
    'point_mass': 0,
    'safety_factor': 1e-4,
    'gamma': 0,
    'thickness_coef': 0,
    'j_coef': 1e-3,
    'ang_mom_strategy': 'const',
    'energy_strategy': 'kin_softgrav_softrot',
    'accel_strategy': 'soft_all',
    't_max': 2,
    'dt_min': 1e-12,
    'N': 1,
    'r_min': 0,
    'H': 0,
    'r_max': 1,
    'm_tot': 1, 
    'stepper_strategy': 'beeman',
    'timescale_strategy': 'dyn',
    'm_enc_strategy': 'overlap_inclusive',
    'save_strategy': 'vflip',
    'deque_size': 100,
}
x = SphericalCollapse(config)

# Profile the run method
profiler = cProfile.Profile()
profiler.enable()
results = x.run()
profiler.disable()

# Print the profiling results
stats = pstats.Stats(profiler).sort_stats('cumulative')
stats.print_stats(20)  # Print top 20 time-consuming functions

In [None]:
print(len(results['t']))
2/len(results['t'])

In [None]:
total_energy_sum = np.sum(results['e_tot'], axis=1)
print(total_energy_sum)
reldif = (total_energy_sum - total_energy_sum[0]) / total_energy_sum[0]
print(reldif)
print(max(np.abs(reldif)))
amax = np.argmax(np.abs(reldif))
print(amax)
print(results['t'][amax])
around = 10
plt.plot(results['t'][amax-around:amax+around], reldif[amax-around:amax+around])
plt.figure()
plt.plot(results['t'][amax-around:amax+around], results['r'][amax-around:amax+around])
plt.figure()
plt.plot(results['t'][amax-around:amax+around], np.diff(results['t'][amax-around-1:amax+around]))

In [None]:
plt.semilogy(results['t'], results['t_dyn'])



In [None]:
def plot_global_property(results, property_name, limit_axis=True, ylim=None):
    """
    Plot the progress of a global property over time.
    
    Parameters:
    - results: dict, output from SphericalCollapse.run()
    - property_name: str, name of the property to plot ('t_dyn', 'e_tot', etc.)
    - limit_axis: bool, whether to limit the y-axis to exclude extreme values
    - ylim: tuple, custom y-axis limits (min, max)
    """
    if property_name not in results:
        raise ValueError(f"Property '{property_name}' not found in results.")
    
    time = results['t']
    property_data = results[property_name]
    
    plt.figure(figsize=(10, 6))
    plt.plot(time, property_data)
    plt.xlabel('Time')
    plt.ylabel(property_name.capitalize())
    plt.title(f'{property_name.capitalize()} vs Time')
    plt.grid(True)
    
    if ylim:
        plt.ylim(ylim)
    elif limit_axis:
        lower_percentile = np.percentile(property_data, 1)
        upper_percentile = np.percentile(property_data, 99)
        plt.ylim(lower_percentile, upper_percentile)
    
    

def plot_shell_property(results, property_name, shell_indices=None, num_shells=5, limit_axis=True, ylim=None, yscale='linear', title=None, vs_r=False):
    """
    Plot the progress of selected shells over time for a given property.
    
    Parameters:
    - results: dict, output from SphericalCollapse.run()
    - property_name: str, name of the property to plot ('r', 'v', 'a', 'e_tot', 'e_g', 'e_k', 'e_r')
    - shell_indices: list of int, indices of shells to plot. If None, evenly spaced shells will be selected.
    - num_shells: int, number of shells to plot if shell_indices is None
    - limit_axis: bool, whether to limit the y-axis to exclude extreme values
    - ylim: tuple, custom y-axis limits (min, max)
    """
    if property_name not in results:
        raise ValueError(f"Property '{property_name}' not found in results.")
    
    time = results['t']
    property_data = results[property_name]
    
    if property_data.ndim == 1:
        # If the property is a global quantity, use plot_global_property instead
        return plot_global_property(results, property_name, limit_axis, ylim)
    
    if shell_indices is None:
        total_shells = property_data.shape[1]
        shell_indices = np.linspace(0, total_shells-1, min(total_shells, num_shells), dtype=int)
    
    plt.figure(figsize=(10, 6))
    for idx in shell_indices:
        if vs_r:
            plt.plot(results['r'][:, idx], property_data[:, idx], label=f'Shell {idx+1}')
        else:
            plt.plot(time, property_data[:, idx], label=f'Shell {idx+1}')
    
    plt.yscale(yscale)
    plt.xlabel('Time' if not vs_r else 'Radius')
    plt.ylabel(property_name.capitalize())
    if title is None:
        plt.title(f'{property_name}')
    else:
        plt.title(title)
    plt.legend()
    plt.grid(True)
    
    if ylim:
        plt.ylim(ylim)
    elif limit_axis:
        data_to_plot = property_data[:, shell_indices]
        lower_percentile = 0.9*np.percentile(data_to_plot, 1)
        upper_percentile = 1.1*np.percentile(data_to_plot, 99)
        plt.ylim(lower_percentile, upper_percentile)
    
    

def plot_energy_components(results, shell_index=None, limit_axis=True, ylim=None):
    """
    Plot the progress of different energy components over time.
    If shell_index is provided, plot for a specific shell, otherwise plot global energy.
    
    Parameters:
    - results: dict, output from SphericalCollapse.run()
    - shell_index: int or None, index of the shell to plot. If None, plot global energy.
    - limit_axis: bool, whether to limit the y-axis to exclude extreme values
    - ylim: tuple, custom y-axis limits (min, max)
    """
    time = results['t']
    energy_components = ['e_tot', 'e_g', 'e_k', 'e_r']
    
    plt.figure(figsize=(10, 6))
    all_energy_data = []
    for component in energy_components:
        if component in results:
            energy_data = results[component]
            if energy_data.ndim > 1 and shell_index is not None:
                energy_data = energy_data[:, shell_index]
            plt.plot(time, energy_data, label=component.capitalize())
            all_energy_data.extend(energy_data)
    
    plt.xlabel('Time')
    plt.ylabel('Energy')
    title = 'Energy Components vs Time'
    if shell_index is not None:
        title += f' for Shell {shell_index}'
    plt.title(title)
    plt.legend()
    plt.grid(True)
    
    if ylim:
        plt.ylim(ylim)
    elif limit_axis:
        lower_percentile = np.percentile(all_energy_data, 1)
        upper_percentile = np.percentile(all_energy_data, 99)
        lower_limit = lower_percentile * 0.9 if lower_percentile > 0 else lower_percentile * 1.1
        upper_limit = upper_percentile * 1.1 if upper_percentile > 0 else upper_percentile * 0.9
        plt.ylim(lower_limit, upper_limit)
    
    


def plot_timescales(results, shell_index=None, limit_axis=False, ylim=None):
    """
    Plot the progress of different timescale components over time.
    If shell_index is provided, plot for a specific shell, otherwise plot global timescales.
    
    Parameters:
    - results: dict, output from SphericalCollapse.run()
    - shell_index: int or None, index of the shell to plot. If None, plot global timescales.
    - limit_axis: bool, whether to limit the y-axis to exclude extreme values
    - ylim: tuple, custom y-axis limits (min, max)
    """
    time = results['t']
    timescale_components = ['t_dyn', 't_vel', 't_acc', 't_cross', 't_zero', 't_rmin', 't_rmina', 't_dynnext', 't_dynr', 'dt']
    
    plt.figure(figsize=(10, 6))
    plt.plot(time[:-1], np.diff(time), color='k', label='saved_dt', linestyle='--', zorder=99)
    all_timescale_data = []
    for component in timescale_components:
        if component in results and not all(x is None for x in results[component]):
            timescale_data = results[component]
            if timescale_data.ndim > 1 and shell_index is not None:
                timescale_data = timescale_data[:, shell_index]
            plt.plot(time, timescale_data, label=component)
            all_timescale_data.extend(timescale_data)
    
    plt.xlabel('Time')
    plt.ylabel('Timescale')
    title = 'Timescale Components vs Time'
    if shell_index is not None:
        title += f' for Shell {shell_index}'
    plt.title(title)
    plt.legend()
    plt.grid(True)
    
    plt.yscale('log')  # Set y-axis to logarithmic scale
    
    if ylim:
        plt.ylim(ylim)
    elif limit_axis:
        lower_percentile = np.percentile(all_timescale_data, 1)
        upper_percentile = np.percentile(all_timescale_data, 99)
        plt.ylim(lower_percentile, upper_percentile)
    
    

def plot_total_energy_relative_change(results, limit_axis=True, ylim=None, yscale='linear'):
    """
    Plot the relative change of the sum of total energies of all individual shells over time,
    compared to the initial total energy.
    
    Parameters:
    - results: dict, output from SphericalCollapse.run()
    - limit_axis: bool, whether to limit the y-axis to exclude extreme values
    - ylim: tuple, custom y-axis limits (min, max)
    - yscale: str, scale of y-axis ('linear', 'log', 'symlog', etc.)
    """
    time = results['t']
    e_tot = results['e_tot']
    
    # Calculate the sum of total energies for all shells at each time step
    total_energy_sum = np.sum(e_tot, axis=1)
    
    # Get the initial total energy
    initial_total_energy = total_energy_sum[0]
    
    # Calculate the relative change compared to the initial energy
    relative_change = (total_energy_sum - initial_total_energy) / initial_total_energy
    if yscale == 'log':
        relative_change = np.abs(relative_change)
    
    plt.figure(figsize=(10, 6))
    plt.plot(time, relative_change)
    plt.xlabel('Time')
    plt.ylabel('Relative Change in Total Energy')
    plt.title(r'$\Delta E/E_0$')
    plt.grid(True)
    
    plt.yscale(yscale)  # Set y-axis scale
    
    if ylim:
        plt.ylim(ylim)
    elif limit_axis:
        lower_percentile = np.percentile(relative_change, 1)
        upper_percentile = np.percentile(relative_change, 99)
        lower_limit = lower_percentile * 0.9 if lower_percentile > 0 else lower_percentile * 1.1
        upper_limit = upper_percentile * 1.1 if upper_percentile > 0 else upper_percentile * 0.9
        plt.ylim(lower_limit, upper_limit)

# Example usage:
plot_timescales(results)
#plot_global_property(results, 'e_tot')
shells = None#[0,24,49,74,99]


plot_shell_property(results, 'r', ylim=[-0.1, 2.2], shell_indices=shells)
plot_shell_property(results, 'r', yscale='log', limit_axis=False, shell_indices=shells)
plot_shell_property(results, 'r', vs_r=True, ylim=[-0.1, 2.2], shell_indices=shells)
plot_shell_property(results, 'e_g', shell_indices=shells)
#plot_global_property(results, 'num_crossing', limit_axis=False)
plot_total_energy_relative_change(results, limit_axis=False, yscale='linear',)
plot_total_energy_relative_change(results, limit_axis=False, yscale='log', ylim=[1e-2, None])
#plot_shell_property(results, 'v')
plot_shell_property(results, 'm_enc', limit_axis=False, shell_indices=shells)
plot_shell_property(results, 'v', limit_axis=False, shell_indices=shells)
plot_shell_property(results, 'a', limit_axis=False, shell_indices=shells)
plot_shell_property(results, 'm_enc', limit_axis=False, shell_indices=shells, vs_r=True)
plot_shell_property(results, 'v', limit_axis=False, shell_indices=shells, vs_r=True)
plot_shell_property(results, 'a', limit_axis=False, shell_indices=shells, vs_r=True)
plot_energy_components(results, shell_index=0, limit_axis=False)  # Plot energy for a specific shell
plot_energy_components(results, shell_index=-1, limit_axis=False)  # Plot energy for a specific shell
#analyze_energy_conservation(results)



In [None]:
def visualize_rebound_moments(results, shell_index=0, window_before=10, window_after=10):
    """
    Visualize parameters around the rebound moments of a specific shell.

    Parameters:
    - results: dict, output from SphericalCollapse.run()
    - shell_index: int, index of the shell to analyze
    - window_before: int, number of time steps before the rebound
    - window_after: int, number of time steps after the rebound

    This function identifies all rebound moments for the specified shell and plots
    the parameters r, v, a, and energies around each rebound without overlapping windows.
    """
    import numpy as np
    import matplotlib.pyplot as plt

    # Extract relevant data
    t = results['t']
    r = results['r'][:, shell_index]
    v = results['v'][:, shell_index]
    a = results['a'][:, shell_index]
    e_tot = results['e_tot'][:, shell_index]
    e_k = results['e_k'][:, shell_index]
    e_g = results['e_g'][:, shell_index]
    e_r = results['e_r'][:, shell_index]

    # Identify rebound moments (velocity crosses from negative to positive)
    rebound_indices = np.where((v[:-1] < 0) & (v[1:] >= 0))[0] + 1  # +1 for the crossing point

    print(f"Found {len(rebound_indices)} rebound moments for shell {shell_index}.")

    # To prevent overlapping windows, keep track of the last plotted index
    last_plotted = -np.inf

    for idx, rebound_idx in enumerate(rebound_indices):
        # Check if current rebound is outside the previous window
        if rebound_idx - window_before <= last_plotted:
            print(f"Skipping rebound {idx+1} at index {rebound_idx} to prevent overlapping windows.")
            continue

        # Define the window around the rebound
        start = rebound_idx - window_before
        end = rebound_idx + window_after + 1

        # Ensure indices are within the data range
        if start < 0:
            start = 0
            print(f"Rebound {idx+1}: Adjusted start from {rebound_idx - window_before} to {start}.")
        if end > len(t):
            end = len(t)
            print(f"Rebound {idx+1}: Adjusted end from {rebound_idx + window_after +1} to {end}.")

        # Time range
        t_window = t[start:end]

        # Parameters in the window
        r_window = r[start:end]
        v_window = v[start:end]
        a_window = a[start:end]
        e_tot_window = e_tot[start:end]
        e_k_window = e_k[start:end]
        e_g_window = e_g[start:end]
        e_r_window = e_r[start:end]

        # Debugging: Print the number of points before and after
        actual_before = rebound_idx - start
        actual_after = end - rebound_idx
        print(f"Rebound {idx+1}: {actual_before} points before, {actual_after} points after.")

        # Create subplots
        fig, axs = plt.subplots(5, 1, figsize=(12, 20), sharex=True)
        fig.suptitle(f'Shell {shell_index} Rebound {idx+1} at t ≈ {t[rebound_idx]:.4f}', fontsize=16)

        # Radius vs Time
        axs[0].plot(t_window, r_window, label='Radius (r)', marker='o', linestyle='-')
        axs[0].axvline(t[rebound_idx], color='r', linestyle='--', label='Rebound')
        axs[0].set_ylabel('Radius')
        axs[0].legend()
        axs[0].grid(True)

        # Velocity vs Time
        axs[1].plot(t_window, v_window, label='Velocity (v)', color='g', marker='o', linestyle='-')
        axs[1].axvline(t[rebound_idx], color='r', linestyle='--', label='Rebound')
        axs[1].set_ylabel('Velocity')
        axs[1].legend()
        axs[1].grid(True)

        # Acceleration vs Time
        axs[2].plot(t_window, a_window, label='Acceleration (a)', color='m', marker='o', linestyle='-')
        axs[2].axvline(t[rebound_idx], color='r', linestyle='--', label='Rebound')
        axs[2].set_ylabel('Acceleration')
        axs[2].legend()
        axs[2].grid(True)

        # Energies vs Time
        axs[3].plot(t_window, e_tot_window, label='Total Energy (E_tot)', marker='o', linestyle='-')
        axs[3].plot(t_window, e_k_window, label='Kinetic Energy (E_k)', marker='o', linestyle='-')
        axs[3].plot(t_window, e_g_window, label='Gravitational Energy (E_g)', marker='o', linestyle='-')
        axs[3].plot(t_window, e_r_window, label='Rotational Energy (E_r)', marker='o', linestyle='-')
        axs[3].axvline(t[rebound_idx], color='r', linestyle='--', label='Rebound')
        axs[3].set_ylabel('Energy')
        axs[3].legend()
        axs[3].grid(True)

        # Calculate the absolute value of total energy divided by total energy at time step zero
        e_tot_normalized = np.abs(e_tot_window) / np.abs(e_tot[0])
        
        # Plot the normalized total energy
        axs[4].plot(t_window, e_tot_normalized, label='Normalized Total Energy', color='b', marker='o', linestyle='-')
        axs[4].axvline(t[rebound_idx], color='r', linestyle='--', label='Rebound')
        axs[4].set_ylabel('Normalized Total Energy')
        axs[4].set_xlabel('Time')
        axs[4].legend()
        axs[4].grid(True)

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

        # Update the last plotted index
        last_plotted = end

In [None]:
print(min(results['r']))

In [None]:
visualize_rebound_moments(results, shell_index=0, window_before=5, window_after=5)
#visualize_rebound_moments(results, shell_index=1, window=10)

In [None]:
results['deque'].shape

In [None]:
def visualize_pre_rebound(results, shell_index=0, num_after_rebound=1):
    """
    Visualize the moments before a rebound using the deque snapshots.

    Parameters:
    - deque_snapshots: list of dict
        The list containing the recent simulation snapshots leading up to a rebound.
    - shell_index: int, default=0
        Index of the shell to analyze.
    - window_size: int, default=3
        Number of snapshots to plot before the rebound.
    """
    # Extract relevant data
    t = results['t']
    r = results['r'][:, shell_index]
    v = results['v'][:, shell_index]
    a = results['a'][:, shell_index]
    e_tot = results['e_tot'][:, shell_index]
    e_k = results['e_k'][:, shell_index]
    e_g = results['e_g'][:, shell_index]
    e_r = results['e_r'][:, shell_index]

    # Identify rebound moments (velocity crosses from negative to positive)
    rebound_index = np.where((v[:-1] < 0) & (v[1:] >= 0))[0] + 1  # +1 for the crossing point
    deque = results['deque'][rebound_index + num_after_rebound][0]
    
    # Extract data for plotting
    times = [deque[i]['t'] for i in range(len(deque))]
    radii = [deque[i]['r'][shell_index] for i in range(len(deque))]
    velocities = [deque[i]['v'][shell_index] for i in range(len(deque))]
    accelerations = [deque[i]['a'][shell_index] for i in range(len(deque))]
    e_tot = [deque[i]['e_tot'][shell_index] for i in range(len(deque))]
    e_g = [deque[i]['e_g'][shell_index] for i in range(len(deque))]
    e_k = [deque[i]['e_k'][shell_index] for i in range(len(deque))]
    e_r = [deque[i]['e_r'][shell_index] for i in range(len(deque))]

    rebound_time = results['t'][rebound_index][0]
    rebound_r = results['r'][rebound_index][0, shell_index]
    rebound_v = results['v'][rebound_index][0, shell_index]

    
    # Calculate the sum of total energies for all shells at each time step
    initial_total_energy = np.sum(results['e_tot'], axis=1)[0]
    
    # Calculate the relative change compared to the initial energy
    e_tot_rel = (e_tot - initial_total_energy) / initial_total_energy

    # Plotting
    fig, axs = plt.subplots(5, 1, figsize=(12, 24), sharex=True)
    print(rebound_time)
    fig.suptitle(f'Shell {shell_index} Moments Before Rebound at t ≈ {rebound_time:.4f}', fontsize=18)

    # Radius vs Time
    axs[0].plot(times, radii, marker='o', linestyle='-', color='blue', label='Radius (r)')
    axs[0].axvline(x=rebound_time, color='red', linestyle='--', label='Rebound')
    axs[0].set_ylabel('Radius')
    axs[0].legend(fontsize=10)  # Smaller fontsize for legend

    # Velocity vs Time
    axs[1].plot(times, velocities, marker='o', linestyle='-', color='green', label='Velocity (v)')
    axs[1].axvline(x=rebound_time, color='red', linestyle='--', label='Rebound')
    axs[1].set_ylabel('Velocity')
    axs[1].legend(fontsize=10)  # Smaller fontsize for legend

    # Acceleration vs Time
    axs[2].plot(times, accelerations, marker='o', linestyle='-', color='purple', label='Acceleration (a)')
    axs[2].axvline(x=rebound_time, color='red', linestyle='--', label='Rebound')
    axs[2].set_ylabel('Acceleration')
    axs[2].legend(fontsize=10)  # Smaller fontsize for legend

    # Energies vs Time
    axs[3].plot(times, np.abs(e_tot), marker='o', linestyle='-', label='Total Energy (E_tot)')
    axs[3].plot(times, np.abs(e_k), marker='o', linestyle='--', label='Kinetic Energy (E_k)')
    axs[3].plot(times, np.abs(e_g), marker='o', linestyle='-.', label='Gravitational Energy (E_g)')
    axs[3].plot(times, np.abs(e_r), marker='o', linestyle=':', label='Rotational Energy (E_r)')
    axs[3].axvline(x=rebound_time, color='red', linestyle='--', label='Rebound')
    axs[3].set_ylabel('Energy (Absolute Value)')
    axs[3].set_yscale('log')
    axs[3].legend(fontsize=10)  # Smaller fontsize for legend

    # Relative Total Energy vs Time
    axs[4].plot(times, e_tot_rel, marker='o', linestyle='-', color='orange', label='Relative Total Energy')
    axs[4].axvline(x=rebound_time, color='red', linestyle='--', label='Rebound')
    axs[4].set_ylabel('Relative Total Energy')
    axs[4].set_xlabel('Time')
    axs[4].legend(fontsize=10)  # Smaller fontsize for legend

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    print(f"Rebound occurred at t = {rebound_time:.6f}")
    print(f"Radius at rebound: {rebound_r:.6f}")
    print(f"Velocity at rebound: {rebound_v:.6f}")

In [None]:
visualize_pre_rebound(results, shell_index=0, num_after_rebound=50)
#visualize_pre_rebound(results, shell_index=1)

In [None]:
def test_r_t_comparison(base_config, m, H, r, J):
    """
    Test the simulation's r(t) against a numerically solved r(t) using scipy.
    """
    initial_conditions = [r, H * r]  # [r(0), v(0)]

    def acceleration(t, y):
        r, v = y
        a = -G * m / r**2 + (J ** 2) / r**3
        return [v, a]
    
    # Define test parameters
    G = 1
    M = m
    # Calculate analytical values
    v = H * r
    L = J * m
    E_k = (1/2)*m*v**2
    E_g = -G*m*M/r
    E_rot = L**2/(2*m*r**2)
    E_tot = E_k + E_g + E_rot
    a = -E_tot
    b = -G*M*m
    c = L**2/(2*m)
    r_close_analytical = (-b - np.sqrt(b**2 - 4*a*c)) / (2*a)
    r_close_analytical = np.nanmin([r_close_analytical, r])
    r_far_analytical = (-b + np.sqrt(b**2 - 4*a*c)) / (2*a)
    r_far_analytical = np.nanmax([r_far_analytical, r])
    t_max = np.sqrt(4*np.pi/m * r_far_analytical**3)
    print(t_max)

    # Time span for the simulation
    t_span = (0, t_max)
    t_eval = np.linspace(t_span[0], t_span[1], int(t_max / base_config["save_dt"]) + 1)

    # Solve ODE numerically
    sol = solve_ivp(acceleration, t_span, initial_conditions, t_eval=t_eval, method='RK45', rtol=1e-8)

    if not sol.success:
        pytest.fail("ODE solver failed to integrate.")

    r_numerical = sol.y[0]
    v_numerical = sol.y[1]
    t_numerical = sol.t

    # Run simulation
    config = {**base_config, "m_tot": m, "j_coef": J, 'H': H, 'r_max': r, 't_max': t_max, 'save_strategy': 'default'}
    sim = SphericalCollapse(config)
    results = sim.run()

    r_simulated = results['r']
    t_simulated = results['t']

    # Interpolate numerical solution to simulation time points
    r_expected = np.interp(t_simulated, t_numerical, r_numerical)

In [None]:
def analyze_energy_conservation(results):
    t = results['t']
    e_tot = np.sum(results['e_tot'], axis=1)
    e_k = np.sum(results['e_k'], axis=1)
    e_g = np.sum(results['e_g'], axis=1)
    e_r = np.sum(results['e_r'], axis=1)
    
    # Calculate relative energy change
    e_rel_change = (e_tot - e_tot[0]) / e_tot[0]
    
    # Calculate energy change per time step
    de_dt = np.diff(e_tot) / np.diff(t)
    
    # Plot results
    fig, axs = plt.subplots(3, 1, figsize=(12, 18), sharex=True)
    
    axs[0].plot(t, e_tot, label='Total')
    axs[0].plot(t, e_k, label='Kinetic')
    axs[0].plot(t, e_g, label='Gravitational')
    axs[0].plot(t, e_r, label='Rotational')
    axs[0].set_ylabel('Energy')
    axs[0].legend()
    axs[0].set_title('Energy Components')
    
    axs[1].plot(t, e_rel_change)
    axs[1].set_ylabel('Relative Energy Change')
    axs[1].set_title('Relative Total Energy Change')
    
    axs[2].plot(t[1:], de_dt)
    axs[2].set_ylabel('dE/dt')
    axs[2].set_xlabel('Time')
    axs[2].set_title('Energy Change Rate')
    
    plt.tight_layout()
    
    
    # Identify largest energy changes
    largest_changes = np.argsort(np.abs(de_dt))[-5:]
    print("Time steps with largest energy changes:")
    for i in largest_changes:
        print(f"Time: {t[i+1]:.6f}, dE/dt: {de_dt[i]:.6e}")

# Use the function
analyze_energy_conservation(results)

In [None]:
from utils import load_simulation_data

# Assuming all the plotting functions are already defined in this notebook

def process_results(results, sim_name):
    print(sim_name)
    #plot_timescales(results)
    shells = None  # [0,24,49,74,99]
    plot_shell_property(results, 'r', ylim=[-0.1, 2], shell_indices=shells, title=sim_name)
    plot_shell_property(results, 'r', yscale='log', limit_axis=False, shell_indices=shells)
    plot_total_energy_relative_change(results, limit_axis=False, yscale='linear')
    plot_total_energy_relative_change(results, limit_axis=False, yscale='log', ylim=[1e-5, None])
    #plot_shell_property(results, 'm_enc', limit_axis=False, shell_indices=shells)
    #plot_shell_property(results, 'v', limit_axis=False, shell_indices=shells)
    #plot_shell_property(results, 'a', limit_axis=False, shell_indices=shells)
    #plot_energy_components(results, shell_index=0, limit_axis=False)
    #plot_energy_components(results, shell_index=-1, limit_axis=False)
    #analyze_energy_conservation(results)
    

# Process all .h5 files in the src/ directory
folder = 'thickness_coef'
src_dir = '/home/jsipple/one_d_spherical_collapse/one-d-spherical-collapse-of-fuzzy-dark-matter/src/runs/'+folder
fn = '/home/jsipple/one_d_spherical_collapse/one-d-spherical-collapse-of-fuzzy-dark-matter/src/runs/sim_N2_soft0_sf0.001_tc0_jc0.h5'
params, results = load_simulation_data(fn)
# Process and display results
process_results(results, 'sim_N2_soft0_sf0.001_tc0_jc0.h5')
for filename in os.listdir(src_dir):
    if filename.endswith('.h5'):
        file_path = os.path.join(src_dir, filename)        
        # Load the simulation data
        params, results = load_simulation_data(file_path)
        
        # Process and display results
        process_results(results, folder+os.path.splitext(filename)[0])
        

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def cycloid_solution(t, A):
    """
    Compute the cycloid solution for spherical collapse.
    
    Parameters:
    t : array-like
        Time values
    A : float
        Amplitude parameter
    
    Returns:
    r : array-like
        Radial positions
    v : array-like
        Velocities
    """
    theta = np.sqrt(np.float128(8) * np.pi / np.float128(A)) * np.float128(t)
    r = np.float128(0.5) * np.float128(A) * (np.float128(1) + np.cos(theta))
    v = -np.float128(0.5) * np.float128(A) * np.sqrt(np.float128(8) * np.pi / np.float128(A)) * np.sin(theta)
    return r, v

def compute_energies(r, v, m, G):
    """
    Compute kinetic, potential, and total energy.
    
    Parameters:
    r : array-like
        Radial positions
    v : array-like
        Velocities
    m : float
        Mass of the particle
    G : float
        Gravitational constant
    
    Returns:
    E_k : array-like
        Kinetic energy
    E_p : array-like
        Potential energy
    E_tot : array-like
        Total energy
    """
    E_k = np.float128(0.5) * np.float128(m) * v**2
    E_p = -np.float128(G) * np.float128(m)**2 / r
    E_tot = E_k + E_p
    return E_k, E_p, E_tot

# Set up parameters
G = np.float128(1.0)  # Gravitational constant
m = np.float128(1.0)  # Mass of the particle
A = np.float128(1.0)  # Amplitude parameter
t_max = np.float128(2) * np.pi * np.sqrt(np.float128(A) / (np.float128(8) * np.pi))  # One full cycle
t = np.linspace(np.float128(0), t_max, 1000, dtype=np.float128)

# Compute cycloid solution
r, v = cycloid_solution(t, A)

# Compute energies
E_k, E_p, E_tot = compute_energies(r, v, m, G)

# Plot results
plt.figure(figsize=(12, 8))
plt.subplot(211)
plt.plot(t, r, label='r(t)')
plt.plot(t, v, label='v(t)')
plt.xlabel('Time')
plt.ylabel('r, v')
plt.legend()
plt.title('Cycloid Solution')

plt.subplot(212)
plt.plot(t, np.abs(E_k), label='Kinetic Energy')
plt.plot(t, np.abs(E_p), label='Potential Energy')
plt.plot(t, np.abs(E_tot), label='Total Energy')
plt.xlabel('Time')
plt.ylabel('Energy')
plt.yscale('log')
plt.ylim(1e0, 1e3)
plt.legend()
plt.title('Energy Conservation')

plt.tight_layout()


# Print energy conservation statistics
E_tot_initial = E_tot[0]
E_tot_relative_error = (E_tot - E_tot_initial) / E_tot_initial
print(f"Maximum relative error in total energy: {np.max(np.abs(E_tot_relative_error)):.2e}")
print(f"Standard deviation of relative error in total energy: {np.std(E_tot_relative_error):.2e}")

In [None]:
# Define the gravitational acceleration function
def gravitational_acceleration(r, G, M):
    return -G * M / r**2

# Velocity Verlet integrator with reflecting sphere
def velocity_verlet(r0, v0, dt, steps, G, M, r_min):
    r, v = np.zeros(steps), np.zeros(steps)
    r[0], v[0] = r0, v0
    
    for i in range(1, steps):
        a = gravitational_acceleration(r[i-1], G, M)
        r[i] = r[i-1] + v[i-1]*dt + 0.5*a*dt**2
        
        # Reflect off the sphere at r_min
        if r[i] < r_min:
            r[i] = 2*r_min - r[i]
            v[i-1] = -v[i-1]  # Reverse velocity for elastic collision
        
        a_next = gravitational_acceleration(r[i], G, M)
        v[i] = v[i-1] + 0.5*(a + a_next)*dt
    
    return r, v

# Beeman integrator with reflecting sphere
def beeman(r0, v0, dt, steps, G, M, r_min):
    r, v = np.zeros(steps), np.zeros(steps)
    r[0], v[0] = r0, v0
    a_prev = gravitational_acceleration(r0, G, M)
    a = a_prev
    
    for i in range(1, steps):
        r[i] = r[i-1] + v[i-1]*dt + (4*a - a_prev)*dt**2/6
        
        # Reflect off the sphere at r_min
        if r[i] < r_min:
            r[i] = 2*r_min - r[i]
            v[i-1] = -v[i-1]  # Reverse velocity for elastic collision
        
        a_next = gravitational_acceleration(r[i], G, M)
        v[i] = v[i-1] + (2*a_next + 5*a - a_prev)*dt/6
        a_prev, a = a, a_next
    
    return r, v

# Set up parameters for the test
G = 1.0
M = 1.0
r0 = 1.0
v0 = 0.5
dt = 0.0001
steps = 100000  # Increased number of steps to see multiple bounces
r_min = 0.1  # Radius of the reflecting sphere
t = np.arange(steps) * dt

# Run simulations
r_vv, v_vv = velocity_verlet(r0, v0, dt, steps, G, M, r_min)
r_beeman, v_beeman = beeman(r0, v0, dt, steps, G, M, r_min)

# Calculate energies
E_vv = 0.5 * v_vv**2 - G*M/r_vv
E_beeman = 0.5 * v_beeman**2 - G*M/r_beeman

# Plot results
plt.figure(figsize=(15, 10))

plt.subplot(221)
plt.plot(t, r_vv, label='Velocity Verlet')
plt.plot(t, r_beeman, label='Beeman')
plt.axhline(y=r_min, color='r', linestyle='--', label='Reflecting Sphere')
plt.xlabel('Time')
plt.ylabel('Radius')
plt.legend()
plt.title('Radius vs Time')

plt.subplot(222)
plt.plot(t, v_vv, label='Velocity Verlet')
plt.plot(t, v_beeman, label='Beeman')
plt.xlabel('Time')
plt.ylabel('Velocity')
plt.legend()
plt.title('Velocity vs Time')

plt.subplot(223)
plt.plot(t, E_vv, label='Velocity Verlet')
plt.plot(t, E_beeman, label='Beeman')
plt.xlabel('Time')
plt.ylabel('Total Energy')
plt.legend()
plt.title('Total Energy vs Time')

plt.subplot(224)
plt.plot(t, (E_vv - E_vv[0])/E_vv[0], label='Velocity Verlet')
plt.plot(t, (E_beeman - E_beeman[0])/E_beeman[0], label='Beeman')
plt.xlabel('Time')
plt.ylabel('Relative Energy Error')
plt.legend()
plt.title('Relative Energy Error vs Time')

plt.tight_layout()


# Print energy conservation statistics
print("Velocity Verlet:")
print(f"Maximum relative error in total energy: {np.max(np.abs((E_vv - E_vv[0])/E_vv[0])):.2e}")
print(f"Standard deviation of relative error in total energy: {np.std((E_vv - E_vv[0])/E_vv[0]):.2e}")

print("\nBeeman:")
print(f"Maximum relative error in total energy: {np.max(np.abs((E_beeman - E_beeman[0])/E_beeman[0])):.2e}")
print(f"Standard deviation of relative error in total energy: {np.std((E_beeman - E_beeman[0])/E_beeman[0]):.2e}")


In [None]:
x = np.sum(results['e_tot'], axis=1)
plt.plot(results['t'], (x-x[0])/x[0])
plt.xlim(0,0.02)
plt.ylim(-0.01, 0)

In [None]:
def r(theta):
    return 1-np.cos(theta)

def t(theta):
    return theta - np.sin(theta)

def v(rvals, tvals):
    return np.gradient(rvals, tvals)

def a(vvals, tvals):
    return np.gradient(vvals, tvals)

def E(rvals, vvals, avals):
    # Calculate energy E
    return 0.5 * (vvals**2 - avals * rvals)

theta = np.linspace(0, 4*np.pi, 1000)
rvals, tvals = r(theta), t(theta)
vvals = v(rvals, tvals)
avals = a(vvals, tvals)
Evals = E(rvals, vvals, avals)

plt.figure(figsize=(24, 5))
plt.subplot(141)
plt.plot(tvals, rvals)
plt.xlabel('t')
plt.ylabel('r')
plt.title('r vs t')

plt.subplot(142)
plt.plot(tvals, vvals)
plt.xlabel('t')
plt.ylabel('v')
plt.title('v vs t')

plt.subplot(143)
plt.plot(tvals, avals)
plt.xlabel('t')
plt.ylabel('a')
plt.title('a vs t')

plt.subplot(144)
plt.plot(tvals, Evals)
plt.xlabel('t')
plt.ylabel('E')
plt.title('E vs t')

plt.tight_layout()


In [None]:
def poorly_formatted_function(   x,y   ):
    return  x+y

if True:
 print(   "This is not properly indented"   )

list_comp = [   x    for x in range(10)   if x % 2 == 0]
