# EFM LSS & Clustering Validation (HPC Optimized with Multi-GPU DDP)

This notebook performs a simulation to validate EFM clustering scales on a CoCalc H100x4 instance using PyTorch's DistributedDataParallel (DDP) for multi-GPU optimization. It incorporates the EFM self-gravity term, aims for numerical stability with appropriate parameters derived from EFM first principles for LSS, and uses local storage with intermediate checkpointing. This version replaces DataParallel with DDP, minimizes CPU-GPU transfers, and optimizes for H100 GPUs.

## Objectives
- Simulate Large-Scale Structure (LSS) formation using the EFM NLKG equation with the `8πGkφ²` self-gravity term on a large grid (N=400).
- Optimize performance using DDP across 4x NVIDIA H100 GPUs.
- Utilize Gaussian noise as initial conditions.
- Compute power spectrum P(k) and correlation function ξ(r).

## Critical Note on Parameters:
**The physical parameters `m_sim`, `g_sim`, `eta_sim`, `k_sim`, and `G_sim` in the 'Configuration' cell below MUST be carefully scaled from a known successful dimensionless EFM LSS simulation. This notebook uses `m_sim=0` as a primary test based on EFM first principles for LSS. Monitor for numerical instability due to zero mass term.**

In [None]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
import gc
import psutil
from tqdm import tqdm
import numpy as np
import time
from datetime import datetime
from scipy.fft import fftn, fftfreq, ifftn
import torch.nn.functional as F
import torch.amp
import matplotlib.pyplot as plt
import glob

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

def setup_ddp(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

def cleanup_ddp():
    dist.destroy_process_group()

print(f"PyTorch version: {torch.__version__}")
num_gpus = torch.cuda.device_count()
available_devices = [torch.device(f'cuda:{i}') for i in range(num_gpus)]
primary_device = torch.device('cuda:0' if num_gpus > 0 else 'cpu')
print(f"Number of GPUs: {num_gpus}, Available devices: {available_devices}")
for i, device in enumerate(available_devices):
    print(f"GPU {i}: {torch.cuda.get_device_name(device)}, VRAM: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB")
print(f"System RAM: {psutil.virtual_memory().total / 1e9:.2f} GB")

checkpoint_path_lss = os.path.expanduser('~/EFM_Simulations/checkpoints/LSS_HPC_Opt/')
data_path_lss = os.path.expanduser('~/EFM_Simulations/data/LSS_HPC_Opt/')
os.makedirs(checkpoint_path_lss, exist_ok=True)
os.makedirs(data_path_lss, exist_ok=True)
print(f"LSS Checkpoints: {checkpoint_path_lss}")
print(f"LSS Data/Plots: {data_path_lss}")

## Configuration for LSS Simulation (HPC Optimized)
Parameters chosen for N=400 grid, aiming for stability and observability of structures. `m_sim` is set to 0.

In [None]:
config_lss = {}
config_lss['N'] = 400
config_lss['L_Mpc'] = 1000.0
config_lss['dx_Mpc'] = config_lss['L_Mpc'] / config_lss['N']
config_lss['c_si_m_s'] = 3e8
config_lss['c_sim_Mpc_yr'] = config_lss['c_si_m_s'] * (3.156e7 / 3.086e22)
config_lss['dt_cfl_factor'] = 0.000007
config_lss['dt_yr'] = config_lss['dt_cfl_factor'] * config_lss['dx_Mpc'] / config_lss['c_sim_Mpc_yr']
config_lss['T_steps'] = 50000
config_lss['chunk_size'] = config_lss['N'] // 4
if config_lss['N'] % config_lss['chunk_size'] != 0: config_lss['chunk_size'] = config_lss['N']
config_lss['boundary_width_factor'] = 0.0
config_lss['damping_strength'] = 0.0
config_lss['m_sim_yr_inv'] = 0.0
config_lss['g_sim'] = 0.01
config_lss['eta_sim'] = 0.001
config_lss['k_efm_gravity_coupling'] = 0.1
G_si_const = 6.674e-11; M_solar_kg = 1.989e30; Mpc_m = 3.086e22; yr_s = 3.156e7
config_lss['G_sim_Mpc_Msolar_yr'] = G_si_const * (yr_s**2 * M_solar_kg) / (Mpc_m**3)
config_lss['initial_noise_amplitude'] = 0.01
config_lss['run_id'] = f"LSS_N{config_lss['N']}_T{config_lss['T_steps']}_m{config_lss['m_sim_yr_inv']:.1e}_k{config_lss['k_efm_gravity_coupling']:.2e}_CFL{config_lss['dt_cfl_factor']:.1e}_DDP"
config_lss['checkpoint_every_n_steps'] = 1000
config_lss['history_every_n_steps'] = 100

print(f"--- LSS Simulation Configuration ({config_lss['run_id']}) ---")
for key, value in config_lss.items():
    if isinstance(value, (float, np.float32, np.float64)):
        print(f"{key}: {value:.4g}")
    else:
        print(f"{key}: {value}")
total_phys_time_lss_gyr = (config_lss['T_steps'] * config_lss['dt_yr']) / 1e9
print(f"Total Physical Time to be simulated: {total_phys_time_lss_gyr:.4f} Gyr")

## Core Simulation Functions (LSS with Self-Gravity, Multi-GPU Optimized)
Using DDP for multi-GPU optimization, keeping fields on GPU.

In [None]:
class EFMModule(nn.Module):
    def __init__(self, dx, m_sq, g, eta, k_gravity, G_gravity, c_sq):
        super(EFMModule, self).__init__()
        self.dx = dx
        self.m_sq = m_sq
        self.g = g
        self.eta = eta
        self.k_gravity = k_gravity
        self.G_gravity = G_gravity
        self.c_sq = c_sq
        self.stencil = torch.tensor([[[0,0,0],[0,1,0],[0,0,0]],
                                    [[0,1,0],[1,-6,1],[0,1,0]],
                                    [[0,0,0],[0,1,0],[0,0,0]]], dtype=torch.float32)
        self.stencil = self.stencil / (dx**2)
        self.stencil = self.stencil.view(1, 1, 3, 3, 3)

    def conv_laplacian(self, phi):
        phi_f32 = phi.to(torch.float32)
        stencil = self.stencil.to(phi.device)
        phi_f32_reshaped = phi_f32.view(-1, 1, phi_f32.shape[-3], phi_f32.shape[-2], phi_f32.shape[-1])
        phi_padded = F.pad(phi_f32_reshaped, (1,1,1,1,1,1), mode='circular')
        laplacian = F.conv3d(phi_padded, stencil, padding=0)
        return laplacian.view(phi_f32.shape).to(phi.dtype)

    def nlkg_derivative(self, phi, phi_dot):
        lap = self.conv_laplacian(phi)
        potential_force = self.m_sq * phi + self.g * torch.pow(phi, 3) + self.eta * torch.pow(phi, 5)
        source_gravity = 8.0 * float(np.pi) * self.G_gravity * self.k_gravity * torch.pow(phi, 2)
        phi_ddot = self.c_sq * lap - potential_force + source_gravity
        return phi_dot, phi_ddot

def create_damping_mask_lss(N: int, device: torch.device):
    return torch.ones((N, N, N), dtype=torch.float16, device=device)

def update_phi_rk4_chunked_lss(phi, phi_dot, dt_val, m_sim, g_sim_p, eta_sim_p, 
                               k_efm_gravity_coupling_p, G_sim_Mpc_Msolar_yr_p, 
                               c_sim_Mpc_yr_p, dx_Mpc_p, chunk_size_val, rank, device):
    model = EFMModule(dx_Mpc_p, m_sim**2, g_sim_p, eta_sim_p, k_efm_gravity_coupling_p, 
                      G_sim_Mpc_Msolar_yr_p, c_sim_Mpc_yr_p**2).to(device)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    model.eval()

    phi_new = torch.empty_like(phi, dtype=torch.float16, device=device)
    phi_dot_new = torch.empty_like(phi_dot, dtype=torch.float16, device=device)
    N_grid = phi.shape[0]

    for i_chunk_idx in range(0, N_grid, chunk_size_val):
        chunk_slice = slice(i_chunk_idx, min(i_chunk_idx + chunk_size_val, N_grid))
        phi_chunk = phi[chunk_slice]
        phi_dot_chunk = phi_dot[chunk_slice]

        with torch.amp.autocast('cuda', enabled=True):
            k1_v, k1_a = model.module.nlkg_derivative(phi_chunk, phi_dot_chunk)
            phi_temp_k2 = phi_chunk + 0.5 * dt_val * k1_v
            phi_dot_temp_k2 = phi_dot_chunk + 0.5 * dt_val * k1_a
            k2_v, k2_a = model.module.nlkg_derivative(phi_temp_k2, phi_dot_temp_k2)
            phi_temp_k3 = phi_chunk + 0.5 * dt_val * k2_v
            phi_dot_temp_k3 = phi_dot_chunk + 0.5 * dt_val * k2_a
            k3_v, k3_a = model.module.nlkg_derivative(phi_temp_k3, phi_dot_temp_k3)
            phi_temp_k4 = phi_chunk + dt_val * k3_v
            phi_dot_temp_k4 = phi_dot_chunk + dt_val * k3_a
            k4_v, k4_a = model.module.nlkg_derivative(phi_temp_k4, phi_dot_temp_k4)
            phi_new[chunk_slice] = phi_chunk + (dt_val / 6.0) * (k1_v + 2*k2_v + 2*k3_v + k4_v)
            phi_dot_new[chunk_slice] = phi_dot_chunk + (dt_val / 6.0) * (k1_a + 2*k2_a + 2*k3_a + k4_a)

        del phi_chunk, phi_dot_chunk, phi_temp_k2, phi_dot_temp_k2, phi_temp_k3, phi_dot_temp_k3
        del phi_temp_k4, phi_dot_temp_k4, k1_v, k1_a, k2_v, k2_a, k3_v, k3_a, k4_v, k4_a
        gc.collect()

    torch.cuda.synchronize(device)
    return phi_new, phi_dot_new

def compute_field_energy_lss(phi, phi_dot, m_p, g_p, eta_p, chunk_size_val, dx_val, c_sq_p, device):
    total_field_energy_val = torch.tensor(0.0, device=device, dtype=torch.float64)
    N_grid = phi.shape[0]
    vol_element = dx_val**3
    model = EFMModule(dx_val, m_p**2, g_p, eta_p, 0.0, 0.0, c_sq_p).to(device)
    model.eval()

    for i_chunk_idx in range(0, N_grid, chunk_size_val):
        chunk_slice = slice(i_chunk_idx, min(i_chunk_idx + chunk_size_val, N_grid))
        phi_chunk = phi[chunk_slice].to(dtype=torch.float32)
        phi_dot_chunk = phi_dot[chunk_slice].to(dtype=torch.float32)
        with torch.amp.autocast('cuda', enabled=False):
            kinetic_density = 0.5 * torch.pow(phi_dot_chunk, 2)
            potential_density = 0.5 * m_p**2 * torch.pow(phi_chunk, 2) + \
                               0.25 * g_p * torch.pow(phi_chunk, 4) + \
                               (1.0/6.0) * eta_p * torch.pow(phi_chunk, 6)
            lap = model.conv_laplacian(phi_chunk)
            gradient_density = 0.5 * c_sq_p * (lap * phi_chunk)
            chunk_field_energy = torch.sum(kinetic_density + potential_density + gradient_density) * vol_element
        if torch.isnan(chunk_field_energy) or torch.isinf(chunk_field_energy):
            return torch.tensor(float('nan'), device=device)
        total_field_energy_val += chunk_field_energy
        del phi_chunk, phi_dot_chunk, kinetic_density, potential_density, gradient_density, lap
        gc.collect()
    torch.cuda.synchronize(device)
    return total_field_energy_val.item()

def compute_power_spectrum_lss(phi_cpu_np, k_range, dx_val, N_grid):
    if not isinstance(phi_cpu_np, np.ndarray):
        phi_cpu_np = phi_cpu_np.cpu().numpy()
    phi_fft = fftn(phi_cpu_np.astype(np.float32))
    power_spectrum_raw = np.abs(phi_fft)**2 / (N_grid**6)
    del phi_fft
    gc.collect()
    kx = fftfreq(N_grid, d=dx_val) * 2 * np.pi
    ky = fftfreq(N_grid, d=dx_val) * 2 * np.pi
    kz = fftfreq(N_grid, d=dx_val) * 2 * np.pi
    kxx, kyy, kzz = np.meshgrid(kx, ky, kz, indexing='ij', sparse=True)
    k_magnitude = np.sqrt(kxx**2 + kyy**2 + kzz**2)
    del kxx, kyy, kzz
    gc.collect()
    k_bins = np.linspace(k_range[0], k_range[1], 50)
    power_binned, _, _ = np.histogram(k_magnitude.ravel(), bins=k_bins, weights=power_spectrum_raw.ravel(), density=False)
    counts, _, _ = np.histogram(k_magnitude.ravel(), bins=k_bins)
    power_binned = np.divide(power_binned, counts, out=np.zeros_like(power_binned), where=counts!=0)
    k_bin_centers = (k_bins[:-1] + k_bins[1:]) / 2
    del k_magnitude, power_spectrum_raw, counts
    gc.collect()
    return k_bin_centers, power_binned

def compute_correlation_function_lss(phi_cpu_np, dx_val, N_grid, L_box):
    if not isinstance(phi_cpu_np, np.ndarray):
        phi_cpu_np = phi_cpu_np.cpu().numpy()
    phi_fft = fftn(phi_cpu_np.astype(np.float32))
    power_spectrum_raw = np.abs(phi_fft)**2
    del phi_fft
    gc.collect()
    correlation_func_raw = ifftn(power_spectrum_raw).real / (N_grid**3)
    del power_spectrum_raw
    gc.collect()
    indices = np.arange(N_grid) - N_grid // 2
    rx, ry, rz = np.meshgrid(indices * dx_val, indices * dx_val, indices * dx_val, indexing='ij', sparse=True)
    r_magnitude = np.sqrt(rx**2 + ry**2 + rz**2)
    del rx, ry, rz
    gc.collect()
    r_bins = np.linspace(0, L_box/2, 50)
    corr_binned, _, _ = np.histogram(r_magnitude.ravel(), bins=r_bins, weights=correlation_func_raw.ravel())
    counts, _, _ = np.histogram(r_magnitude.ravel(), bins=r_bins)
    corr_binned = np.divide(corr_binned, counts, out=np.zeros_like(corr_binned), where=counts!=0)
    r_bin_centers = (r_bins[:-1] + r_bins[1:]) / 2
    del r_magnitude, correlation_func_raw, counts
    gc.collect()
    return r_bin_centers, corr_binned


## Main Simulation Function (DDP)

In [None]:
def run_simulation(rank, world_size, config_lss, checkpoint_path_lss, data_path_lss):
    setup_ddp(rank, world_size)
    device = torch.device(f'cuda:{rank}')
    torch.cuda.set_device(device)

    if rank == 0:
        print(f"Initializing fields for LSS simulation ({config_lss['run_id']}) on rank {rank}...")
    np.random.seed(42 + rank)
    phi = torch.from_numpy(np.random.randn(config_lss['N'], config_lss['N'], config_lss['N']).astype(np.float16)) * config_lss['initial_noise_amplitude']
    phi = phi.to(device, dtype=torch.float16)
    phi_dot = torch.zeros_like(phi, dtype=torch.float16, device=device)
    damping_mask = create_damping_mask_lss(config_lss['N'], device)
    if rank == 0:
        print(f"LSS fields initialized on GPU {rank}. Phi shape: {phi.shape}, Dtype: {phi.dtype}")

    num_hist_points_lss = config_lss['T_steps'] // config_lss['history_every_n_steps'] + 1
    field_energy_hist_lss = np.zeros(num_hist_points_lss, dtype=np.float64)
    density_norm_hist_lss = np.zeros(num_hist_points_lss, dtype=np.float64)
    hist_idx_lss = 0
    start_step_lss = 0

    if rank == 0:
        checkpoint_file_pattern_lss = os.path.join(checkpoint_path_lss, f"intermediate_CKPT_{config_lss['run_id']}_step_*.npz")
        list_of_intermediate_checkpoints_lss = sorted(glob.glob(checkpoint_file_pattern_lss), key=lambda f: int(os.path.basename(f).split('_step_')[1].split('.npz')[0]), reverse=True)
        if list_of_intermediate_checkpoints_lss:
            latest_intermediate_ckpt_lss = list_of_intermediate_checkpoints_lss[0]
            print(f"Found latest intermediate LSS checkpoint: {latest_intermediate_ckpt_lss}")
            try:
                data_lss_load = np.load(latest_intermediate_ckpt_lss, allow_pickle=True)
                phi = torch.from_numpy(data_lss_load['phi_r_cpu']).to(device, dtype=torch.float16)
                phi_dot = torch.from_numpy(data_lss_load['phi_dot_r_cpu']).to(device, dtype=torch.float16)
                start_step_lss = data_lss_load['last_step'].item() + 1
                loaded_field_energy_hist = data_lss_load['field_energy_history']
                loaded_density_norm_hist = data_lss_load['density_norm_history']
                hist_idx_lss = start_step_lss // config_lss['history_every_n_steps']
                if hist_idx_lss > len(field_energy_hist_lss):
                    field_energy_hist_lss = np.resize(field_energy_hist_lss, hist_idx_lss + (config_lss['T_steps']-start_step_lss)//config_lss['history_every_n_steps'] + 1)
                    density_norm_hist_lss = np.resize(density_norm_hist_lss, hist_idx_lss + (config_lss['T_steps']-start_step_lss)//config_lss['history_every_n_steps'] + 1)
                field_energy_hist_lss[:hist_idx_lss] = loaded_field_energy_hist[:hist_idx_lss]
                density_norm_hist_lss[:hist_idx_lss] = loaded_density_norm_hist[:hist_idx_lss]
                print(f"Resuming LSS simulation from step {start_step_lss}. History index set to {hist_idx_lss}.")
                del data_lss_load, loaded_field_energy_hist, loaded_density_norm_hist
                gc.collect()
            except Exception as e:
                print(f"Error loading intermediate LSS checkpoint: {e}. Starting from scratch.")
                start_step_lss = 0
                hist_idx_lss = 0
        else:
            print("No intermediate LSS checkpoint found. Starting from scratch.")
            if hist_idx_lss < num_hist_points_lss:
                print("Calculating initial observables for LSS simulation...")
                field_energy_hist_lss[hist_idx_lss] = compute_field_energy_lss(phi, phi_dot, config_lss['m_sim_yr_inv'], config_lss['g_sim'], config_lss['eta_sim'], config_lss['chunk_size'], config_lss['dx_Mpc'], config_lss['c_sim_Mpc_yr']**2, device)
                density_norm_hist_lss[hist_idx_lss] = torch.sum(phi.to(torch.float32)**2).item() * config_lss['k_efm_gravity_coupling']
                print(f"Initial Field Energy: {field_energy_hist_lss[hist_idx_lss]:.4g}, Initial Density Norm: {density_norm_hist_lss[hist_idx_lss]:.4g}")
                hist_idx_lss += 1

    dist.barrier()
    if start_step_lss < config_lss['T_steps']:
        if rank == 0:
            print(f"Starting/Resuming LSS simulation from step {start_step_lss} for {config_lss['T_steps'] - start_step_lss} more steps...")
            pbar_lss = tqdm(range(start_step_lss, config_lss['T_steps']), desc=f"LSS Sim ({config_lss['run_id']})", initial=start_step_lss, total=config_lss['T_steps'])
        else:
            pbar_lss = range(start_step_lss, config_lss['T_steps'])

        numerical_error_lss = False
        sim_start_time_lss = time.time()

        for t_step in pbar_lss:
            try:
                if torch.any(torch.isinf(phi)) or torch.any(torch.isnan(phi)) or \
                   torch.any(torch.isinf(phi_dot)) or torch.any(torch.isnan(phi_dot)):
                    if rank == 0:
                        print(f"\nERROR: NaN/Inf detected in LSS fields BEFORE step {t_step + 1}! Stopping.")
                    numerical_error_lss = True
                    break

                phi, phi_dot = update_phi_rk4_chunked_lss(
                    phi, phi_dot, config_lss['dt_yr'],
                    config_lss['m_sim_yr_inv'], config_lss['g_sim'], config_lss['eta_sim'],
                    config_lss['k_efm_gravity_coupling'], config_lss['G_sim_Mpc_Msolar_yr'],
                    config_lss['c_sim_Mpc_yr'], config_lss['dx_Mpc'],
                    config_lss['chunk_size'], rank, device
                )

                if torch.any(torch.isinf(phi)) or torch.any(torch.isnan(phi)):
                    if rank == 0:
                        print(f"\nERROR: NaN/Inf detected in LSS phi AFTER step {t_step + 1}! Stopping.")
                    numerical_error_lss = True
                    break

                if (t_step + 1) % 1000 == 0 and rank == 0:
                    print(f"VRAM usage after step {t_step + 1}:")
                    print(f"{device}: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB allocated, {torch.cuda.memory_reserved(device) / 1e9:.2f} GB reserved")

                if (t_step + 1) % config_lss['history_every_n_steps'] == 0 and rank == 0:
                    if hist_idx_lss < num_hist_points_lss:
                        current_field_energy = compute_field_energy_lss(phi, phi_dot, config_lss['m_sim_yr_inv'], config_lss['g_sim'], config_lss['eta_sim'], config_lss['chunk_size'], config_lss['dx_Mpc'], config_lss['c_sim_Mpc_yr']**2, device)
                        current_density_norm = torch.sum(phi.to(torch.float32)**2).item() * config_lss['k_efm_gravity_coupling']
                        field_energy_hist_lss[hist_idx_lss] = current_field_energy
                        density_norm_hist_lss[hist_idx_lss] = current_density_norm
                        pbar_lss.set_postfix({'E_field': f'{current_field_energy:.3e}', 'Norm': f'{current_density_norm:.3e}'})
                        if np.isnan(current_field_energy) or np.isinf(current_field_energy):
                            print(f"LSS Instability: Energy is NaN/Inf at step {t_step+1}. Stop.")
                            numerical_error_lss = True
                            break
                        hist_idx_lss += 1

                if (t_step + 1) % config_lss['checkpoint_every_n_steps'] == 0 and (t_step + 1) < config_lss['T_steps'] and rank == 0:
                    intermediate_ckpt_file_lss = os.path.join(checkpoint_path_lss, f"intermediate_CKPT_{config_lss['run_id']}_step_{t_step+1}.npz")
                    try:
                        np.savez_compressed(intermediate_ckpt_file_lss,
                                            phi_r_cpu=phi.cpu().numpy(),
                                            phi_dot_r_cpu=phi_dot.cpu().numpy(),
                                            last_step=t_step,
                                            config_lss_saved=config_lss,
                                            field_energy_history=field_energy_hist_lss[:hist_idx_lss],
                                            density_norm_history=density_norm_hist_lss[:hist_idx_lss])
                    except Exception as e:
                        print(f"Error saving intermediate LSS checkpoint: {e}")

            except Exception as e:
                if rank == 0:
                    print(f"ERROR in LSS sim at step {t_step + 1}: {e}")
                    import traceback
                    traceback.print_exc()
                numerical_error_lss = True
                break

        dist.barrier()
        if rank == 0:
            sim_run_duration = time.time() - sim_start_time_lss
            print(f"LSS sim loop finished/resumed in {sim_run_duration:.2f} s. Error: {numerical_error_lss}")
    else:
        if rank == 0:
            print("LSS simulation already completed to T_steps based on loaded checkpoint or start_step issue.")

    if rank == 0:
        final_timestamp_lss = datetime.now().strftime("%Y%m%d_%H%M%S")
        final_checkpoint_filename_lss = os.path.join(checkpoint_path_lss, f"FINAL_CKPT_{config_lss['run_id']}_{final_timestamp_lss}.npz")
        try:
            if not numerical_error_lss and start_step_lss < config_lss['T_steps'] and \
               (config_lss['T_steps'] % config_lss['history_every_n_steps'] != 0 or hist_idx_lss == 0) and \
               hist_idx_lss < len(field_energy_hist_lss):
                field_energy_hist_lss[hist_idx_lss] = compute_field_energy_lss(phi, phi_dot, config_lss['m_sim_yr_inv'], config_lss['g_sim'], config_lss['eta_sim'], config_lss['chunk_size'], config_lss['dx_Mpc'], config_lss['c_sim_Mpc_yr']**2, device)
                density_norm_hist_lss[hist_idx_lss] = torch.sum(phi.to(torch.float32)**2).item() * config_lss['k_efm_gravity_coupling']
                hist_idx_lss += 1

            np.savez_compressed(final_checkpoint_filename_lss,
                                phi_r_final_cpu=phi.cpu().numpy(),
                                phi_dot_r_final_cpu=phi_dot.cpu().numpy(),
                                field_energy_history=field_energy_hist_lss[:hist_idx_lss],
                                density_norm_history=density_norm_hist_lss[:hist_idx_lss],
                                config_lss=config_lss)
            print(f"LSS final state saved to {final_checkpoint_filename_lss}")
        except Exception as e:
            print(f"Error saving final LSS checkpoint: {e}")

    del phi, phi_dot, damping_mask
    gc.collect()
    torch.cuda.empty_cache()
    cleanup_ddp()


## Run Simulation and Analysis

In [None]:
def main():
    world_size = torch.cuda.device_count()
    if world_size < 1:
        print("No GPUs available. Exiting.")
        return
    mp.spawn(run_simulation,
             args=(world_size, config_lss, checkpoint_path_lss, data_path_lss),
             nprocs=world_size,
             join=True)

    print("--- LSS Final Analysis and Plotting (Multi-GPU Optimized) ---")
    plot_config_final_lss = config_lss
    hist_field_energy_plot = np.array([0.0])
    hist_density_norm_plot = np.array([0.0])
    hist_idx_plot_final_lss = 0
    phi_r_final_for_plot = None
    sim_data_available = False

    pattern = os.path.join(checkpoint_path_lss, f"FINAL_CKPT_{config_lss['run_id']}_*.npz")
    files = sorted(glob.glob(pattern), key=os.path.getmtime, reverse=True)
    if not files:
        print(f"No FINAL LSS checkpoint found for {config_lss['run_id']}. Trying intermediate...")
        pattern_int = os.path.join(checkpoint_path_lss, f"intermediate_CKPT_{config_lss['run_id']}_step_*.npz")
        files = sorted(glob.glob(pattern_int), key=lambda f: int(os.path.basename(f).split('_step_')[1].split('.npz')[0]), reverse=True)
        if not files:
            print(f"No intermediate LSS checkpoint found either for {config_lss['run_id']}.")

    if files:
        latest_ckpt = files[0]
        print(f"Loading LSS checkpoint: {latest_ckpt}")
        try:
            data_plot = np.load(latest_ckpt, allow_pickle=True)
            hist_field_energy_plot = data_plot['field_energy_history']
            hist_density_norm_plot = data_plot['density_norm_history']
            phi_data_key = 'phi_r_final_cpu' if 'phi_r_final_cpu' in data_plot else 'phi_r_cpu'
            phi_r_final_for_plot = torch.from_numpy(data_plot[phi_data_key]).to(dtype=torch.float16)
            config_key = 'config_lss' if 'config_lss' in data_plot else 'config_lss_saved'
            if config_key in data_plot:
                plot_config_final_lss = data_plot[config_key].item()
            hist_idx_plot_final_lss = len(hist_field_energy_plot)
            sim_data_available = True
            print("LSS checkpoint data loaded for plotting.")
        except Exception as e:
            print(f"Error loading LSS checkpoint: {e}")

    if sim_data_available and hist_idx_plot_final_lss > 0:
        steps_rec_plot_lss = np.arange(hist_idx_plot_final_lss) * plot_config_final_lss.get('history_every_n_steps', 100)
        actual_len_plot_lss = hist_idx_plot_final_lss

        plt.figure(figsize=(14, 6))
        plt.subplot(1, 2, 1)
        plt.plot(steps_rec_plot_lss, hist_field_energy_plot[:actual_len_plot_lss], marker='.')
        plt.title(f"Field Energy Evo (N={plot_config_final_lss.get('N')})")
        plt.xlabel('Step')
        plt.grid(True)
        plt.ticklabel_format(style='sci', axis='y', scilimits=(-3,3), useMathText=True)
        plt.subplot(1, 2, 2)
        plt.plot(steps_rec_plot_lss, hist_density_norm_plot[:actual_len_plot_lss], marker='.')
        plt.title(f"Density Norm Evo (N={plot_config_final_lss.get('N')})")
        plt.xlabel('Step')
        plt.grid(True)
        plt.ticklabel_format(style='sci', axis='y', scilimits=(-3,3), useMathText=True)
        plt.tight_layout()
        plt.suptitle(f"LSS Evolution Metrics ({plot_config_final_lss.get('run_id')})", fontsize=14, y=1.02)
        plt.savefig(f"{data_path_lss}lss_evo_metrics_{plot_config_final_lss.get('run_id', 'plot_run')}.png")
        plt.show()
        plt.close()

        if actual_len_plot_lss > 1:
            print(f"\n--- Final LSS Properties ({plot_config_final_lss.get('run_id')}) ---")
            print(f"Final Field Energy: {hist_field_energy_plot[actual_len_plot_lss-1]:.4g}")
            print(f"Final Density Norm: {hist_density_norm_plot[actual_len_plot_lss-1]:.4g}")
            if hist_density_norm_plot[actual_len_plot_lss-1] < 1e-7 * plot_config_final_lss.get('k_efm_gravity_coupling', 0.01) * (plot_config_final_lss.get('initial_noise_amplitude', 0.01)**2 * plot_config_final_lss.get('N')**3):
                print("WARNING: Field appears to have decayed to very low values!")

        if phi_r_final_for_plot is not None and phi_r_final_for_plot.ndim == 3 and phi_r_final_for_plot.shape[0] > 1 and torch.max(torch.abs(phi_r_final_for_plot)) > 1e-7:
            print("Computing P(k) and xi(r) for LSS final state...")
            phi_r_final_np = phi_r_final_for_plot.cpu().numpy().astype(np.float32)
            k_range_pk = [2*np.pi/plot_config_final_lss['L_Mpc']*1.5, 2*np.pi/20.0]
            k_bins_pk, pk_vals = compute_power_spectrum_lss(phi_r_final_np, k_range=k_range_pk, dx_val=plot_config_final_lss['dx_Mpc'], N_grid=plot_config_final_lss['N'])
            r_bins_xi, xi_vals = compute_correlation_function_lss(phi_r_final_np, dx_val=plot_config_final_lss['dx_Mpc'], N_grid=plot_config_final_lss['N'], L_box=plot_config_final_lss['L_Mpc'])
            del phi_r_final_np
            gc.collect()

            plt.figure(figsize=(16,6))
            plt.subplot(1,2,1)
            plt.loglog(k_bins_pk, pk_vals)
            plt.title('LSS Power Spectrum P(k)')
            plt.xlabel('k (Mpc$^{-1}$)')
            plt.ylabel('P(k)')
            plt.grid(True, which='both', linestyle=':')
            plt.axvline(2*np.pi/147, color='r', linestyle='--', label='147 Mpc')
            plt.axvline(2*np.pi/628, color='g', linestyle='--', label='628 Mpc')
            plt.legend()
            plt.subplot(1,2,2)
            plt.plot(r_bins_xi, xi_vals)
            plt.title('LSS Correlation Function $\xi$(r)')
            plt.xlabel('r (Mpc)')
            plt.ylabel('$\xi$(r)')
            plt.grid(True, linestyle=':')
            plt.axvline(147, color='r', linestyle='--', label='147 Mpc')
            plt.axvline(628, color='g', linestyle='--', label='628 Mpc')
            plt.legend()
            abs_max_xi = np.max(np.abs(xi_vals[1:])) if len(xi_vals[1:]) > 0 else 0.1
            plt.ylim(-0.5*abs_max_xi if abs_max_xi > 0 else -0.1, 1.1*abs_max_xi if abs_max_xi > 0 else 0.1)
            plt.tight_layout()
            plt.suptitle(f"LSS Observables ({plot_config_final_lss.get('run_id')})", fontsize=14, y=1.02)
            plt.savefig(f"{data_path_lss}lss_observables_{plot_config_final_lss.get('run_id', 'plot_run')}.png")
            plt.show()
            plt.close()

            if len(xi_vals) > 1 and np.any(np.abs(xi_vals[1:]) > 1e-6):
                print(f"Correlation peak (max of abs after r=0) near r ~ {r_bins_xi[np.argmax(np.abs(xi_vals[1:]))+1] if len(xi_vals[1:]) > 0 else 'N/A'} Mpc")
            else:
                print("Correlation function is effectively zero.")
            if len(pk_vals) > 0 and np.any(pk_vals > 1e-9):
                print(f"Power spectrum peak (max value) at k ~ {k_bins_pk[np.argmax(pk_vals)]:.3f} Mpc^-1 (scale ~ {2*np.pi/k_bins_pk[np.argmax(pk_vals)]:.1f} Mpc)")
            else:
                print("Power spectrum is effectively zero.")

        else:
            print("Final LSS field data not suitable for P(k)/xi(r) plotting (e.g., effectively all zeros or not 3D).")
    else:
        print("LSS simulation history not available or error occurred. Cannot plot.")

    if 'phi_r_final_for_plot' in locals() and phi_r_final_for_plot is not None:
        del phi_r_final_for_plot
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("LSS plotting and analysis cell finished.")

if __name__ == '__main__':
    main()
