In [None]:
import numpy as np
import os
import time
import matplotlib.pyplot as plt
import scipy
import matplotlib
import sympy
import numba
from numba import njit, prange , set_num_threads, get_num_threads
from typing import Callable
from mpi4py import MPI
import jax
import jax.numpy as jnp

## Bonus 1

In [None]:
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
os.environ["NUMBA_NUM_THREADS"] = str(size)

In [None]:
def get_acceleration_mpi(X: np.ndarray) -> np.ndarray:
    
    N = X.shape[0]
    local_N = N // size
    start = rank * local_N
    end = (rank + 1) * local_N if rank != size - 1 else N
    
    local_a = np.zeros((N, 3))
    
    for i in range(start, end):
        for j in range(N):
            if i != j:
                r_ij = X[j] - X[i]
                dist = np.linalg.norm(r_ij)
                if dist > 0:
                    local_a[i] += r_ij / dist**3
    
    a = np.zeros((N, 3))
    comm.Allreduce(local_a, a, op=MPI.SUM)
    
    return a

In [None]:
def measure_mpi_scaling(N_values):
    times_mpi = []
    
    for N in N_values:
        X = np.random.randn(N, 3)
        
        start_time = time.perf_counter()
        get_acceleration_mpi(X)
        end_time = time.perf_counter()
        
        times_mpi.append(end_time - start_time)
        
    return times_mpi

In [None]:
def speedup_val(N_values, times_numba_parallel, times_mpi):
    speedup_factors = np.array(times_numba_parallel) / np.array(times_mpi)
    return print("Speedup Factors (Numba Parallel / MPI):", speedup_factors)

In [None]:
def plot_comparison(N_values, times_numba_parallel, times_mpi):
    plt.figure()
    plt.loglog(N_values, times_numba_parallel, marker='o', linestyle='-', label='Numba Parallel')
    plt.loglog(N_values, times_mpi, marker='s', linestyle='-', label='MPI')
    plt.xlabel('Number of Bodies (N)')
    plt.ylabel('Execution Time (s)')
    plt.legend()
    plt.title('MPI vs. Numba Parallel Performance')
    plt.show()

In [None]:
# Define problem sizes
N_values = [10, 50, 100, 200, 500, 1000]

times_numba_parallel = measure_parallel_scaling(N_values)
times_mpi = measure_mpi_scaling(N_values)

plot_comparison(N_values, times_numba_parallel, times_mpi)
speedup_val(N_values, times_numba_parallel, times_mpi)

### Bonus 2

In [None]:
# Function to compute the acceleration (gravitational) on GPU using JAX
def get_acceleration_jax(X: jnp.ndarray) -> jnp.ndarray:
    N = X.shape[0]
    a = jnp.zeros_like(X)  # Initialize acceleration array

    # Compute the acceleration in a vectorized way
    for i in range(N):
        r_ij = X - X[i]  # Vector from i to each j
        distances = jnp.linalg.norm(r_ij, axis=1)  # Compute distance between i and each j
        # Avoid division by zero by setting very small distances to a large value
        distances = jnp.where(distances == 0, jnp.inf, distances)
        a = a.at[i].add(jnp.sum(r_ij / distances[:, None]**3, axis=0))  # Update acceleration

    return a

In [None]:
# Generate random data (e.g., 1000 particles with 3D coordinates)
X = jnp.array(np.random.rand(1000, 3))

# Run the acceleration computation
a_jax = get_acceleration_jax(X)

# Check the device being used by JAX
print(f"JAX is running on device: {jax.devices()}")

In [None]:
# Measure performance for JAX version
start = time.time()
a_jax_vectorized = get_acceleration_jax_vectorized(X)
jax_time = time.time() - start
print(f"JAX vectorized version time: {jax_time:.4f} seconds")