In [1]:
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib. pyplot as plt
from matplotlib.cm import get_cmap

In [19]:
class Body():
    def __init__(self, x:float, y:float, m:float, vx: float = 0.0, vy: float = 0.0) -> None:

        self.x0 = x
        self.y0 = y
        self.vx0 = vx
        self.vy0 = vy
        
        self.init_position:np.ndarray = np.array([x,y])
        self.init_velocity: np.ndarray = np.array([vx, vy])
        self.mass:float = m 
        self.trajectory:list[np.ndarray] = None
        self.time = None

    def update_pos(self, new_pos_matrix: np.ndarray, timesteps: np.ndarray):
        self.time = timesteps
        self.trajectory = new_pos_matrix

In [16]:
class Simulation():
    def __init__(self, bodies:list[Body], method: str = "RK45", t_max:int=10, t_step:float=0.1, grav_const = 6.67e-11):
        self.bodies:list[Body] = bodies
        self.t_max:int = t_max
        self.dt:float = t_step
        self.method:str = method
        self.G = grav_const

    def run(self):
        t = np.linspace(0, self.t_max, self.dt)
        g = self.G
        m1 = self.bodies[0].mass
        m2 = self.bodies[1].mass
        m3 = self.bodies[2].mass

        def system(t, y):
            #Position of the 3 bodies 
            x1 = y[0:2]
            x2 = y[2:4]
            x3 = y[4:6]
            #Velocities of the 3 bodies
            v1 = y[6:8]
            v2 = y[8:10]
            v3 = y[10:12]

            #Distances of the 3 bodies
            r12 = np.linalg.norm(x2 - x1)
            r13 = np.linalg.norm(x3 - x1)
            r23 = np.linalg.norm(x3 - x2)

            #Accelerations of the 3 bodies
            a1 = g * m1 * (m2 * (x2 - x1) / r12**3 + m3 * (x3 - x1) / r13**3)
            a2 = g * m2 * (m1 * (x1 - x2) / r12**3 + m3 * (x3 - x2) / r23**3)
            a3 = g * m3 * (m1 * (x1 - x3) / r13**3 + m2 * (x2 - x3) / r23**3)

            #Computing the derivatives
            dydt = np.zeros(18)
            dydt[0:2] = v1
            dydt[2:4] = v2
            dydt[4:6] = v3
            dydt[6:8] = a1
            dydt[8:10] = a2
            dydt[10:12] = a3
            return dydt
        
        y0 = np.array([self.bodies[0].x0, self.bodies[0].y0, self.bodies[1].x0, self.bodies[1].y0, self.bodies[2].x0, self.bodies[2].y0, self.bodies[0].vx0, self.bodies[0].vy0, self.bodies[1].vx0, self.bodies[1].vy0, self.bodies[2].vx0, self.bodies[2].vy0])
        
        sol = solve_ivp(system, [0, self.t_max], y0, t_eval = t)

        self.bodies[0].update_pos(sol.y[0:2, :], sol.t)
        self.bodies[1].update_pos(sol.y[2:4, :], sol.t)
        self.bodies[2].update_pos(sol.y[4:6, :], sol.t)

        return sol
    
    def plot_trajectories(self):
        fig = plt.figure(figsize=(5, 7), dpi = 600)

        cmap = get_cmap("viridis", len(self.bodies))

        for i, body in enumerate(self.bodies):
            plt.plot(body.trajectory[0, :], body.trajectory[1, :], marker='o', linestyle='-', color=cmap(i))
        
        plt.xlabel('x position')
        plt.ylabel('y position')
        plt.title('Object Trajectory')
        plt.grid(True)
        plt.show()
        return fig
            