In [None]:
import torch
import matplotlib.pyplot as plt

# Parameters
N = 100          # Number of particles
G = 1.0          # Gravitational constant
dt = 0.01        # Timestep
steps = 200      # Simulation steps

# Initialize on GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pos = torch.rand(N, 2, device=device) * 10  # Random x, y in [0, 10]
vel = torch.zeros(N, 2, device=device)      # Start at rest
mass = torch.ones(N, device=device)         # Equal masses

# Simulation loop
positions = [pos.cpu().numpy()]  # Store for plotting
for _ in range(steps):
    # Pairwise distances (N x N x 2 tensor)
    r_vec = pos.unsqueeze(1) - pos.unsqueeze(0)  # Shape: (N, N, 2)
    r_sq = (r_vec ** 2).sum(dim=-1) + 1e-6       # Squared distance + softening
    r = r_sq.sqrt()

    # Gravitational force
    force_mag = G * mass.unsqueeze(1) * mass.unsqueeze(0) / r_sq
    force = force_mag.unsqueeze(-1) * (r_vec / r.unsqueeze(-1))  # Direction

    # Net force per particle
    net_force = force.sum(dim=1)  # Sum over all interactions
    accel = net_force / mass.unsqueeze(-1)

    # Update (Euler)
    vel += accel * dt
    pos += vel * dt
    positions.append(pos.cpu().numpy())

# Plot last frame
plt.scatter(positions[-1][:, 0], positions[-1][:, 1], s=10)
plt.title("N-Body Simulation (Final State)")
plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt
import time

def run_nbody_simulation(N, steps, device):
    """Run N-body simulation on specified device"""
    # Initialize
    pos = torch.rand(N, 2, device=device) * 10
    vel = torch.zeros(N, 2, device=device)
    mass = torch.ones(N, device=device)
    G = 1.0
    dt = 0.01
    
    start_time = time.perf_counter()
    
    # Simulation loop
    for _ in range(steps):
        r_vec = pos.unsqueeze(1) - pos.unsqueeze(0)
        r_sq = (r_vec ** 2).sum(dim=-1) + 1e-6
        r = r_sq.sqrt()
        
        force_mag = G * mass.unsqueeze(1) * mass.unsqueeze(0) / r_sq
        force = force_mag.unsqueeze(-1) * (r_vec / r.unsqueeze(-1))
        
        net_force = force.sum(dim=1)
        accel = net_force / mass.unsqueeze(-1)
        
        vel += accel * dt
        pos += vel * dt
    
    end_time = time.perf_counter()
    return end_time - start_time

# Benchmark parameters
particle_counts = range(500, 3000, 500)
steps = 200
num_runs = 3  # Number of runs to average over

print("\nN-body Simulation Benchmark")
print("-" * 60)
print(f"{'Particles':>10} | {'CPU Time (s)':>15} | {'GPU Time (s)':>15} | {'Speedup':>10}")
print("-" * 60)

for N in particle_counts:
    # CPU benchmark
    cpu_times = []
    for _ in range(num_runs):
        cpu_time = run_nbody_simulation(N, steps, device='cpu')
        cpu_times.append(cpu_time)
    avg_cpu_time = sum(cpu_times) / num_runs
    
    # GPU benchmark (if available)
    if torch.cuda.is_available():
        gpu_times = []
        for _ in range(num_runs):
            torch.cuda.synchronize()  # Ensure GPU operations are completed
            gpu_time = run_nbody_simulation(N, steps, device='cuda')
            torch.cuda.synchronize()
            gpu_times.append(gpu_time)
        avg_gpu_time = sum(gpu_times) / num_runs
        speedup = avg_cpu_time / avg_gpu_time
    else:
        avg_gpu_time = float('nan')
        speedup = float('nan')
    
    print(f"{N:>10} | {avg_cpu_time:>15.3f} | {avg_gpu_time:>15.3f} | {speedup:>10.2f}x")

# Original visualization code
N = 100  # For visualization
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pos = torch.rand(N, 2, device=device) * 10
vel = torch.zeros(N, 2, device=device)
mass = torch.ones(N, device=device)

# Store positions for plotting
positions = [pos.cpu().numpy()]
for _ in range(steps):
    r_vec = pos.unsqueeze(1) - pos.unsqueeze(0)
    r_sq = (r_vec ** 2).sum(dim=-1) + 1e-6
    r = r_sq.sqrt()
    
    force_mag = G * mass.unsqueeze(1) * mass.unsqueeze(0) / r_sq
    force = force_mag.unsqueeze(-1) * (r_vec / r.unsqueeze(-1))
    
    net_force = force.sum(dim=1)
    accel = net_force / mass.unsqueeze(-1)
    
    vel += accel * dt
    pos += vel * dt
    positions.append(pos.cpu().numpy())

# Plot last frame
plt.scatter(positions[-1][:, 0], positions[-1][:, 1], s=10)
plt.title("N-Body Simulation (Final State)")
plt.show()


In [None]:
import torch
import triton
import triton.language as tl

DEVICE = torch.device('cuda:0')


@triton.jit
def nbody_force_kernel(
    pos_x_ptr: tl.tensor,
    pos_y_ptr: tl.tensor,
    mass_ptr: tl.tensor, 
    force_x_ptr: tl.tensor,
    force_y_ptr: tl.tensor,
    n_elements: tl.int32,
    G: tl.float32,
    softening: tl.float32,
    BLOCK_SIZE: tl.constexpr,
) -> None:
    # Get the program ID
    pid = tl.program_id(axis=0)
    
    # Compute indices for this block
    block_start = pid * BLOCK_SIZE 
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    # Load position and mass for this particle block
    # Use explicit pointer arithmetic without array indexing
    x = tl.load(pos_x_ptr + offsets, mask=mask)
    y = tl.load(pos_y_ptr + offsets, mask=mask)
    m = tl.load(mass_ptr + offsets, mask=mask)

    # Initialize force accumulators
    fx = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    fy = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

    # Loop over all particles
    for j in range(0, n_elements, BLOCK_SIZE):
        j_idx = j + tl.arange(0, BLOCK_SIZE)
        j_mask = j_idx < n_elements
        
        # Load other particles
        xj = tl.load(pos_x_ptr + j_idx, mask=j_mask)
        yj = tl.load(pos_y_ptr + j_idx, mask=j_mask)
        mj = tl.load(mass_ptr + j_idx, mask=j_mask)

        # Compute forces for each pair
        dx = xj - x
        dy = yj - y
        r_sq = dx * dx + dy * dy + softening
        r_inv = 1.0 / tl.sqrt(r_sq)
        f = G * m * mj * r_inv * r_inv * r_inv
        
        fx = fx + tl.where(mask, f * dx, 0.0)
        fy = fy + tl.where(mask, f * dy, 0.0)

    # Store results
    tl.store(force_x_ptr + offsets, fx, mask=mask)
    tl.store(force_y_ptr + offsets, fy, mask=mask)

def compute_forces(pos_x, pos_y, mass, G=1.0, softening=1e-6):
    # Ensure inputs are float32 and on GPU
    pos_x = pos_x.contiguous().to(torch.float32)
    pos_y = pos_y.contiguous().to(torch.float32)
    mass = mass.contiguous().to(torch.float32)
    
    N = pos_x.shape[0]
    force_x = torch.empty_like(pos_x, dtype=torch.float32, device=DEVICE)
    force_y = torch.empty_like(pos_y, dtype=torch.float32, device=DEVICE)

    assert pos_x.device == DEVICE and pos_y.device == DEVICE \
        and mass.device == DEVICE and force_x.device == DEVICE \
            and force_y.device == DEVICE
    
    grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )

    nbody_force_kernel[grid](
        pos_x_ptr=pos_x,
        pos_y_ptr=pos_y,
        mass_ptr=mass,
        force_x_ptr=force_x,
        force_y_ptr=force_y,
        n_elements=N,
        G=G,
        softening=softening,
        BLOCK_SIZE=1024,
    )
    return force_x, force_y

In [None]:
# Benchmark Torch vs Triton implementation
import time
import numpy as np

def compute_forces_torch(pos_x, pos_y, mass, G=1.0, softening=1e-6):
    # Compute pairwise distance vectors
    dx = pos_x.unsqueeze(0) - pos_x.unsqueeze(1)  # NxN matrix
    dy = pos_y.unsqueeze(0) - pos_y.unsqueeze(1)  # NxN matrix
    
    # Compute inverse distances cubed 
    r_sq = dx * dx + dy * dy + softening
    r_inv_cube = 1.0 / torch.sqrt(r_sq) ** 3
    
    # Compute forces
    mass_matrix = mass.unsqueeze(0) * mass.unsqueeze(1)  # NxN matrix
    fx = G * (mass_matrix * r_inv_cube * dx).sum(dim=1)
    fy = G * (mass_matrix * r_inv_cube * dy).sum(dim=1)
    
    return fx, fy

In [None]:
# Verify Triton implementation matches PyTorch
N = 1000
pos_x = torch.rand(N, device=DEVICE) * 10.0
pos_y = torch.rand(N, device=DEVICE) * 10.0
mass = torch.ones(N, device=DEVICE)

# Compute forces using both implementations
fx_torch, fy_torch = compute_forces_torch(pos_x, pos_y, mass)
fx_triton, fy_triton = compute_forces(pos_x, pos_y, mass)

# Compare results
fx_diff = torch.abs(fx_torch - fx_triton).mean()
fy_diff = torch.abs(fy_torch - fy_triton).mean()

print("Verification against PyTorch implementation:")
print(f"Average absolute difference in fx: {fx_diff:.3e}")
print(f"Average absolute difference in fy: {fy_diff:.3e}")



In [None]:


# Test sizes
sizes = [1000, 5000, 10000, 20000]
times_torch = []
times_triton = []

print("Benchmarking N-body force computation:")
print(f"{'N':>10} {'Torch (ms)':>12} {'Triton (ms)':>12} {'Speedup':>10}")
print("-" * 46)

for N in sizes:
    # Generate random test data
    pos_x = torch.rand(N, device=DEVICE) * 10.0
    pos_y = torch.rand(N, device=DEVICE) * 10.0
    mass = torch.ones(N, device=DEVICE)
    
    # Warmup
    _ = compute_forces_torch(pos_x, pos_y, mass)
    _ = compute_forces(pos_x, pos_y, mass)
    torch.cuda.synchronize()
    
    # Benchmark Torch
    t0 = time.perf_counter()
    for _ in range(3):
        fx_torch, fy_torch = compute_forces_torch(pos_x, pos_y, mass)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    torch_time = (t1 - t0) * 1000 / 3  # Convert to ms
    
    # Benchmark Triton
    t0 = time.perf_counter()
    for _ in range(3):
        fx_triton, fy_triton = compute_forces(pos_x, pos_y, mass)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    triton_time = (t1 - t0) * 1000 / 3  # Convert to ms
    
    # Store results
    times_torch.append(torch_time)
    times_triton.append(triton_time)
    
    # Print results
    speedup = torch_time / triton_time
    print(f"{N:>10d} {torch_time:>12.2f} {triton_time:>12.2f} {speedup:>10.2f}x")

# Plot results
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(sizes, times_torch, 'o-', label='PyTorch')
plt.plot(sizes, times_triton, 'o-', label='Triton')
plt.xlabel('Number of particles')
plt.ylabel('Time (ms)')
plt.title('N-body Force Computation Performance')
plt.legend()
plt.grid(True)
plt.yscale('log')
plt.show()
