In [None]:
import numpy as np
%matplotlib qt
import matplotlib.pyplot as plt


def getEnergy(pos, vel, mass, G):
    """
    Generalized function to calculate kinetic and potential energy for N-dimensional space.
    pos   is an N x D matrix of positions, where D is the number of dimensions
    vel   is an N x D matrix of velocities
    mass  is an N x 1 vector of masses
    G     is Newton's Gravitational constant
    """
    # Kinetic Energy
    KE = 0.5 * np.sum(mass * np.sum(vel ** 2, axis=1))

    # Potential Energy
    # Number of dimensions
    _, num_dims = pos.shape

    # Create a matrix of pairwise differences for each dimension and calculate squared distances
    pairwise_dist_squared = np.zeros((pos.shape[0], pos.shape[0]))
    for dim in range(num_dims):
        diff = pos[:, dim:dim + 1] - pos[:, dim:dim + 1].T
        pairwise_dist_squared += diff ** 2

    inv_r = np.sqrt(pairwise_dist_squared)
    inv_r[inv_r > 0] = 1.0 / inv_r[inv_r > 0]

    # Sum over the upper triangle to count each interaction only once
    PE = G * np.sum(np.triu(-(mass * mass.T) * inv_r, 1))

    return KE, PE


def getAcc(pos, mass, G, softening):
    """
    Calculate the acceleration on each particle due to Newton's Law 
    pos  is an N x D matrix of positions, where D is the number of dimensions
    mass is an N x 1 vector of masses
    G is Newton's Gravitational constant
    softening is the softening length
    a is N x D matrix of accelerations
    """
    # Number of dimensions
    _, num_dims = pos.shape

    # Initialize the acceleration array
    acc = np.zeros_like(pos)

    # Calculate pairwise separations and accelerations in each dimension
    for dim in range(num_dims):
        # Separation in this dimension: r_j - r_i
        dpos = pos[:, dim:dim + 1].T - pos[:, dim:dim + 1]

        # Update the pairwise distance squared (contribution from this dimension)
        if dim == 0:
            r2 = dpos ** 2
        else:
            r2 += dpos ** 2

        # Update acceleration contribution from this dimension
        acc[:, dim:dim + 1] = (dpos * ((r2 + softening ** 2) ** -1.5)) @ mass

    # Multiply by G
    acc *= G

    return acc


# Simulation parameters
N = 50  # Number of particles
t = 0  # current time of the simulation
tEnd = 10.0  # time at which simulation ends
dt = 0.01  # timestep
softening = 0.1  # softening length
G = 1.0  # Newton's Gravitational Constant
plotRealTime = True  # switch on for plotting as the simulation goes along
boxSize = 1.0

# Generate Initial Conditions
np.random.seed(17)  # set the random number generator seed

pos = np.random.rand(N, 3).astype(np.float64) * 2 * boxSize - boxSize
vel = np.random.randn(N, 3).astype(np.float64) * 0.7
mass = (10.0 * np.ones((N, 1)) / N).astype(np.float64)
# Convert to Center-of-Mass frame
vel -= np.mean(mass * vel, 0) / np.mean(mass)

# calculate initial gravitational accelerations
acc = getAcc(pos, mass, G, softening)
KE, PE = getEnergy(pos, vel, mass, G)

# number of timesteps
Nt = int(np.ceil(tEnd / dt))

# save energies, particle orbits for plotting trails
pos_save = np.zeros((N, 3, Nt + 1))
vel_save = np.zeros((N, 3, Nt + 1))
pos_save[:, :, 0] = pos
vel_save[:, :, 0] = vel
KE_save = np.zeros(Nt + 1)
KE_save[0] = KE
PE_save = np.zeros(Nt + 1)
PE_save[0] = PE
t_all = np.arange(Nt + 1) * dt

# prep figure
fig = plt.figure(figsize=(4, 5), dpi=80)
grid = plt.GridSpec(3, 1, wspace=0.0, hspace=0.3)
ax1 = fig.add_subplot(grid[0:2, 0], projection='3d')
ax2 = fig.add_subplot(grid[2, 0])

from mpl_toolkits.mplot3d import Axes3D

# Simulation Main Loop
for i in range(Nt):
    # (1/2) kick
    #vel += acc * dt / 2.0

    # drift
    pos += vel * dt

    for n in range(N):
        for j in range(3):
            if pos[n, j] > boxSize:
                pos[n, j] = boxSize - (pos[n, j] - boxSize) # Reflect position inside boundary
                vel[n, j] *= -1 # Reverse velocity component
            elif pos[n, j] < -boxSize:
                pos[n, j] = -boxSize + (-boxSize - pos[n, j]) # Reflect position inside boundary
                vel[n, j] *= -1 # Reverse velocity component

    # update accelerations
    #acc = getAcc(pos, mass, G, softening)

    # (1/2) kick
    #vel += acc * dt / 2.0

    # update time
    t += dt

    # save energies, positions for plotting trail
    vel_save[:, :, i + 1] = vel
    pos_save[:, :, i + 1] = pos

    # get energy of system
    KE, PE = getEnergy(pos, vel, mass, G)
    KE_save[i + 1] = KE
    PE_save[i + 1] = PE

    # plot in real time
    if plotRealTime or (i == Nt - 1):
        # plot in real time
        ax1.clear()
        xx = pos_save[:, 0, max(i - 50, 0):i + 1]
        yy = pos_save[:, 1, max(i - 50, 0):i + 1]
        zz = pos_save[:, 2, max(i - 50, 0):i + 1]
        ax1.plot3D(xx.flatten(), yy.flatten(), zz.flatten(), 'blue', alpha=0.5)
        ax1.scatter(pos[:, 0], pos[:, 1], pos[:, 2], color='red', s=10)

        ax1.set_xlim([-boxSize, boxSize])
        ax1.set_ylim([-boxSize, boxSize])
        ax1.set_zlim([-boxSize, boxSize])
        ax1.set_xlabel('X axis')
        ax1.set_ylabel('Y axis')
        ax1.set_zlabel('Z axis')

        ax2.clear()

        ax2.scatter(t_all[:i + 1], KE_save[:i + 1], color='red', s=1, label='KE' if i == Nt - 1 else "")
        ax2.scatter(t_all[:i + 1], PE_save[:i + 1], color='blue', s=1, label='PE' if i == Nt - 1 else "")
        ax2.scatter(t_all[:i + 1], KE_save[:i + 1] + PE_save[:i + 1], color='black', s=1,
                    label='Etot' if i == Nt - 1 else "")

        ax2.set(xlim=(0, tEnd), ylim=(-300, 300))
        ax2.set_aspect(0.007)
        ax2.set_xlabel('Time')
        ax2.set_ylabel('Energy')

        plt.pause(0.0001)



idem plotit 0
idem plotit 1
idem plotit 2
idem plotit 3
idem plotit 4
idem plotit 5
idem plotit 6
idem plotit 7
idem plotit 8
idem plotit 9
idem plotit 10
idem plotit 11
idem plotit 12
idem plotit 13
idem plotit 14
idem plotit 15
idem plotit 16
idem plotit 17
idem plotit 18
idem plotit 19
idem plotit 20
idem plotit 21
idem plotit 22
idem plotit 23
idem plotit 24
idem plotit 25
idem plotit 26
idem plotit 27
idem plotit 28
idem plotit 29
idem plotit 30
idem plotit 31
idem plotit 32
idem plotit 33
idem plotit 34
idem plotit 35
idem plotit 36
idem plotit 37
idem plotit 38
idem plotit 39
idem plotit 40
idem plotit 41
idem plotit 42
idem plotit 43
idem plotit 44
idem plotit 45
idem plotit 46
idem plotit 47
idem plotit 48
idem plotit 49


In [None]:
plt.figure(figsize=(10, 6))

plt.plot(t_all, KE_save, label='Kinetic Energy', color='red')
plt.plot(t_all, PE_save, label='Potential Energy', color='blue')
plt.plot(t_all, KE_save + PE_save, label='Total Energy', color='black')

plt.xlabel('Time')
plt.ylabel('Energy')
plt.title('Energy vs Time')
plt.legend()
plt.grid(True)
plt.show()