# EFM First-Principles Simulation: Batch-Optimized for Maximum Performance

## Objective: A High-Throughput Test of Self-Organization

This notebook implements a crucial optimization to the first-principles simulation. Analysis of the previous run revealed a significant performance bottleneck: the `512³` simulation was too small to fully saturate the computational capacity of the NVIDIA A100 GPU, leading to low resource utilization and extended runtimes.

To solve this, this version has been re-architected to use **batch processing**. Instead of running one simulation, it runs a **batch of simulations simultaneously** in a single pass. This maximizes GPU parallelism and throughput, dramatically reducing the total computation time.

**Key Optimizations Implemented:**
1.  **Batch Processing:** All core tensors are now shaped `(B, N, N, N)`, where `B` is the batch size. This allows the GPU to process multiple simulations in parallel.
2.  **GPU-Centric JIT Compilation:** All core functions are JIT-compiled to handle batched tensors, ensuring maximum computational efficiency.
3.  **Dynamic Parameter Fields:** The stability (`m`) and emergence (`g`) parameters are calculated dynamically for each simulation within the batch, based on local field density.
4.  **Numerical Stability:** The refined physics model incorporating `g_min` and `ν∇⁴φ` biharmonic damping is retained to ensure stability across the entire run.

This notebook represents the most computationally efficient method for testing the EFM's self-organization hypothesis on the given hardware.

In [None]:
import os
import torch
import torch.nn.functional as F
import gc
from tqdm.notebook import tqdm
import numpy as np
import time
from datetime import datetime
import matplotlib.pyplot as plt

try:
    from google.colab import drive
    drive.mount('/content/drive')
except ImportError:
    pass # Continue silently if not in Colab

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print(f"Using GPU: {torch.cuda.get_device_name(device)}, VRAM: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB")
else:
    device = torch.device('cpu')
    print("No GPU available, running on CPU.")

data_path_dynamic = '/content/drive/My Drive/EFM_Simulations/data/FirstPrinciples_Dynamic_N512_Batch/'
os.makedirs(data_path_dynamic, exist_ok=True)
print(f"Batch-Optimized Simulation Data will be saved to: {data_path_dynamic}")

In [None]:
config = {
    'batch_size': 4, # <-- KEY OPTIMIZATION
    'N': 512,
    'L_sim_unit': 40.0,
    'T_steps': 100000,
    'dt_cfl_factor': 0.001,
    'c_sim_unit': 1.0,

    # Base EFM Parameters
    'k_efm_gravity_coupling': 0.005,
    'eta_sim': 0.01,
    'alpha_sim': 0.1,
    'delta_sim': 0.0002,

    # Dynamic Parameter and Stability Config
    'rho_ref': 1.5,
    'm_ref': 1.0,
    'g_ref': 0.1,
    'g_sign_threshold': 1.0,
    'g_min': 1e-5,
    'nu_damping': 1e-4,
    
    # Initial conditions
    'initial_perturbation_amplitude': 15.0,
    'initial_perturbation_width': 2.0,
    'background_noise_amplitude': 1.0e-4,

    'history_every_n_steps': 500
}

config['dx_sim_unit'] = config['L_sim_unit'] / config['N']
config['dt_sim_unit'] = config['dt_cfl_factor'] * config['dx_sim_unit'] / config['c_sim_unit']
config['run_id'] = f"DynamicParams_B{config['batch_size']}_N{config['N']}_Stable"

print(f"--- EFM Batch-Optimized Simulation Configuration ({config['run_id']}) ---")
for key, value in config.items():
    print(f"{key}: {value}")

## Core Simulation Functions (Batch-Aware & JIT-Compiled)

In [None]:
@torch.jit.script
def conv_laplacian_batch_gpu(phi_batch: torch.Tensor, dx: float) -> torch.Tensor:
    """JIT-compiled 3D Laplacian for a batch of fields.
    Input shape: (B, D, H, W). Output shape: (B, D, H, W)."""
    B, D, H, W = phi_batch.shape
    # Pad the spatial dimensions (D, H, W)
    phi_padded = F.pad(phi_batch.unsqueeze(1), (1,1,1,1,1,1), mode='circular')
    
    stencil = torch.tensor([[[0.,0.,0.],[0.,1.,0.],[0.,0.,0.]],[[0.,1.,0.],[1.,-6.,1.],[0.,1.,0.]],[[0.,0.,0.],[0.,1.,0.],[0.,0.,0.]]], 
                           dtype=phi_batch.dtype, device=phi_batch.device) / (dx**2)
    stencil = stencil.view(1, 1, 3, 3, 3)
    
    # Apply to the whole batch at once
    return F.conv3d(phi_padded, stencil, groups=B, padding=0).squeeze(1)

@torch.jit.script
def biharmonic_operator_batch_gpu(phi_batch: torch.Tensor, dx: float) -> torch.Tensor:
    """Computes the biharmonic operator for a batch of fields."""
    lap_phi = conv_laplacian_batch_gpu(phi_batch, dx)
    return conv_laplacian_batch_gpu(lap_phi, dx)

@torch.jit.script
def nlkg_derivative_dynamic_batch_gpu(phi: torch.Tensor, phi_dot: torch.Tensor, 
                                      k_gravity: float, eta: float, c_sq: float, alpha: float, 
                                      delta: float, dx: float, rho_ref: float, m_ref_sq: float, 
                                      g_ref: float, g_min: float, g_sign_thresh: float, 
                                      nu_damping: float) -> tuple[torch.Tensor, torch.Tensor]:
    phi_f32 = phi.to(torch.float32)
    phi_dot_f32 = phi_dot.to(torch.float32)

    rho_local = k_gravity * phi_f32**2 + 1e-12
    m_sq_field = m_ref_sq * (rho_local / rho_ref)
    g_field_unsigned = g_min + g_ref * (rho_local / rho_ref)
    sign_field = torch.where(rho_local > g_sign_thresh, -1.0, 1.0)
    g_field_signed = sign_field * g_field_unsigned
    
    lap_phi = conv_laplacian_batch_gpu(phi_f32, dx)
    potential_force = m_sq_field * phi_f32 + g_field_signed * torch.pow(phi_f32, 3) + eta * torch.pow(phi_f32, 5)
    
    grad_phi_x = (torch.roll(phi_f32, -1, 1) - torch.roll(phi_f32, 1, 1)) / (2 * dx)
    grad_phi_y = (torch.roll(phi_f32, -1, 2) - torch.roll(phi_f32, 1, 2)) / (2 * dx)
    grad_phi_z = (torch.roll(phi_f32, -1, 3) - torch.roll(phi_f32, 1, 3)) / (2 * dx)
    grad_phi_abs_sq = grad_phi_x**2 + grad_phi_y**2 + grad_phi_z**2

    alpha_term = alpha * phi_f32 * phi_dot_f32 * grad_phi_abs_sq
    delta_term = delta * torch.pow(phi_dot_f32, 2) * phi_f32
    biharmonic_term = nu_damping * biharmonic_operator_batch_gpu(phi_f32, dx)

    phi_ddot = c_sq * lap_phi - potential_force + alpha_term - delta_term - biharmonic_term
    return phi_dot, phi_ddot.to(phi.dtype)

@torch.jit.script
def update_phi_rk4_dynamic_batch_gpu(phi_current, phi_dot_current, dt, dx, params): 
    args = (params[0], params[1], params[2], params[3], params[4], dx, 
            params[5], params[6], params[7], params[8], params[9], params[10])

    k1_v, k1_a = nlkg_derivative_dynamic_batch_gpu(phi_current, phi_dot_current, *args)
    k2_v, k2_a = nlkg_derivative_dynamic_batch_gpu(phi_current + 0.5*dt*k1_v, phi_dot_current + 0.5*dt*k1_a, *args)
    k3_v, k3_a = nlkg_derivative_dynamic_batch_gpu(phi_current + 0.5*dt*k2_v, phi_dot_current + 0.5*dt*k2_a, *args)
    k4_v, k4_a = nlkg_derivative_dynamic_batch_gpu(phi_current + dt*k3_v, phi_dot_current + dt*k3_a, *args)

    phi_next = phi_current + (dt / 6.0) * (k1_v + 2*k2_v + 2*k3_v + k4_v)
    phi_dot_next = phi_dot_current + (dt / 6.0) * (k1_a + 2*k2_a + 2*k3_a + k4_a)
    return phi_next, phi_dot_next

print("Batch-Aware, JIT-Optimized simulation functions defined.")

In [None]:
if __name__ == '__main__':
    print("--- INITIATING BATCH-OPTIMIZED FIRST-PRINCIPLES SIMULATION ---")
    torch.manual_seed(42)
    B, N = config['batch_size'], config['N']

    # --- Initialize Batched Fields --- 
    coords = torch.linspace(-config['L_sim_unit']/2, config['L_sim_unit']/2, N, device=device)
    X, Y, Z = torch.meshgrid(coords, coords, coords, indexing='ij')
    r_sq = X**2 + Y**2 + Z**2
    
    # Create a batch of initial states
    central_pulse = config['initial_perturbation_amplitude'] * torch.exp(-r_sq / (config['initial_perturbation_width']**2))
    # Add slightly different noise to each simulation in the batch
    noise = torch.rand(B, N, N, N, device=device) * config['background_noise_amplitude']
    phi_current = (central_pulse.unsqueeze(0) + noise).to(torch.float16)
    phi_dot_current = torch.zeros_like(phi_current, dtype=torch.float16)
    del X, Y, Z, r_sq, coords, central_pulse, noise; gc.collect(); torch.cuda.empty_cache()
    print(f"Initialized a batch of {B} simulations on the GPU.")

    # --- Pack config into a tensor for JIT --- 
    config_params = torch.tensor([
        config['k_efm_gravity_coupling'], config['eta_sim'], config['c_sim_unit']**2, config['alpha_sim'], config['delta_sim'],
        config['rho_ref'], config['m_ref']**2, config['g_ref'], config['g_min'], config['g_sign_threshold'], config['nu_damping']
    ], device=device, dtype=torch.float32)

    # --- Simulation Loop ---
    pbar = tqdm(range(config['T_steps']), desc=f"Batch Sim (B={B}, N={N}³)")
    sim_start_time = time.time()

    for t_step in pbar:
        phi_current, phi_dot_current = update_phi_rk4_dynamic_batch_gpu(
            phi_current, phi_dot_current, config['dt_sim_unit'], config['dx_sim_unit'], config_params
        )

        if (t_step + 1) % config['history_every_n_steps'] == 0:
            if torch.any(torch.isinf(phi_current)) or torch.any(torch.isnan(phi_current)):
                print(f"\nERROR: NaN/Inf detected at step {t_step + 1}! Halting."); break
            max_phi = torch.max(torch.abs(phi_current)).item()
            pbar.set_postfix({'Max|φ|': f'{max_phi:.3e}'})
            if max_phi > 5e7: print(f"\nWarning: Instability detected. Halting."); break
    
    sim_duration = time.time() - sim_start_time
    print(f"Simulation finished in {sim_duration:.2f} seconds. Effective it/s: {config['T_steps']*B/sim_duration:.2f}")

    # --- Save Final State (saving only the first simulation of the batch for analysis) --- 
    final_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    final_data_filename = os.path.join(data_path_dynamic, f"FINAL_DATA_{config['run_id']}_{final_timestamp}.npz")
    np.savez_compressed(final_data_filename, 
                        phi_final_cpu=phi_current[0].cpu().numpy(), 
                        config=config)
    print(f"Final state of the first simulation saved to {final_data_filename}")

    del phi_current, phi_dot_current, config_params; gc.collect(); torch.cuda.empty_cache()
    print("\n--- SIMULATION COMPLETE. ANALYSIS CAN NOW PROCEED. ---")