# n-Body Simulation

**Imports**

In [None]:
# @title
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
import matplotlib.animation as animation
from IPython.display import HTML
from numba import njit, prange
from joblib import Parallel, delayed
import plotly.graph_objects as go

**Collision detection**

In [None]:
# @title
import numpy as np

class UnionFind:
    """Simple Union-Find (DSU) for merging collision clusters."""
    def __init__(self, n):
        self.parent = np.arange(n)
        self.rank = np.zeros(n, dtype=int)

    def find(self, x):
        while x != self.parent[x]:
            self.parent[x] = self.parent[self.parent[x]]  # path compression
            x = self.parent[x]
        return x

    def union(self, x, y):
        rx, ry = self.find(x), self.find(y)
        if rx == ry: return
        if self.rank[rx] < self.rank[ry]:
            self.parent[rx] = ry
        elif self.rank[rx] > self.rank[ry]:
            self.parent[ry] = rx
        else:
            self.parent[ry] = rx
            self.rank[rx] += 1

def resolve_collisions_with_clusters(pos, vel, mass, r_coll):
    """
    Detects collisions, merges them into clusters, and assigns common trajectories.

    Parameters
    ----------
    pos : (N,3) array
        Positions of all particles.
    vel : (N,3) array
        Velocities of all particles.
    mass : (N,) array
        Mass of each particle.
    r_coll : float
        Collision radius.

    Returns
    -------
    pos, vel : updated arrays
    """
    N = len(pos)
    uf = UnionFind(N)
    r2_coll = r_coll**2

    # --- 1) Detect all collisions & union into clusters ---
    for i in range(N):
        for j in range(i + 1, N):
            rij = pos[j] - pos[i]
            if np.dot(rij, rij) < r2_coll:
                uf.union(i, j)

    # --- 2) Group particles by cluster root ---
    clusters = {}
    for i in range(N):
        root = uf.find(i)
        if root not in clusters:
            clusters[root] = []
        clusters[root].append(i)

    # --- 3) For each cluster, compute common trajectory ---
    for cluster_indices in clusters.values():
        if len(cluster_indices) == 1:
            continue  # no collision for this single body

        # Extract cluster data
        cluster_idx = np.array(cluster_indices)
        cluster_mass = mass[cluster_idx]
        cluster_vel = vel[cluster_idx]
        cluster_pos = pos[cluster_idx]

        # Mass-weighted common velocity
        m_total = np.sum(cluster_mass)
        v_common = np.sum(cluster_mass[:, None] * cluster_vel, axis=0) / m_total

        # Common position (center of mass)
        p_common = np.sum(cluster_mass[:, None] * cluster_pos, axis=0) / m_total

        # Assign same trajectory
        vel[cluster_idx] = v_common
        pos[cluster_idx] = p_common

    return pos, vel

**Calculate trajectory**

In [None]:
# @title
def trajectory(pos, vel, mass, G, eps, steps, res, dt, collision=True, parallel=False, p_idx=-1):
    def compute_PE_hist(pos, mass, G, eps):
        N = pos.shape[0]
        PE = 0.0     # Potential energy
        for i in range(N):
            r_ij = pos - pos[i]
            dist = np.sqrt(np.sum(r_ij**2, axis=1) + eps**2)
            dist[i] = np.inf
            PE -= 0.5 * G * mass[i] * np.dot(1.0 / dist, mass)
        return PE

    def total_angular_momentum(pos, vel, mass):
        return np.dot(mass, np.cross(pos, vel)) # Mpc²⋅m_sun/Myr

    # Leapfrog integrator
    @njit
    def compute_accelerations_chunk(b, n_threads, pos, mass, G, eps):
        N = pos.shape[0]
        chunk_size = (N + n_threads - 1) // n_threads
        start = b * chunk_size
        end = min((b + 1) * chunk_size, N)
        acc = np.zeros((end-start, 3))
        for i in range(start, end):
            r_ij = pos - pos[i]  # vector to all other particles
            dist2 = np.sum(r_ij**2, axis=1) + eps**2
            inv_dist3 = dist2**-1.5
            inv_dist3 *= mass
            acc[i-start] = G * np.sum(r_ij * inv_dist3[:, None], axis=0)
        return start, acc

    def compute_accelerations_parallel(n_threads, pos, mass, G, eps):
        N = pos.shape[0]
        acc = np.zeros((N, 3))
        chunks = list(range(0, n_threads))
        results = Parallel(n_jobs=n_threads)(delayed(compute_accelerations_chunk)(b, n_threads, pos, mass, G, eps) for b in chunks)
        for start, partial in results:
            acc[start:start + partial.shape[0]] = partial
        return acc

    @njit
    def compute_accelerations(pos, mass, G, eps):
        N = pos.shape[0]
        acc = np.zeros((N, 3))
        for i in range(N):
            r_ij = pos - pos[i]  # vector to all other particles
            dist2 = np.sum(r_ij**2, axis=1) + eps**2
            inv_dist3 = dist2**-1.5
            inv_dist3 *= mass
            acc[i] = G * np.sum(r_ij * inv_dist3[:, None], axis=0)
        return acc

    def compute_energies(pos, vel, mass, G, eps):
        KE = 0.5 * np.dot(mass, np.sum(vel**2, axis=1)) # m_sun⋅Mpc²/Myr²
        PE = compute_PE_hist(pos, mass, G, eps)
        return KE, PE

    def leapfrog_step(pos, vel, mass, acc, G, dt, eps, parallel):
        vel_half = vel + 0.5 * dt * acc # Mpc/s
        pos_new = pos + dt * vel_half # Mpc
        if parallel:
            acc_new = compute_accelerations_parallel(16, pos_new, mass, G, eps) # Mpc/s²
        else:
            acc_new = compute_accelerations(pos_new, mass, G, eps) # Mpc/s²
        vel_new = vel_half + 0.5 * dt * acc_new # Mpc/s
        return pos_new, vel_new, acc_new

    # Run simulation
    pos_hist = []
    vel_hist = []
    time_hist = []
    KE_hist = []
    PE_hist = []
    virial_hist = []

    cnt = max(steps // res, 1)
    time = 0 # Myr
    time_last = 0.0
    M_total = np.sum(mass)
    pos -= np.dot(mass, pos)/M_total
    vel -= np.dot(mass, vel)/M_total

    acc = compute_accelerations(pos, mass, G, eps)
    L_ang = total_angular_momentum(pos, vel, mass)
    print(f"Total system mass: {M_total:.4e} m_sun")
    print(f"Initial angular momentum: {L_ang} Mpc²⋅m_sun/Myr")
    dr_p_last = 0.0
    x_p_last = np.zeros(3)
    first_periastron_time = None
    first_periastron_x = None
    if p_idx >= 0: print("t [days], r_p [AU], Δt [days], Δφ [arcsec]")

    # Simulation
    for step in range(steps+1):
        if step % cnt == 0:
            KE, PE = compute_energies(pos, vel, mass, G, eps)
            KE_hist.append(KE)
            PE_hist.append(PE)
            virial_hist.append(2*KE/np.abs(PE))
            pos_hist.append(pos.copy())
            vel_hist.append(vel.copy())
            time_hist.append(time)
            if p_idx<0: print(".", end='')
        time += dt
        pos, vel, acc = leapfrog_step(pos, vel, mass, acc, G, dt, eps, parallel)
        if collision:
            pos, vel = resolve_collisions_with_clusters(pos, vel, mass, 2*eps)
        if p_idx >= 0: # periastron tracking
            x_p=pos[p_idx]-pos[0]
            v_p=vel[p_idx]-vel[0]
            r_p = np.linalg.norm(x_p)
            dr_p = np.dot(x_p, v_p) / r_p
            if dr_p_last < 0 and dr_p >= 0: # periastron passage
                alpha = -dr_p_last / (dr_p - dr_p_last)
                #alpha = 1.0
                t_p = time_last + alpha * (time - time_last)
                x_p_fine = x_p_last + alpha * (x_p - x_p_last)
                r_p = np.linalg.norm(x_p_fine)
                if first_periastron_time is None:
                    print(f"{t_p:.1f}, {r_p:.3e}")
                    first_periastron_time = t_p
                    first_periastron_x = x_p_fine
                else:
                    delta_t = t_p - first_periastron_time
                    direction   = np.dot(x_p_fine - first_periastron_x, v_p)
                    delta_phi_arcsec = np.arccos(np.dot(x_p_fine, first_periastron_x) / (np.linalg.norm(x_p_fine) * np.linalg.norm(first_periastron_x)))*206265.0
                    if direction < 0: delta_phi_arcsec *= -1.0
                    print(f"{t_p:.1f}, {r_p:.3e}, {delta_t:.1f}, {delta_phi_arcsec:.3e}")

            dr_p_last = dr_p
            x_p_last = x_p.copy()
            time_last = time

    L_ang = total_angular_momentum(pos, vel, mass)
    print()
    print(f"Final angular momentum: {L_ang} Mpc²⋅m_sun/Myr")

    return pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist

**Charts**

In [None]:
# @title
def charts(time_hist,KE_hist,PE_hist,virial_hist, time_unit='', energy_unit=''):
    plt.figure(figsize=(12, 5))


    # Plot energies
    #plt.figure(figsize=(8, 5))
    plt.subplot(1, 2, 1)
    plt.plot(time_hist, KE_hist, label='Kinetic Energy')
    plt.plot(time_hist, PE_hist, label='Potential Energy')
    plt.plot(time_hist, np.array(KE_hist) + np.array(PE_hist), label='Total Energy')
    plt.xlabel("Time [" + time_unit + "]")
    plt.ylabel("Energy [" + energy_unit+ "]")
    plt.legend()
    plt.title("Energy Evolution")

    # Plot virial_hist
    #plt.figure(figsize=(8, 5))
    plt.subplot(1, 2, 2)
    plt.plot(time_hist, virial_hist, label='virial_hist balance')
    plt.xlabel("Time [" + time_unit + "]")
    plt.ylabel("virial_hist ratio")
    plt.legend()
    plt.title("virial_hist balance")

    plt.tight_layout()
    plt.show()

**Animation**

In [None]:
# @title
menues = [{
    "type": "buttons",
    "buttons": [
        {
            "label": "Play",
            "method": "animate",
            "args": [None, {"frame": {"duration": 50, "redraw": True}, "fromcurrent": True}]
        },
        {
            "label": "XY Plane",
            "method":"relayout",
            "args": [{"scene.camera.eye": {"x": 0, "y": 0, "z": 2},
                    "scene.camera.up": {"x": 0, "y": 1, "z": 0},
                    "scene.camera.projection.type": "orthographic",
                    "scene.aspectratio": {"x": 1.8, "y": 1.8, "z": 1.8}
            }]
        },
        {
            "label": "XZ Plane",
            "method": "relayout",
            "args": [{"scene.camera.eye": {"x": 0, "y": -2, "z": 0},
                    "scene.camera.up": {"x": 0, "y": 0, "z": -1},
                    "scene.camera.projection.type": "orthographic",
                    "scene.aspectratio": {"x": 1.8, "y": 1.8, "z": 1.8}
            }]
        },
        {
            "label": "YZ Plane",
            "method": "relayout",
            "args": [{"scene.camera.eye": {"x": 2, "y": 0, "z": 0},
                    "scene.camera.projection.type": "orthographic",
                    "scene.aspectratio": {"x": 1.8, "y": 1.8, "z": 1.8}
            }]
        },
        {
            "label": "3D View",
            "method": "relayout",
            "args": [{"scene.camera.eye": {"x": 1.5, "y": -1.5, "z": 1.5},
                    "scene.camera.up": {"x": 0, "y": 0, "z": -1},
                    "scene.camera.projection.type": "perspective",
                    "scene.aspectratio": {"x": 1.2, "y": 1.2, "z": 1.2}
            }]
        }
    ],
}]

def animation(pos_hist, time_hist, R_scale, chart_title, baryon_set, center_body=-1, marker_size=10, time_unit='', pos_unit=''):
    # Create Plotly animation frames
    frames = []
    for i, pos in enumerate(pos_hist):
        if center_body >= 0:
            pos -= pos[center_body]
        frames.append(go.Frame(
            data=[
                go.Scatter3d(x=pos[~baryon_set, 0], y=pos[~baryon_set, 1], z=pos[~baryon_set, 2],mode='markers', marker=dict(color='blue', size=marker_size)),
                go.Scatter3d(x=pos[baryon_set, 0], y=pos[baryon_set, 1], z=pos[baryon_set, 2], mode='markers', marker=dict(color='red', size=marker_size)),
            ],
            name=f'frame{i}',
            layout=go.Layout(title_text=chart_title)
        ))

    # Initial frame
    init_pos = pos_hist[0]
    if center_body >= 0:
        init_pos -= init_pos[center_body]
    fig = go.Figure(
        data=[
            go.Scatter3d(x=init_pos[~baryon_set, 0], y=init_pos[~baryon_set, 1], z=init_pos[~baryon_set, 2], mode='markers', marker=dict(color='blue', size=marker_size)),
            go.Scatter3d(x=init_pos[baryon_set, 0], y=init_pos[baryon_set, 1], z=init_pos[baryon_set, 2], mode='markers', marker=dict(color='red', size=marker_size)),
        ],
        layout=go.Layout(
            title=chart_title,
            width = 1000,
            height = 800,
            scene=dict(
                xaxis=dict(title='x ['+ pos_unit + ']', range=[-R_scale, R_scale]),
                yaxis=dict(title='y ['+ pos_unit + ']', range=[-R_scale, R_scale]),
                zaxis=dict(title='z ['+ pos_unit + ']', range=[-R_scale, R_scale]),
                camera=dict(
                    eye=dict(x=1.5, y=-1.5, z=1.5),
                    up=dict(x=0, y=0, z=-1),
                    projection=dict(type='perspective')
                ),
                aspectratio=dict(x=1.2, y=1.2, z=1.2)
            ),
            updatemenus=menues,
            sliders=[{
                        "steps": [
                            {"args": [[f"frame{i}"], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate", "transition": {"duration": 0}}],
                            "label": f"{time_hist[i]:,.1e} {time_unit}",
                            "method": "animate"
                            } for i in range(0,len(pos_hist),1)
                        ],
                    }],
        ),
        frames=frames
    )
    fig.show()
    #fig.write_html(chart_title + '.html') # Save to HTML

In [None]:
# @title
def animation2(pos_hist, time_hist, R_scale, chart_title, bodies, center_body=-1, time_unit='', pos_unit=''):
    # Create Plotly animation frames
    frames = []
    colors = [body[3] for body in bodies]   # 4th entry = color
    sizes  = [body[4] for body in bodies]   # 5th entry = size
    labels =  [body[5] for body in bodies]
    for i, pos in enumerate(pos_hist):
        if center_body >= 0:
            pos -= pos[center_body]
        frames.append(go.Frame(
            data=[
                go.Scatter3d(x=pos[:,0], y=pos[:,1], z=pos[:,2], mode='markers', marker=dict(color=colors, size=sizes)),
                go.Scatter3d(x=pos[:,0], y=pos[:,1], z=pos[:,2], mode='text', text=labels, textposition="top center", textfont=dict(size=12, color="black"))
            ],
            name=f'frame{i}',
            layout=go.Layout(title_text=chart_title)
        ))

    # Initial frame
    init_pos = pos_hist[0]
    if center_body >= 0:
        init_pos -= init_pos[center_body]
    fig = go.Figure(
        data=[
            go.Scatter3d(x=init_pos[:,0], y=init_pos[:,1], z=init_pos[:,2], mode='markers', marker=dict(color=colors, size=sizes)),
            go.Scatter3d(x=init_pos[:,0], y=init_pos[:,1], z=init_pos[:,2], mode='text', text=labels, textposition="top center", textfont=dict(size=12, color="black"))
        ],
        layout=go.Layout(
            title=chart_title,
            width = 1000,
            height = 800,
            scene=dict(
                xaxis=dict(title='x ['+ pos_unit + ']', range=[-R_scale, R_scale]),
                yaxis=dict(title='y ['+ pos_unit + ']', range=[-R_scale, R_scale]),
                zaxis=dict(title='z ['+ pos_unit + ']', range=[-R_scale, R_scale]),
                camera=dict(
                    eye=dict(x=1.5, y=-1.5, z=1.5),
                    up=dict(x=0, y=0, z=-1),
                    projection=dict(type='perspective')
                ),
                aspectratio=dict(x=1.2, y=1.2, z=1.2)
            ),
            updatemenus=menues,
            sliders=[{
                        "steps": [
                            {"args": [[f"frame{i}"], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate", "transition": {"duration": 0}}],
                            "label": f"{time_hist[i]:,.1e} {time_unit}",
                            "method": "animate"
                            } for i in range(0,len(pos_hist),1)
                        ],
                    }],
        ),
        frames=frames
    )
    fig.show()
    fig.write_html(chart_title + '.html') # Save to HTML

**Constants**

In [None]:
# Constants in physical units
Mpc_m = 3.085677581e22 # m/Mpc
Mpc_km = Mpc_m/1000 # km/Mpc
Myr_s = 1e6 * 365 * 24 * 3600 # s/Myr
m_sun = 1.989e30 # kg
G = 6.67430e-11 * m_sun / Mpc_m**3 * Myr_s**2 # Mpc²⋅(Mpc/Myr²) / M_sun
m_earth = 5.972e24  # Mass of the Earth, kg
m_moon = m_earth/81  # Mass of the moon, kg
r01 = 1.496e11 / Mpc_m # Sun-Earth
r12 = 384000000 / Mpc_m # Earth-Moon

# Solar System

This models the solar system, including the Sun, Mercury, Venus, Earth and moon, Mars, Jupiter with 4 moons, Saturn with 6 moons, Uranus with 5 moons and Neptune with one moon (26 bodies in total). The initial conditions are sourced from JPL Horizons through astroquery.

In [None]:
try:
    from astroquery.jplhorizons import Horizons
except ImportError:
    !pip install astroquery
    from astroquery.jplhorizons import Horizons

# Gravitational constant
G_metric = 6.67430e-20   # km^3 / (kg s^2)

# Standard Gravitational Parameters (GM) in km^3/s^2
#     ID    GM (km^3/s^2)        Name        Color       Size  Label
Bodies = [
    ['Sun', 1.32712440018e11, "Sun", 'yellow', 20, "Sun"],
    ['199', 2.203186855e4, "Mercury", 'gray', 10, "Mercury"],
    ['299', 3.24858592e5, "Venus", 'orange', 10, "Venus"],
    ['399', 3.986004418e5, "Earth", 'blue', 10, "Earth"],
    ['301', 4.9048695e3, "Moon", 'lightgray', 10, ""],
    ['499', 4.282837e4, "Mars", 'red', 10, "Mars"],
    ['599', 1.26686534e8, "Jupiter", 'brown', 10, "Jupiter"],
    ['699', 3.7931187e7, "Saturn", 'goldenrod', 10, "Saturn"],
    ['799', 5.793939e6, "Uranus", 'lightblue', 10, "Uranus"],
    ['899', 6.836529e6, "Neptune", 'blue', 10, "Neptune"],
    ['501', 5.959924e3, "Io", 'darkgray', 7, ""],
    ['502', 3.202738e3, "Europa", 'darkgray', 7, ""],
    ['503', 9.887819e3, "Ganymede", 'darkgray', 7, ""],
    ['504', 7.179289e3, "Callisto", 'darkgray', 7, ""],
    ['601', 2.5025, "Mimas", 'blue', 7, ""],
    ['602', 7.2103, "Enceladus", 'blue', 7, ""],
    ['603', 41.208, "Tethys", 'blue', 7, ""],
    ['604', 73.117, "Dione", 'blue', 7, ""],
    ['605', 153.94, "Rhea", 'blue', 7, ""],
    ['606', 8.97813e3, "Titan", 'blue', 7, ""],
    ['701', 83.43, "Ariel", 'green', 7, ""],
    ['703', 222.8, "Titania", 'green', 7, ""],
    ['704', 205.34, "Oberon", 'green', 7, ""],
    ['705', 4.3, "Miranda", 'green', 7, ""],
    ['801', 1.428e3, "Triton", 'brown', 7, ""]
]

# Convert GM → mass
def mass_from_GM(body_id):
    return Bodies[body_id][1] / G_metric   # gives kg

N = len(Bodies)

# Preallocate arrays
body_mass = np.zeros(N)
body_pos  = np.zeros((N, 3))
body_vel  = np.zeros((N, 3))

# For computing barycenter
mass_all = 0.0
x_all = y_all = z_all = 0.0

for body in range(N):
    obj = Horizons(
        id=Bodies[body][0],
        location='@0',     # SSB frame
        epochs={'2025-01-01'},
        id_type='majorbody'
    )

    # Extract SSB state vector (in AU, AU/day)
    vectors = obj.vectors()
    x, y, z  = vectors['x'][0],  vectors['y'][0],  vectors['z'][0]
    vx, vy, vz = vectors['vx'][0], vectors['vy'][0], vectors['vz'][0]
    mass = mass_from_GM(body) # Mass

    # Store in arrays
    body_mass[body] = mass
    body_pos[body] = np.array([x, y, z])
    body_vel[body] = np.array([vx, vy, vz])

    x_all += mass * x
    y_all += mass * y
    z_all += mass * z
    mass_all += mass

    print(f"{Bodies[body][2]}:")
    print("  position (AU):", x, y, z)
    print("  velocity (AU/day):", vx, vy, vz)
    print("  mass (kg):", mass)
    print()

# Finally compute barycenter
barycenter = np.array([x_all, y_all, z_all]) / mass_all

print(f"System Center of Mass:")
print("  position (AU):", barycenter[0], barycenter[1], barycenter[2])
print("  mass (kg):", mass_all)
print()

# Parameters of the n-body system
steps = 100000
res = 500
dt = 365 / steps # covering 1000 days
epsilon = 1e-10
AU_m = 1.495978707e11  # m/AU
day_s = 86400          # s/day
G_AU = 6.67430e-11 / AU_m**3 * day_s**2 # AU²⋅(AU/day²) / kg

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(body_pos, body_vel, body_mass, G_AU, epsilon, steps, res, dt, p_idx=-1)
charts(time_hist,KE_hist,PE_hist,virial_hist, time_unit='days', energy_unit="M$_\\odot$⋅AU²⋅days⁻²")

In [None]:
animation2(pos_hist, time_hist, 30.0, 'Solar System', Bodies, center_body=0, time_unit='days', pos_unit='AU')

In [None]:
for body in Bodies: body[5]=body[2]
animation2(pos_hist, time_hist, 0.02, 'Jupiter with its moons', Bodies, center_body=6, time_unit='days', pos_unit='AU')

__Calculating Mercurys perihelion precession over 100 years__

In [None]:
# Parameters of the n-body system
steps = 10000000
res = 500
dt = 100*365 / steps # covering 100 years
# pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(body_pos, body_vel, body_mass, G_AU, epsilon, steps, res, dt, p_idx=1)

# Sun, Earth and Moon

This models moons orbit around earth and the moon/earth orbit around the sun

In [None]:
# Parameters of the n-body system
steps = 100000
res = 500
dt = 1e-6 / steps # covering 1 year
body_mass = np.array([1, m_earth/m_sun, m_moon/m_sun])
v_earth = np.sqrt(G * body_mass[0] / r01)  # Earth circular orbit velocity
v_moon = np.sqrt(G * body_mass[1] / r12)  # Moon circular orbit velocity
body_pos = np.array([[0,0,0], [r01,0,0], [r01+r12,0,0]])
body_vel = np.array([[0,0,0], [0,v_earth,0], [0,v_earth+v_moon,0]])
epsilon = r12/200

baryon_set = np.zeros(body_pos.shape[0], dtype=bool)
baryon_set[1] = True

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(body_pos, body_vel, body_mass, G, epsilon, steps, res, dt)
charts(time_hist,KE_hist,PE_hist,virial_hist, time_unit='Myr', energy_unit="M$_\\odot$⋅Mpc²⋅Myr⁻²")
animation(pos_hist, time_hist, 1.2 * r12, 'Sun, Earth and moons', baryon_set, center_body=1, time_unit='Myr', pos_unit='Mpc')

# Sun, Earth with two moons

This adds a second moon to the earth, in an opposite position. After a few orbits of both moons arond the earth, they start to attract themselfs and eventually collapse into each other.

In [None]:
# Parameters of the n-body system
steps = 100000
res = 500
dt = 1e-6 / steps # covering 1 year
body_mass = np.array([1, m_earth/m_sun, m_moon/m_sun, m_moon/m_sun])
v_earth = np.sqrt(G * body_mass[0] / r01)  # Earth circular orbit velocity
v_moon = np.sqrt(G * body_mass[1] / r12)  # Moon circular orbit velocity
body_pos = np.array([[0,0,0], [r01,0,0], [r01+r12,0,0], [r01-r12,0,0]])
body_vel = np.array([[0,0,0], [0,v_earth,0], [0,v_earth+v_moon,0], [0,v_earth-v_moon,0]])
epsilon = r12/200

baryon_set = np.zeros(body_pos.shape[0], dtype=bool)
baryon_set[1] = True

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(body_pos, body_vel, body_mass, G, epsilon, steps, res, dt)
charts(time_hist,KE_hist,PE_hist,virial_hist, time_unit='Myr', energy_unit="M$_\\odot$⋅Mpc²⋅Myr⁻²")
animation(pos_hist, time_hist, 1.2 * r12, 'Sun, Earth with two moons', baryon_set, center_body=1, time_unit='Myr', pos_unit='Mpc')

# 3-body Eight

The 3-body Eight is a dynamically stable constellation. The red body is a small test body that does not materially influence the 3-body Eight (blue bodies), but illustates the chaotic trajectory of a small body in the gravitational field of the 3-body Eight.

In [None]:
G = 1
# === Figure-eight initial conditions for equal masses ===
mass = np.array([1.0, 1.0, 1.0, 1e-3])  # 3 equal masses

# Positions (from Moore 1993)
pos = np.array([
    [ 0.97000436, -0.24308753, 0.0],
    [-0.97000436,  0.24308753, 0.0],
    [ 0.0,         0.0,        0.0],
    [-0.33,       -0.3,        0.0]
])

# Velocities
vel = np.array([
    [ 0.4662036850,  0.4323657300, 0.0],
    [ 0.4662036850,  0.4323657300, 0.0],
    [-0.93240737,   -0.86473146,   0.0],
    [ 0.0,           0.0,          0.0]
])

# Parameters of the n-body system
steps = 100000
res = 100
dt = 10 / steps # covering 10 time units
epsilon = 1/100

baryon_set = np.zeros(pos.shape[0], dtype=bool)
baryon_indices = np.array([3])
baryon_set[baryon_indices] = True

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(pos, vel, mass, G, epsilon, steps, res, dt, False)
charts(time_hist,KE_hist,PE_hist,virial_hist)
animation(pos_hist, time_hist, 1.2, 'Eight Curve', baryon_set)

# Trojan configuration

A Trojan configuration is a three-body orbital arrangement where a small body orbits stably 60° ahead of or behind a planet at the L4 or L5 Lagrange points, sharing the same orbital period and forming an equilateral triangle with the planet and star.

In [None]:
G = 1
M = 1.0    # planet mass
m = 0.001  # moon mass
a = 1.0    # orbital radius

planet_pos = np.array([0.0, 0.0, 0.0])
moon_pos   = np.array([a, 0.0, 0.0])
trojan_pos = a * np.array([np.cos(np.pi/3), np.sin(np.pi/3), 0.0])

planet_vel = np.array([0.0, 0.0, 0.0])
v_orbit = np.sqrt(M/a)  # circular velocity
moon_vel   = np.array([0.0, v_orbit, 0.0])
trojan_vel = v_orbit * np.array([-np.sin(np.pi/3), np.cos(np.pi/3), 0.0])

mass = np.array([M, m, 1e-8])
pos = np.array([planet_pos, moon_pos, trojan_pos])
vel = np.array([planet_vel, moon_vel, trojan_vel])

steps = 10000
res = 100
dt = 10 / steps # covering 10 time units
epsilon = 1/200

baryon_set = np.zeros(pos.shape[0], dtype=bool)
baryon_indices = np.array([2])
baryon_set[baryon_indices] = True

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(pos, vel, mass, G, epsilon, steps, res, dt)
charts(time_hist,KE_hist,PE_hist,virial_hist)
animation(pos_hist, time_hist, 2, 'Trojan', baryon_set)

# Two binaries

This is a constellation is a dynamically stable 4-body constellation where 2 binary systems orbit each other. Both binaries have different orbital periods, and in this example, even the orbital direction is opposite.

In [None]:
def double_binary_scatter(d_sep=5.0, v_rel=0.1):
    # 4 equal masses
    masses = np.ones(4)

    # First binary (close)
    r_bin = 1
    v_bin = np.sqrt(1/r_bin)  # circular velocity in G=1

    pos1 = np.array([
        [-r_bin/2, 0, 0],
        [+r_bin/2, 0, 0]
    ])
    vel1 = np.array([
        [0, -v_bin/2, 0],
        [0, +v_bin/2, 0]
    ])

    # Second binary (far away)
    pos2 = np.array([
        [-r_bin/2+d_sep, 0, 0],
        [+r_bin/2+d_sep, 0, 0]
    ])
    vel2 = np.array([
        [0, -v_bin/2, 0],
        [0, +v_bin/2, 0]
    ])

    # Relative velocity between binaries
    vel1 += [v_rel/2, 0, 0]
    vel2 -= [v_rel/2, 0, 0]

    # Combine
    pos = np.vstack([pos1, pos2])
    vel = np.vstack((vel1, vel2))

    return pos, vel, masses

def double_binary(masses, d_sep=5.0, r_A = 0.5, r_B = 0.5):
    # 4 equal masses

    # First binary (close)
    v_A = np.sqrt((masses[0]+masses[1])/r_A)  # circular velocity in G=1
    v_B = np.sqrt((masses[2]+masses[3])/r_B)  # circular velocity in G=1
    v_orbit = np.sqrt(np.sum(masses)/d_sep)

    pos_A = np.array([
        [-r_A/2-d_sep/2, 0, 0],
        [+r_A/2-d_sep/2, 0, 0]
    ])
    vel_A = np.array([
        [0, -v_A/2, 0],
        [0, +v_A/2, 0]
    ])

    # Second binary (far away)
    pos_B = np.array([
        [-r_B/2+d_sep/2, 0, 0],
        [+r_B/2+d_sep/2, 0, 0]
    ])
    vel_B = np.array([
        [0, +v_B/2, 0],
        [0, -v_B/2, 0]
    ])

    # Relative velocity between binaries
    vel_A -= [0, v_orbit/2, 0]
    vel_B += [0, v_orbit/2, 0]

    # Combine
    pos = np.vstack([pos_A, pos_B])
    vel = np.vstack((vel_A, vel_B))

    return pos, vel, masses

In [None]:
G=1
masses = np.array([1,1,1,1])
epsilon = 1/200
steps = 100000
dt = 50/steps
res = 500
pos, vel, mass = double_binary(masses, d_sep = 3, r_A=0.7, r_B = 0.3)
baryon_set = np.zeros(pos.shape[0], dtype=bool)
baryon_indices = np.array([0,1])
baryon_set[baryon_indices] = True

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(pos, vel, mass, G, epsilon, steps, res, dt)
charts(time_hist,KE_hist,PE_hist,virial_hist)
animation(pos_hist, time_hist, 6, 'Two binaries', baryon_set)

# Tetrahedal

Another stable 4 body constellation

In [None]:
import numpy as np

def tetrahedral_4body():
    # 4 equal masses
    masses = np.ones(4)

    # Coordinates of a regular tetrahedron (unit edge length, centered at origin)
    a = 1.0 / np.sqrt(2.0)
    pos = np.array([
        [ 1.0,  0.0, -1/np.sqrt(2)],
        [-1.0,  0.0, -1/np.sqrt(2)],
        [ 0.0,  1.0,  1/np.sqrt(2)],
        [ 0.0, -1.0,  1/np.sqrt(2)]
    ]) * 0.5  # scale to reasonable size

    # For a static central configuration (free-fall collapse)
    vel = np.zeros_like(pos)

    return pos, vel, masses

import numpy as np

def tetrahedral_4body_rotating(G=1.0, M=1.0):
    """
    Four equal masses on a regular tetrahedron rotating as a rigid body.
    G=1, M=1 by default (N-body units).
    """
    N = 4
    m = M / N

    # Regular tetrahedron vertices centered at origin
    pos = np.array([
        [ 1,  1,  1],
        [-1, -1,  1],
        [-1,  1, -1],
        [ 1, -1, -1]
    ], dtype=float)
    pos /= np.linalg.norm(pos[0])  # scale so distance from center = 1

    # Compute gravitational acceleration at one vertex
    def grav_acc(i):
        acc = np.zeros(3)
        for j in range(N):
            if i == j: continue
            r = pos[j] - pos[i]
            dist3 = np.linalg.norm(r)**3
            acc += G*m * r / dist3
        return acc

    acc0 = grav_acc(0)
    R = np.linalg.norm(pos[0])       # distance from center
    a_mag = np.linalg.norm(acc0)     # gravity magnitude
    omega = np.sqrt(a_mag / R)       # rigid-body angular speed

    # Choose rotation axis (z-axis)
    rot_axis = np.array([0,0,1])

    def rigid_rot_vel(r, omega, axis):
        return omega * np.cross(axis, r)

    vel = np.array([rigid_rot_vel(r, omega, rot_axis) for r in pos])

    masses = np.ones(N) * m
    return pos, vel, masses, omega

In [None]:
G=1
epsilon = 1/200
steps = 50000
dt = 20/steps
pos, vel, mass, omega = tetrahedral_4body_rotating()
baryon_set = np.zeros(pos.shape[0], dtype=bool)
baryon_indices = np.array([0,1])
baryon_set[baryon_indices] = True

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(pos, vel, mass, G, epsilon, steps, res, dt)
charts(time_hist,KE_hist,PE_hist,virial_hist)
animation(pos_hist, time_hist, 5, 'Tetrahedal', baryon_set)

# Plummer sphere

A Plummer sphere is a widely used analytic model in stellar dynamics that describes a smooth, spherically symmetric distribution of matter with a finite total mass and a soft, non-singular core. Its density profile decreases with radius as

$\rho(r)=\frac{3M}{4\pi b^3}\left(1+\frac{r^2}{b^2}\right)^{-5/2}$,

where M is the total mass and b is the scale radius that sets the core size. Near the center, the density is approximately constant, avoiding the central cusp found in other models, while at large radii the density falls off as $r^{-5}$, ensuring a finite mass and potential. Because it has closed-form expressions for density, potential, and cumulative mass, the Plummer sphere is commonly used to initialize star clusters or dark-matter halos in N-body simulations, providing a stable and computationally convenient equilibrium configuration.

In [None]:
def sample_plummer_sphere(N, M=1.0, a=1.0, G=1.0):
    """
    Sample N-body equilibrium initial conditions for a Plummer sphere.

    Parameters:
    -----------
    N : int
        Number of particles
    M : float
        Total mass
    a : float
        Plummer scale radius
    G : float
        Gravitational constant (default N-body units)

    Returns:
    --------
    pos : (N,3)
        Positions
    vel : (N,3)
        Velocities
    masses : (N,)
        Equal particle masses summing to M
    """
    # === 1) Positions ===
    X = np.random.random(N)
    r = a * (X**(-2/3) - 1.0)**(-0.5)

    # isotropic angles
    theta = np.arccos(1 - 2*np.random.random(N))
    phi = 2 * np.pi * np.random.random(N)

    x = r * np.sin(theta) * np.cos(phi)
    y = r * np.sin(theta) * np.sin(phi)
    z = r * np.cos(theta)
    pos = np.vstack((x, y, z)).T

    # === 2) Velocities ===
    vel = np.zeros_like(pos)
    for i in range(N):
        # Local escape speed at r_i
        ri = np.linalg.norm(pos[i])
        v_esc = np.sqrt(2 * G * M / np.sqrt(ri**2 + a**2))

        # Sample speed using rejection sampling
        while True:
            y_rand = np.random.random()
            x_rand = np.random.random()

            # Candidate speed
            v = y_rand * v_esc

            # Normalized probability from Plummer DF
            g = v**2 * (1 - (v/v_esc)**2)**3.5
            if x_rand < g:  # accept
                v_sample = v
                break

        # Random velocity direction
        costh = 2*np.random.random() - 1
        sinth = np.sqrt(1 - costh**2)
        phi_v = 2*np.pi * np.random.random()
        vx = v_sample * sinth * np.cos(phi_v)
        vy = v_sample * sinth * np.sin(phi_v)
        vz = v_sample * costh

        vel[i] = [vx, vy, vz]

    # === 3) Equal particle masses ===
    masses = np.ones(N) * (M / N)

    return pos, vel, masses

In [None]:
np.random.seed = 41
N_gal = 1000
a = 0.5
G = 1
pos, vel, mass = sample_plummer_sphere(N_gal, M=1, a=a, G=G)

G = 1
steps = 1000
res = 100
dt = 50 / steps # covering 10 years
epsilon = 1/100

baryon_set = np.zeros(pos.shape[0], dtype=bool)

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(pos, vel, mass, G, epsilon, steps, res, dt, collision=False, parallel=True)
charts(time_hist,KE_hist,PE_hist,virial_hist)
animation(pos_hist, time_hist, 15, 'Galaxy merger', baryon_set, marker_size = 2)

# Galaxy merger

This constellation shows two galaxies, initialized as Plummer spheres, moving towards each other and eventually merge. This is what happens in cosmological n-body simulations, such as LCDM.

In [None]:
np.random.seed = 41
N_gal = 1000
a = 0.2
G = 1
offset = np.array([10, 0, 0])
v_rel = np.array([0.1, -0.10, 0])
pos1, vel1, mass1 = sample_plummer_sphere(N_gal, M=1, a=a, G=G)
pos2, vel2, mass2 = sample_plummer_sphere(N_gal, M=1, a=a, G=G)
pos1 -= offset/2
pos2 += offset/2
vel1 += v_rel
vel2 -= v_rel
pos = np.vstack((pos1, pos2))
vel = np.vstack((vel1, vel2))
mass = np.concatenate((mass1, mass2))

G = 1
steps = 2000
res = 100
dt = 100 / steps # covering 10 years
epsilon = 1/100

baryon_set = np.zeros(pos.shape[0], dtype=bool)
baryon_set[:N_gal] = True

pos_hist,vel_hist,time_hist,KE_hist,PE_hist,virial_hist = trajectory(pos, vel, mass, G, epsilon, steps, res, dt, collision=False, parallel=True)
charts(time_hist,KE_hist,PE_hist,virial_hist)
animation(pos_hist, time_hist, 15, 'Galaxy merger', baryon_set, marker_size=2)