In [1]:
%pip install torch numpy matplotlib imageio pillow opencv-python tifffile ffmpeg

Note: you may need to restart the kernel to use updated packages.


In [2]:
# %% [code]
# ==============================================================================
# AETHERIA: MASTER PYTHON SCRIPT (Consolidated from Notebook)
# ==============================================================================
# This script consolidates all code cells from the AETHERIA notebook
# into a single runnable Python file.
#
# To run:
# 1. Ensure all required libraries are installed:
#    pip install torch numpy matplotlib imageio pillow
# 2. Configure the parameters in the "PHASE 3: GLOBAL PARAMETERS" section below.
# 3. Run the script from your terminal:
#    python your_script_name.py
# ==============================================================================

# %% [markdown]
# # PHASE 0: SETUP & IMPORTS
# ---
# Import necessary libraries.
# Note: IPython imports are removed as this is a standalone script.
# %% [code]
import torch
import torch.nn as nn
import numpy as np
import imageio.v2 as imageio
import time
import os
import glob
from PIL import Image
import re # To parse checkpoint filenames
import gc # Import garbage collection for memory management
# from IPython.display import display # Removed (Notebook-specific)
# from IPython.display import Video # Removed (Notebook-specific)

# ------------------------------------------------------------------------------
# 0.1: SETUP AND CONTROL CONSTANTS
# ------------------------------------------------------------------------------
# Set up the device to use CUDA if available, otherwise use CPU.
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        print(f"Detected {torch.cuda.device_count()} GPUs. DataParallel will be managed internally.")
    else:
        print("Detected 1 GPU.")
else:
    print("‚ö†Ô∏è  No GPU detected. Training and simulation will be slow on CPU.")

# Enable cuDNN auto-tuner for potential speed improvements on GPU.
torch.backends.cudnn.benchmark = True

# --- Checkpoint Directory Configuration ---
# Directory to save training checkpoints.
CHECKPOINT_DIR = "checkpoints_optimized"
# Directory to save large simulation checkpoints.
LARGE_SIM_CHECKPOINT_DIR = "large_sim_checkpoints_1024" # Directory for simulation

# Create directories if they don't exist.
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LARGE_SIM_CHECKPOINT_DIR, exist_ok=True)

print(f"Training Checkpoint Directory: {CHECKPOINT_DIR}")
print(f"Large Simulation Checkpoint Directory: {LARGE_SIM_CHECKPOINT_DIR}")


# ------------------------------------------------------------------------------
# 0.3: INPUT MODEL CONFIGURATION (OPTIONAL)
# ------------------------------------------------------------------------------

# Set to True if you want to attempt loading a trained model from a specific input path.
# If set to False, the script will look for locally trained models in CHECKPOINT_DIR.
USE_INPUT_MODEL = False

# If USE_INPUT_MODEL is True, specify the path to the .pth file of the model you want to load.
INPUT_MODEL_PATH = "" # <--- Specify the path here if USE_INPUT_MODEL is True

# Internal variable to track if an input model will be used.
USING_INPUT_MODEL_FLAG = False

if USE_INPUT_MODEL and INPUT_MODEL_PATH and os.path.exists(INPUT_MODEL_PATH):
    print(f"\n‚úÖ Configured to use input model from: {INPUT_MODEL_PATH}")
    USING_INPUT_MODEL_FLAG = True
elif USE_INPUT_MODEL and (not INPUT_MODEL_PATH or not os.path.exists(INPUT_MODEL_PATH)):
    print(f"\n‚ö†Ô∏è USE_INPUT_MODEL is True, but INPUT_MODEL_PATH is not specified or the file does not exist at '{INPUT_MODEL_PATH}'.")
    print("Will look for locally trained models in CHECKPOINT_DIR instead.")
    USING_INPUT_MODEL_FLAG = False
else:
    print("\n‚ÑπÔ∏è Not configured to use a specific input model. Will look for locally trained models in CHECKPOINT_DIR.")
    USING_INPUT_MODEL_FLAG = False

USE_INPUT_MODEL_FOR_TRAINING_RESUME = False
if USE_INPUT_MODEL_FOR_TRAINING_RESUME and USING_INPUT_MODEL_FLAG:
     print("‚ö†Ô∏è Warning: USE_INPUT_MODEL_FOR_TRAINING_RESUME is True and an input model was configured.")
     print("This may cause issues if the input model is not a full training checkpoint.")
     print("Ensure the file at INPUT_MODEL_PATH is a full training checkpoint if you continue.")

print("Input model configuration (optional) completed.")


# %% [markdown]
# # PHASE 3: GLOBAL PARAMETERS & CONFIGURATION
# ---
# Configure all tunable parameters for the pipeline here.
# %% [code]
# ==============================================================================
# --- PHASE 3: GLOBAL PARAMETERS & CONFIGURATION ---
# ==============================================================================

# --- Execution Control ---
# Define which parts of the pipeline will run.
RUN_TRAINING = False          # Set to True to run the training phase.
RUN_POST_TRAINING_VIZ = True # Set to True to run post-training visualization (at training size).
RUN_LARGE_SIM = False          # Set to True to run the prolonged large simulation.

CONTINUE_TRAINING = False      # Set to True to resume training from the last checkpoint in CHECKPOINT_DIR.

# --- Simulation Parameters (Training Size) ---
GRID_SIZE_TRAINING = 256      # Grid size for training (optimized)
D_STATE = 32                  # Dimension of the quantum state (optimized)
HIDDEN_CHANNELS = 256         # Channels in the M-Law's deep network (optimized)

# --- Optimized Training Parameters ---
EPISODES_TO_ADD = 1000         # Number of training episodes to run
STEPS_PER_EPISODE = 150        # Number of simulation steps per episode
LR_RATE_M = 5e-6              # Learning rate for the optimizer
PERSISTENCE_COUNT = 10        # (k in BPTT-k)

# --- Reward Parameters (Optimized Annealing) ---
ALPHA_START = 3.0             # Initial weight for R_Density_Target
ALPHA_END = 30.0              # Final weight for R_Density_Target
GAMMA_START = 3.0             # Initial weight for R_Stability
GAMMA_END = 0.6               # Final weight for R_Stability
BETA_CAUSALITY = 3.0          # Fixed weight for R_Causality

# --- Weights for New Rewards ---
LAMBDA_ACTIVITY_VAR = 1.0     # Weight for R_Activity_Var
LAMBDA_VELOCIDAD = 0.5        # Weight for R_Velocidad

# --- "Target-Seeking" and Penalty Parameters (Optimized) ---
TARGET_STD_DENSITY = 1.2      # Target standard deviation for density
EXPLOSION_THRESHOLD = 0.7     # Max density per cell to trigger penalty
EXPLOSION_PENALTY_MULTIPLIER = 20.0 # Multiplier for explosion penalty

# --- Stagnation Parameters (Optimized) ---
STAGNATION_WINDOW = 500       # Episodes without improvement before stagnation
MIN_LOSS_IMPROVEMENT = 5e-6   # Minimum required loss improvement

# --- Reactivation Parameters (Optimized) ---
REACTIVATION_COUNT = 2        # Number of times to attempt reactivation
REACTIVATION_STATE_MODE = 'random' # 'random', 'seeded', 'complex_noise'
REACTIVATION_LR_MULTIPLIER = 0.5 # Factor to multiply LR by on reactivation

# --- Gradient Clipping ---
GRADIENT_CLIP = 0.85          # Threshold for gradient clipping

# --- Checkpointing Frequency (Training) ---
SAVE_EVERY_EPISODES = 50      # Save a training checkpoint every N episodes.

# --- Post-Training Visualization Parameters (Training Size) ---
NUM_FRAMES_VIZ = 1500         # Number of simulation steps for the viz video
FPS_VIZ_TRAINING = 24         # Frames per second for the viz video

# --- Large Simulation Parameters (Inference Size) ---
GRID_SIZE_INFERENCE = 1250     # Grid size for large simulation
NUM_INFERENCE_STEPS = 50000    # Total number of large simulation steps

# --- Initialization Configuration (Inference) ---
INITIAL_STATE_MODE_INFERENCE = 'random' # 'random', 'seeded', 'complex_noise'
LOAD_STATE_CHECKPOINT_INFERENCE = True # Load existing large sim state checkpoint
STATE_CHECKPOINT_PATH_INFERENCE = "" # Specific path to state checkpoint

# --- Checkpointing Frequency (Large Simulation) ---
LARGE_SIM_CHECKPOINT_INTERVAL = 1500 # Save a large sim state checkpoint every N steps.

# --- Video Saving Parameters (Large Simulation) ---
VIDEO_FPS = 20                # Frames per second for generated videos
VIDEO_SAVE_INTERVAL_STEPS = 50 # Save a video frame every N simulation steps
VIDEO_DOWNSCALE_FACTOR = 1    # Factor to reduce video resolution (1 = no downscale)
VIDEO_QUALITY = 8             # Video quality (0-51, lower is better)

# --- Real-time Visualization Parameters (Large Simulation) ---
REAL_TIME_VIZ_INTERVAL = 100  # Show a frame every N steps (set to None or 0 to disable)
REAL_TIME_VIZ_TYPE = 'phase' # 'density', 'channels', 'magnitude', 'phase', 'change'
REAL_TIME_VIZ_DOWNSCALE = 4   # Factor to reduce resolution for real-time display

print("Global parameters set.")


# %% [markdown]
# # PHASE 1: QCA ENGINE CLASSES
# ---
# Definition of `QCA_State`, `QCA_Operator_Deep`, and `Aetheria_Motor`.
# %% [code]
# ------------------------------------------------------------------------------
# 1.1: QCA_State Class (The State of the Universe)
# ------------------------------------------------------------------------------
class QCA_State:
    def __init__(self, size, d_state):
        self.size = size
        self.d_state = d_state
        self.x_real = torch.zeros(1, size, size, d_state, device=DEVICE)
        self.x_imag = torch.zeros(1, size, size, d_state, device=DEVICE)

    def _reset_state_random(self):
        """Initializes the state with low-amplitude noise and normalizes it."""
        noise_r = (torch.rand(1, self.size, self.size, self.d_state, device=DEVICE) * 2 - 1) * 1e-2
        noise_i = (torch.rand(1, self.size, self.size, self.d_state, device=DEVICE) * 2 - 1) * 1e-2
        self.x_real.data = noise_r
        self.x_imag.data = noise_i
        self.normalize_state()

    def _reset_state_seeded(self):
        """Initializes the state with a vacuum and a 'seed' of activity in the center."""
        self.x_real.data.fill_(0)
        self.x_imag.data.fill_(0)
        center_x, center_y = self.size // 2, self.size // 2
        seed_size = max(1, self.size // 64)
        for dx in range(-seed_size, seed_size + 1):
            for dy in range(-seed_size, seed_size + 1):
                if 0 <= center_x + dx < self.size and 0 <= center_y + dy < self.size:
                        if self.d_state > 3: # Ensure enough channels exist
                            self.x_real[0, center_y + dy, center_x + dx, 0] = 0.5
                            self.x_imag[0, center_y + dy, center_x + dx, 1] = 0.5
                            self.x_real[0, center_y + dy, center_x + dx, 2] = -0.5
                            self.x_imag[0, center_y + dy, center_x + dx, 3] = -0.5
        self.normalize_state()

    def _reset_state_complex_noise(self):
        """Initializes the state with a structured complex noise pattern."""
        y_coords, x_coords = torch.meshgrid(torch.linspace(-1, 1, self.size, device=DEVICE),
                                            torch.linspace(-1, 1, self.size, device=DEVICE),
                                            indexing='ij')
        radial_dist = torch.sqrt(x_coords**2 + y_coords**2)
        angle = torch.atan2(y_coords, x_coords)
        pattern1_r = torch.sin(x_coords * 10 + angle * 5) * 0.1
        pattern1_i = torch.cos(y_coords * 12 + angle * 6) * 0.1
        pattern2_r = torch.sin(radial_dist * 15 + x_coords * 8) * 0.05
        pattern2_i = torch.cos(radial_dist * 18 + y_coords * 9) * 0.05
        noise_r = (torch.rand(self.size, self.size, self.d_state, device=DEVICE) * 2 - 1) * 1e-3
        noise_i = (torch.rand(self.size, self.size, self.d_state, device=DEVICE) * 2 - 1) * 1e-3
        if self.d_state > 0: noise_r[:, :, 0] += pattern1_r
        if self.d_state > 1: noise_i[:, :, 1] += pattern1_i
        if self.d_state > 2: noise_r[:, :, 2] += pattern2_r
        if self.d_state > 3: noise_i[:, :, 3] += pattern2_i
        self.x_real.data = noise_r.unsqueeze(0).to(DEVICE)
        self.x_imag.data = noise_i.unsqueeze(0).to(DEVICE)
        self.normalize_state()

    def normalize_state(self):
        """Normalizes the state vector in each cell to conserve probability."""
        prob_sq = self.x_real.pow(2) + self.x_imag.pow(2)
        norm = torch.sqrt(prob_sq.sum(dim=-1, keepdim=True) + 1e-8)
        self.x_real.data = self.x_real.data / norm
        self.x_imag.data = self.x_imag.data / norm

    def get_cat(self):
        """Concatenates real/imag tensors into (B, C, H, W) format for CNNs."""
        x_real_c = self.x_real.permute(0, 3, 1, 2)
        x_imag_c = self.x_imag.permute(0, 3, 1, 2)
        return torch.cat([x_real_c, x_imag_c], dim=1)

# ------------------------------------------------------------------------------
# 1.2: QCA_Operator_Deep Class (The "Deep" M-Law)
# ------------------------------------------------------------------------------
class QCA_Operator_Deep(nn.Module):
    def __init__(self, d_state, hidden_channels):
        super().__init__()
        self.d_state = d_state
        # 3x3 Neighborhood convolution (non-trainable, center-masked)
        self.conv_neighbors = nn.Conv2d(2*d_state, 2*d_state*8, kernel_size=3,
                                        padding=1, groups=2*d_state, bias=False)
        weights = torch.ones(2*d_state*8, 1, 3, 3)
        weights[:, 0, 1, 1] = 0.0 # Zero out the center weight.
        self.conv_neighbors.weight.data = weights
        self.conv_neighbors.weight.requires_grad = False

        # Trainable 1x1 Convolutional MLP
        self.processing_net = nn.Sequential(
            nn.Conv2d(2 * d_state * 8, hidden_channels, kernel_size=1, bias=False),
            nn.ELU(inplace=True),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1, bias=False),
            nn.ELU(inplace=True),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1, bias=False),
            nn.ELU(inplace=True),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1, bias=False),
            nn.ELU(inplace=True),
            nn.Conv2d(hidden_channels, 8 * d_state, kernel_size=1, bias=False)
        )

        # Trainable bias parameters
        self.M_bias_real = nn.Parameter(torch.zeros(d_state))
        self.M_bias_imag = nn.Parameter(torch.zeros(d_state))

    def forward(self, x_cat):
        """Applies the evolution operator."""
        x_neighbors = self.conv_neighbors(x_cat.to(self.conv_neighbors.weight.device))
        F_int = self.processing_net(x_neighbors)

        F_int = F_int.squeeze(0).permute(1, 2, 0) # (H, W, Channels)
        H, W, C = F_int.shape
        D4 = 4 * self.d_state

        F_int_real_raw = F_int[:, :, :D4]
        F_int_imag_raw = F_int[:, :, D4:]

        F_int_real = F_int_real_raw.reshape(H, W, 4, self.d_state).mean(dim=2) * 0.1
        F_int_imag = F_int_imag_raw.reshape(H, W, 4, self.d_state).mean(dim=2) * 0.1

        return F_int_real, F_int_imag

# ------------------------------------------------------------------------------
# 1.3: Aetheria_Motor Class (The Evolution Engine)
# ------------------------------------------------------------------------------
class Aetheria_Motor:
    def __init__(self, size, d_state, operator_model):
        self.size = size
        self.d_state = d_state
        # Send the operator model to the device
        self.operator = operator_model.to(DEVICE)
        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs for the operator.")
            self.operator = nn.DataParallel(self.operator)

        self.state = QCA_State(size, d_state)

    def evolve_step(self):
        """Evolves the QCA state one time step."""
        with torch.no_grad():
            prev_state = self.state
            x_cat = prev_state.get_cat()

            if isinstance(self.operator, nn.DataParallel):
                x_cat = x_cat.to(self.operator.device_ids[0])
            else:
                 x_cat = x_cat.to(DEVICE)

            delta_real, delta_imag = self.operator(x_cat)

            if isinstance(self.operator, nn.DataParallel):
                bias_real = self.operator.module.M_bias_real
                bias_imag = self.operator.module.M_bias_imag
            else:
                bias_real = self.operator.M_bias_real
                bias_imag = self.operator.M_bias_imag

            new_real = prev_state.x_real.squeeze(0) + delta_real + bias_real.to(delta_real.device)
            new_imag = prev_state.x_imag.squeeze(0) + delta_imag + bias_imag.to(delta_imag.device)

            prob_sq = new_real.pow(2) + new_imag.pow(2)
            norm = torch.sqrt(prob_sq.sum(dim=-1, keepdim=True) + 1e-8)
            next_real = new_real / norm
            next_imag = new_imag / norm

            self.state.x_real.data = next_real.unsqueeze(0)
            self.state.x_imag.data = next_imag.unsqueeze(0)

print("QCA_State, QCA_Operator_Deep, and Aetheria_Motor classes defined.")


# %% [markdown]
# # PHASE 4: VISUALIZATION & CHECKPOINTING FUNCTIONS
# ---
# Helper functions for generating visualization frames and saving/loading state.
# %% [code]
# ------------------------------------------------------------------------------
# 4.1: Visualization Helper Functions
# ------------------------------------------------------------------------------

def downscale_frame(frame, downscale_factor):
    """Downscales an image frame (numpy array) using PIL."""
    if downscale_factor <= 1:
        return frame
    height, width = frame.shape[:2]
    new_height, new_width = height // downscale_factor, width // downscale_factor
    img = Image.fromarray(frame)
    img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    return np.array(img_resized)


def get_density_frame_gpu(state):
    """Generates a frame visualizing total probability density."""
    prob_sq = state.x_real.pow(2) + state.x_imag.pow(2)
    density_map = prob_sq.squeeze(0).sum(dim=2).detach()
    d_min, d_max = density_map.min(), density_map.max()
    norm_factor = d_max - d_min
    if norm_factor < 1e-8:
        normalized_density = torch.zeros_like(density_map).to(state.x_real.device)
    else:
        normalized_density = (density_map - d_min) / norm_factor
    normalized_density_clamped = normalized_density.clamp(0.0, 1.0)
    R = normalized_density_clamped
    G = torch.zeros_like(normalized_density_clamped)
    B = 1.0 - normalized_density_clamped
    img_rgb = torch.stack([R, G, B], dim=2).clamp(0.0, 1.0)
    final_image = (img_rgb * 255).byte().cpu().numpy()
    return final_image

def get_channel_frame_gpu(state, num_channels=3):
    """Generates a frame visualizing the first N channels as RGB."""
    prob_sq = state.x_real.pow(2) + state.x_imag.pow(2)
    combined_image = torch.zeros(state.size, state.size, 3, device=state.x_real.device)

    num_channels_to_viz = min(num_channels, state.d_state)
    if num_channels_to_viz == 0:
        return (combined_image * 255).byte().cpu().numpy()

    for i in range(num_channels_to_viz):
        channel_data = prob_sq[0, :, :, i].detach()
        ch_min, ch_max = channel_data.min(), channel_data.max()
        if (ch_max - ch_min) < 1e-8:
            channel_scaled = torch.zeros_like(channel_data)
        else:
            channel_scaled = (channel_data - ch_min) / (ch_max - ch_min)
        color_index = i % 3
        combined_image[:, :, color_index] += channel_scaled * (1.0 / num_channels_to_viz)

    final_image = (combined_image.clamp(0, 1) * 255).byte().cpu().numpy()
    return final_image

def get_state_magnitude_frame_gpu(state):
    """Generates a frame visualizing the state vector magnitude (grayscale)."""
    prob_sq = state.x_real.pow(2) + state.x_imag.pow(2)
    magnitude_map = torch.sqrt(prob_sq.squeeze(0).sum(dim=2) + 1e-8).detach()
    m_min, m_max = magnitude_map.min(), magnitude_map.max()
    norm_factor = m_max - m_min
    if norm_factor < 1e-8:
        normalized_magnitude = torch.zeros_like(magnitude_map).to(state.x_real.device)
    else:
        normalized_magnitude = (magnitude_map - m_min) / norm_factor
    normalized_magnitude_clamped = normalized_magnitude.clamp(0.0, 1.0)
    img_gray = normalized_magnitude_clamped
    img_rgb = torch.stack([img_gray, img_gray, img_gray], dim=2)
    final_image = (img_rgb * 255).byte().cpu().numpy()
    return final_image

def get_state_phase_frame_gpu(state):
    """Generates a frame visualizing the state vector phase (mapped to Hue)."""
    sum_real = state.x_real.squeeze(0).sum(dim=2)
    sum_imag = state.x_imag.squeeze(0).sum(dim=2)
    phase_map = torch.atan2(sum_imag, sum_real).detach()
    normalized_phase = (phase_map + torch.pi) / (2 * torch.pi)
    normalized_phase_clamped = normalized_phase.clamp(0.0, 1.0)
    hue = normalized_phase_clamped
    R = torch.sin(2 * torch.pi * hue + torch.pi/2) * 0.5 + 0.5
    G = torch.sin(2 * torch.pi * hue + torch.pi*3/2) * 0.5 + 0.5
    B = torch.sin(2 * torch.pi * hue + torch.pi*5/2) * 0.5 + 0.5
    img_rgb = torch.stack([R, G, B], dim=2).clamp(0.0, 1.0)
    final_image = (img_rgb * 255).byte().cpu().numpy()
    return final_image

def get_state_change_magnitude_frame_gpu(state, prev_state):
    """Generates a frame visualizing the magnitude of state change (activity)."""
    state_real = state.x_real.detach()
    state_imag = state.x_imag.detach()
    prev_state_real = prev_state.x_real.detach().to(DEVICE)
    prev_state_imag = prev_state.x_imag.detach().to(DEVICE)
    diff_real = state_real - prev_state_real
    diff_imag = state_imag - prev_state_imag
    change_magnitude_sq = diff_real.pow(2) + diff_imag.pow(2)
    change_magnitude_map = torch.sqrt(change_magnitude_sq.squeeze(0).sum(dim=2) + 1e-8)
    m_min, m_max = change_magnitude_map.min(), change_magnitude_map.max()
    norm_factor = m_max - m_min
    if norm_factor < 1e-12:
        normalized_change = torch.zeros_like(change_magnitude_map).to(DEVICE)
    else:
        normalized_change = (change_magnitude_map - m_min) / norm_factor # Normalized
    normalized_change_clamped = normalized_change.clamp(0.0, 1.0)
    img_gray = normalized_change_clamped
    img_rgb = torch.stack([img_gray, img_gray, img_gray], dim=2)
    final_image = (img_rgb * 255).byte().cpu().numpy()
    return final_image

# ------------------------------------------------------------------------------
# 4.2: State Checkpointing Functions
# ------------------------------------------------------------------------------

def load_qca_state(motor_instance, checkpoint_filepath):
    """Loads the QCA state (x_real, x_imag) from a checkpoint file."""
    try:
        checkpoint = torch.load(checkpoint_filepath, map_location=motor_instance.state.x_real.device)
        if 'x_real' in checkpoint and 'x_imag' in checkpoint and \
           checkpoint['x_real'].shape == motor_instance.state.x_real.shape and \
           checkpoint['x_imag'].shape == motor_instance.state.x_imag.shape:

            motor_instance.state.x_real.data = checkpoint['x_real'].data.to(motor_instance.state.x_real.device)
            motor_instance.state.x_imag.data = checkpoint['x_imag'].data.to(motor_instance.state.x_imag.device)
            print(f"‚úÖ State loaded successfully from: {checkpoint_filepath}")
            return checkpoint.get('step', -1)
        else:
            print("‚ùå Error loading state: Checkpoint file invalid or dimensions mismatch.")
            return -1
    except FileNotFoundError:
        print(f"‚ùå Error loading state: File '{checkpoint_filepath}' not found.")
        return -1
    except Exception as e:
        print(f"‚ùå Error loading state from '{checkpoint_filepath}': {e}")
        return -1

def save_qca_state(motor_instance, step, checkpoint_dir):
    """Saves the current QCA state (x_real, x_imag) and step number."""
    checkpoint_filename = os.path.join(
        checkpoint_dir,
        f"large_sim_state_step_{step}.pth"
    )
    try:
        torch.save({
            'step': step,
            'x_real': motor_instance.state.x_real.data.cpu(),
            'x_imag': motor_instance.state.x_imag.data.cpu()
        }, checkpoint_filename)
        print(f"\nüíæ Large simulation checkpoint saved: step {step}")
    except Exception as e:
        print(f"‚ö†Ô∏è  Error saving large simulation checkpoint at step {step}: {e}")

print("Visualization and state checkpointing functions defined.")


# %% [markdown]
# # PHASE 2: ADVANCED TRAINER (PEF v3)
# ---
# Definition of the `QC_Trainer_v3` class for training the M-Law.
# %% [code]
# ------------------------------------------------------------------------------
# 2.1: QC_Trainer_v3 Class
# ------------------------------------------------------------------------------
class QC_Trainer_v3:
    def __init__(self, motor, lr_rate):
        self.motor = motor
        if isinstance(self.motor.operator, nn.DataParallel):
            params_to_optimize = self.motor.operator.module.parameters()
        else:
            params_to_optimize = self.motor.operator.parameters()

        self.optimizer = torch.optim.AdamW(
            params_to_optimize,
            lr=lr_rate,
            weight_decay=1e-6,
            betas=(0.9, 0.999)
        )

        self.history = {
            'Loss': [],
            'R_Density_Target': [],
            'R_Causalidad': [],
            'R_Stability': [],
            'P_Explosion': [],
            'Gradient_Norm': [],
            'R_Activity_Var': [],
            'R_Velocidad': []
        }

        self.current_episode = 0
        self.best_loss = float('inf')
        self.stagnation_counter = 0
        self.reactivation_counter = 0
        self.gradient_norms = []

    def _calculate_annealed_alpha_gamma(self, total_episodes):
        """Calculates annealed weights for Alpha and Gamma."""
        total_episodes_for_annealing = total_episodes * 0.75
        progress = min(1.0, self.current_episode / max(1.0, total_episodes_for_annealing))
        alpha_progress = 1 - (1 - progress) ** 1.5
        gamma_progress = progress ** 0.7
        current_alpha = ALPHA_START + (ALPHA_END - ALPHA_START) * alpha_progress
        current_gamma = GAMMA_START + (GAMMA_END - GAMMA_START) * gamma_progress
        return current_alpha, current_gamma

    def _save_checkpoint(self, episode, is_best=False):
        """Saves the training state to a .pth file."""
        if is_best:
             filename = f"{CHECKPOINT_DIR}/qca_best_eps{episode}.pth"
        else:
             filename = f"{CHECKPOINT_DIR}/qca_checkpoint_eps{episode}.pth"

        if isinstance(self.motor.operator, nn.DataParallel):
            model_state_dict = self.motor.operator.module.state_dict()
        else:
            model_state_dict = self.motor.operator.state_dict()

        state = {
            'episode': episode,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_loss': self.best_loss,
            'stagnation_counter': self.stagnation_counter,
            'reactivation_counter': self.reactivation_counter,
            'history': self.history
        }
        torch.save(state, filename)
        print(f"\n[Checkpoint saved to: {filename}]")

    def _load_checkpoint(self):
        """Loads the latest training checkpoint from CHECKPOINT_DIR."""
        try:
            list_of_files = glob.glob(f"{CHECKPOINT_DIR}/qca_checkpoint_eps*.pth")
            list_of_best_files = glob.glob(f"{CHECKPOINT_DIR}/qca_best_eps*.pth")
            all_checkpoint_files = list_of_files + list_of_best_files

            if not all_checkpoint_files:
                print("No checkpoints found. Starting from scratch.")
                return

            latest_file = max(all_checkpoint_files, key=os.path.getmtime)
            checkpoint = torch.load(latest_file, map_location=DEVICE)

            if isinstance(self.motor.operator, nn.DataParallel):
                 self.motor.operator.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                first_key = next(iter(checkpoint['model_state_dict']))
                if first_key.startswith('module.'):
                    new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
                    self.motor.operator.load_state_dict(new_state_dict)
                else:
                    self.motor.operator.load_state_dict(checkpoint['model_state_dict'])

            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.current_episode = checkpoint['episode'] + 1
            self.best_loss = checkpoint['best_loss']

            loaded_history = checkpoint.get('history', {})
            for key in self.history.keys():
                self.history[key] = loaded_history.get(key, []) # Init missing keys

            self.stagnation_counter = checkpoint.get('stagnation_counter', 0)
            self.reactivation_counter = checkpoint.get('reactivation_counter', 0)

            print(f"Checkpoint loaded: {latest_file}. Resuming from episode {self.current_episode}.")

        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting from scratch.")
            self.current_episode = 0
            self.history = {k: [] for k in self.history.keys()}
            self.best_loss = float('inf')
            self.stagnation_counter = 0
            self.reactivation_counter = 0
            self.gradient_norms = []

    def check_stagnation_and_reactivate(self, total_episodes):
        """Checks for training stagnation and triggers reactivation if configured."""
        current_loss = self.history['Loss'][-1] if self.history['Loss'] else float('inf')

        if current_loss < (self.best_loss - MIN_LOSS_IMPROVEMENT):
            self.best_loss = current_loss
            self.stagnation_counter = 0
            return False # Not stagnated

        else:
            self.stagnation_counter += 1
            if self.stagnation_counter >= STAGNATION_WINDOW:
                print(f"\nSTAGNATION DETECTED at episode {self.current_episode}!")
                print(f"No improvement of {MIN_LOSS_IMPROVEMENT} in {STAGNATION_WINDOW} episodes.")

                if self.reactivation_counter < REACTIVATION_COUNT:
                    self.reactivation_counter += 1
                    print(f"Attempting reactivation {self.reactivation_counter}/{REACTIVATION_COUNT}...")

                    if REACTIVATION_STATE_MODE == 'random':
                        self.motor.state._reset_state_random()
                        print("-> Resetting state with random noise.")
                    elif REACTIVATION_STATE_MODE == 'seeded':
                         self.motor.state._reset_state_seeded()
                         print("-> Resetting state with central seed.")
                    elif REACTIVATION_STATE_MODE == 'complex_noise':
                         self.motor.state._reset_state_complex_noise()
                         print("-> Resetting state with complex noise.")
                    else:
                         print(f"State reactivation mode '{REACTIVATION_STATE_MODE}' not recognized. Resetting to random.")
                         self.motor.state._reset_state_random()

                    current_lr = self.optimizer.param_groups[0]['lr']
                    new_lr = current_lr * REACTIVATION_LR_MULTIPLIER
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = new_lr
                    print(f"-> Learning rate adjusted from {current_lr:.2e} to {new_lr:.2e}.")

                    self.stagnation_counter = 0
                    print("-> Reactivation complete. Continuing training.")
                    return False # Do not stop

                else:
                    print(f"Maximum number of reactivations ({REACTIVATION_COUNT}) reached.")
                    return True # Stop training

        return False # Not stagnated yet

    def train_episode(self, total_episodes):
        """Runs one full training episode (BPTT-k)."""
        self.motor.state._reset_state_random()
        alpha, gamma = self._calculate_annealed_alpha_gamma(total_episodes)

        episode_total_loss = 0.0
        bptt_cumulative_loss = 0.0
        valid_steps = 0
        current_real = self.motor.state.x_real.clone().requires_grad_(True).to(DEVICE)
        current_imag = self.motor.state.x_imag.clone().requires_grad_(True).to(DEVICE)

        activity_variances_per_step_mean = []
        density_variances_per_step = []

        for t in range(STEPS_PER_EPISODE):
            if torch.isnan(current_real).any() or torch.isinf(current_real).any() or \
               torch.isnan(current_imag).any() or torch.isinf(current_imag).any():
                print(f"‚ö†Ô∏è  NaN/Inf detected in state at step {t} of episode {self.current_episode}.")
                episode_total_loss = float('nan')
                break

            prev_real_detached = current_real.detach()
            prev_imag_detached = current_imag.detach()

            x_real_c = current_real.permute(0, 3, 1, 2)
            x_imag_c = current_imag.permute(0, 3, 1, 2)
            x_cat = torch.cat([x_real_c, x_imag_c], dim=1).to(DEVICE)

            if isinstance(self.motor.operator, nn.DataParallel):
                F_int_real, F_int_imag = self.motor.operator(x_cat)
            else:
                F_int_real, F_int_imag = self.motor.operator(x_cat)

            if torch.isnan(F_int_real).any() or torch.isinf(F_int_real).any() or \
               torch.isnan(F_int_imag).any() or torch.isinf(F_int_imag).any():
                print(f"‚ö†Ô∏è  NaN/Inf detected in F_int at step {t} of episode {self.current_episode}.")
                episode_total_loss = float('nan')
                break

            if isinstance(self.motor.operator, nn.DataParallel):
                bias_real = self.motor.operator.module.M_bias_real.to(DEVICE)
                bias_imag = self.motor.operator.module.M_bias_imag.to(DEVICE)
            else:
                bias_real = self.motor.operator.M_bias_real.to(DEVICE)
                bias_imag = self.motor.operator.M_bias_imag.to(DEVICE)

            new_real = current_real.squeeze(0) + F_int_real + bias_real
            new_imag = current_imag.squeeze(0) + F_int_imag + bias_imag

            new_real = torch.clamp(new_real, -1.5, 1.5)
            new_imag = torch.clamp(new_imag, -1.5, 1.5)

            prob_sq = new_real.pow(2) + new_imag.pow(2)
            norm = torch.sqrt(prob_sq.sum(dim=-1, keepdim=True) + 1e-8)
            next_real = new_real / norm
            next_imag = new_imag / norm

            if torch.isnan(next_real).any() or torch.isinf(next_real).any() or \
               torch.isnan(next_imag).any() or torch.isinf(next_imag).any():
                print(f"‚ö†Ô∏è  NaN/Inf detected in next_state at step {t} of episode {self.current_episode}.")
                episode_total_loss = float('nan')
                break

            density_map = torch.clamp(prob_sq.sum(dim=-1), 0.0, 3.0)

            # --- Reward and Penalty Calculation ---
            current_std_density = density_map.std()
            density_error = torch.abs(current_std_density - TARGET_STD_DENSITY)
            R_density_target = -density_error * (1.0 + density_error)

            change_real = next_real - prev_real_detached.squeeze(0)
            change_imag = next_imag - prev_imag_detached.squeeze(0)
            R_Causalidad = -(change_real.abs().mean() + change_imag.abs().mean())

            density_t_plus_1 = next_real.pow(2) + next_imag.pow(2)
            R_Stability = -density_t_plus_1.var(dim=-1).mean()

            change_magnitude_per_cell = torch.sqrt(change_real.pow(2) + change_imag.pow(2)).sum(dim=-1)
            activity_variances_per_step_mean.append(change_magnitude_per_cell.mean().item())
            density_variances_per_step.append(density_map.var().item())

            P_Explosion = torch.relu(density_map.max() - EXPLOSION_THRESHOLD) * EXPLOSION_PENALTY_MULTIPLIER

            # Step-wise reward/penalty (for BPTT)
            reward_step_bptt = (alpha * R_density_target) + \
                               (BETA_CAUSALITY * R_Causalidad) + \
                               (gamma * R_Stability) + \
                               (LAMBDA_ACTIVITY_VAR * change_magnitude_per_cell.var()) - \
                               (LAMBDA_VELOCIDAD * density_map.var())

            step_loss = -reward_step_bptt + P_Explosion

            if torch.isnan(step_loss) or torch.isinf(step_loss):
                 print(f"‚ö†Ô∏è  NaN/Inf detected in step_loss at step {t} of episode {self.current_episode}.")
                 episode_total_loss = float('nan')
                 break

            bptt_cumulative_loss = bptt_cumulative_loss + step_loss
            if not torch.isnan(step_loss):
                episode_total_loss += step_loss.item()
            valid_steps += 1

            # --- Truncated Backpropagation (BPTT-k) ---
            if (t + 1) % PERSISTENCE_COUNT == 0 or (t + 1) == STEPS_PER_EPISODE:
                if bptt_cumulative_loss != 0 and not torch.isnan(bptt_cumulative_loss) and not torch.isinf(bptt_cumulative_loss):
                    self.optimizer.zero_grad()
                    bptt_cumulative_loss.backward()

                    total_norm = 0.0
                    params_to_clip = []
                    if isinstance(self.motor.operator, nn.DataParallel):
                        params_to_clip = [p for p in self.motor.operator.module.parameters() if p.requires_grad and p.grad is not None]
                    else:
                        params_to_clip = [p for p in self.motor.operator.parameters() if p.requires_grad and p.grad is not None]

                    if params_to_clip:
                         for p in params_to_clip:
                             param_norm = p.grad.data.norm(2)
                             total_norm += param_norm.item() ** 2
                         total_norm = total_norm ** 0.5
                         self.gradient_norms.append(total_norm)
                         torch.nn.utils.clip_grad_norm_(params_to_clip, GRADIENT_CLIP)
                    else:
                         self.gradient_norms.append(0.0)

                    self.optimizer.step()
                else:
                    print(f"‚ö†Ô∏è  BPTT cumulative loss is NaN/Inf/Zero at step {t} of episode {self.current_episode}. Skipping backward pass.")
                    self.gradient_norms.append(0.0)

                bptt_cumulative_loss = 0.0
                current_real = next_real.unsqueeze(0).detach().to(DEVICE).requires_grad_(True)
                current_imag = next_imag.unsqueeze(0).detach().to(DEVICE).requires_grad_(True)
            else:
                current_real = next_real.unsqueeze(0)
                current_imag = next_imag.unsqueeze(0)

        # --- End of Episode ---
        episodic_R_Activity_Var = 0.0
        if len(activity_variances_per_step_mean) > 1:
             episodic_R_Activity_Var = float(np.var(activity_variances_per_step_mean)) * LAMBDA_ACTIVITY_VAR

        episodic_R_Velocidad = 0.0
        if len(density_variances_per_step) > 0:
            episodic_R_Velocidad = np.mean(density_variances_per_step) * LAMBDA_VELOCIDAD

        avg_loss = episode_total_loss / max(valid_steps, 1)

        # --- Store Metrics in History ---
        self.history['Loss'].append(avg_loss)
        self.history['R_Density_Target'].append(R_density_target.item() if valid_steps > 0 and 'R_density_target' in locals() else float('nan'))
        self.history['R_Causalidad'].append(R_Causalidad.item() if valid_steps > 0 and 'R_Causalidad' in locals() else float('nan'))
        self.history['R_Stability'].append(R_Stability.item() if valid_steps > 0 and 'R_Stability' in locals() else float('nan'))
        self.history['P_Explosion'].append(P_Explosion.item() if valid_steps > 0 and 'P_Explosion' in locals() else float('nan'))
        self.history['R_Activity_Var'].append(episodic_R_Activity_Var if valid_steps > 0 else float('nan'))
        self.history['R_Velocidad'].append(episodic_R_Velocidad if valid_steps > 0 else float('nan'))

        if self.gradient_norms:
             self.history['Gradient_Norm'].append(np.mean(self.gradient_norms))
             self.gradient_norms = []
        else:
             self.history['Gradient_Norm'].append(0.0)

        self.current_episode += 1
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return avg_loss

print("QC_Trainer_v3 class defined.")


# %% [markdown]
# # PHASES 5, 6 & 7: MAIN EXECUTION PIPELINE
# ---
# This is the main runnable part of the script.
# It uses the flags from PHASE 3 to determine which sections to run.
# %% [code]
def main_pipeline():
    """
    Main function to run the AETHERIA pipeline based on global flags.
    """
    print("--- STARTING AETHERIA PIPELINE EXECUTION ---")

    # These variables need to be accessible across phases
    Aetheria_Motor_Train = None
    M_FILENAME = None
    trainer = None
    model_id = "Deep_v3" # Default model ID

    # --------------------------------------------------------------------------
    # FASE 5: MAIN TRAINING LOGIC
    # --------------------------------------------------------------------------
    if RUN_TRAINING:
        print("\n" + "="*60)
        print(">>> STARTING TRAINING PHASE (PHASE 5) <<<")
        print("="*60)

        # Initialize the model (QCA_Operator_Deep)
        model_M = QCA_Operator_Deep(D_STATE, HIDDEN_CHANNELS)

        # Instantiate the Aetheria_Motor for the training phase.
        Aetheria_Motor_Train = Aetheria_Motor(GRID_SIZE_TRAINING, D_STATE, model_M)

        print(f"Motor and M-Law ({model_id}) initialized. Training grid: {GRID_SIZE_TRAINING}x{GRID_SIZE_TRAINING}.")

        # Correctly get the number of trainable parameters
        trainable_params = sum(p.numel() for p in (Aetheria_Motor_Train.operator.module.parameters()
                                                   if isinstance(Aetheria_Motor_Train.operator, nn.DataParallel)
                                                   else Aetheria_Motor_Train.operator.parameters()) if p.requires_grad)
        print(f"Trainable Parameters: {trainable_params}")

        # Instantiate the trainer
        trainer = QC_Trainer_v3(Aetheria_Motor_Train, LR_RATE_M)

        # Load checkpoint if CONTINUE_TRAINING is True.
        if CONTINUE_TRAINING:
            print("Attempting to continue training...")
            trainer._load_checkpoint()
        else:
            print("Starting new training from scratch.")

        # MAIN TRAINING LOOP
        print("\n--- Training (PEF v3 - Optimized Parameters + Reactivation + New Rewards) ---")
        print(f"Starting from episode {trainer.current_episode}. Training for {EPISODES_TO_ADD} more episodes.")
        print(f"Model: {model_id}, BPTT-k: {PERSISTENCE_COUNT}, Initial LR: {LR_RATE_M}")
        print(f"Target Std Density: {TARGET_STD_DENSITY}, Explosion Threshold: {EXPLOSION_THRESHOLD}, Explosion Multiplier: {EXPLOSION_PENALTY_MULTIPLIER}")
        print(f"Annealing Alpha: {ALPHA_START} -> {ALPHA_END}, Gamma: {GAMMA_START} -> {GAMMA_END}, Beta Causality: {BETA_CAUSALITY}")
        print(f"New Reward Weights: Activity Var (Œª={LAMBDA_ACTIVITY_VAR}), Velocidad (Œª={LAMBDA_VELOCIDAD})")
        print(f"Gradient Clip: {GRADIENT_CLIP}, Stagnation Window: {STAGNATION_WINDOW}, Min Loss Improvement: {MIN_LOSS_IMPROVEMENT}")
        print(f"Reactivation: {REACTIVATION_COUNT} attempts with state '{REACTIVATION_STATE_MODE}' and LR * {REACTIVATION_LR_MULTIPLIER}")
        print(f"Save Checkpoint every {SAVE_EVERY_EPISODES} episodes.")

        start_time = time.time()
        final_episode = trainer.current_episode + EPISODES_TO_ADD

        try:
            for episode in range(trainer.current_episode, final_episode):
                avg_loss = trainer.train_episode(final_episode)

                if np.isnan(avg_loss) or np.isinf(avg_loss):
                    print(f"‚ö†Ô∏è  Episode {episode:04}: Training failed (NaN/Inf in loss). Skipping save and continuing.")
                    Aetheria_Motor_Train.state._reset_state_random()
                    continue

                # Print progress
                if episode % 10 == 0 or episode == final_episode - 1 or episode == trainer.current_episode:
                    alpha, gamma = trainer._calculate_annealed_alpha_gamma(final_episode)
                    last_r_density = trainer.history['R_Density_Target'][-1] if trainer.history['R_Density_Target'] else float('nan')
                    last_r_causalidad = trainer.history['R_Causalidad'][-1] if trainer.history['R_Causalidad'] else float('nan')
                    last_r_stability = trainer.history['R_Stability'][-1] if trainer.history['R_Stability'] else float('nan')
                    last_p_explosion = trainer.history['P_Explosion'][-1] if trainer.history['P_Explosion'] else float('nan')
                    last_grad_norm = trainer.history['Gradient_Norm'][-1] if trainer.history['Gradient_Norm'] else float('nan')
                    last_r_activity_var = trainer.history['R_Activity_Var'][-1] if trainer.history['R_Activity_Var'] else float('nan')
                    last_r_velocidad = trainer.history['R_Velocidad'][-1] if trainer.history['R_Velocidad'] else float('nan')

                    print(f"Eps {episode:04}: Loss={avg_loss:.3e} | "
                          f"R_Dens(Tgt={TARGET_STD_DENSITY:.2f})={last_r_density:.3f} | "
                          f"R_Caus={last_r_causalidad:.3f} | "
                          f"R_Stab={last_r_stability:.3f} | "
                          f"R_ActVar={last_r_activity_var:.3e} | "
                          f"R_Vel={last_r_velocidad:.3f} | "
                          f"P_Expl={last_p_explosion:.3e} | "
                          f"GradNorm={last_grad_norm:.3e} | "
                          f"Œ±={alpha:.2f}, Œ≥={gamma:.2f}, LR={trainer.optimizer.param_groups[0]['lr']:.2e}")

                # Save checkpoint
                if episode % SAVE_EVERY_EPISODES == 0 and episode > (trainer.current_episode - EPISODES_TO_ADD): # Avoid saving at start if continuing
                    trainer._save_checkpoint(episode)
                    if not np.isnan(avg_loss) and not np.isinf(avg_loss) and avg_loss < trainer.best_loss:
                        trainer._save_checkpoint(episode, is_best=True)
                        print(f"üèÜ New best model saved at episode {episode}")

                # Check for stagnation
                if trainer.check_stagnation_and_reactivate(final_episode):
                    print("Training stopped due to stagnation with no reactivations left.")
                    break

        except KeyboardInterrupt:
            print("\nTraining interrupted by user. Saving checkpoint...")
            if trainer: trainer._save_checkpoint(trainer.current_episode)
        except Exception as e:
            print(f"\n‚ùå Critical error during training at episode {trainer.current_episode if trainer else 'N/A'}: {e}")
            print("Saving current state before stopping...")
            if trainer: trainer._save_checkpoint(trainer.current_episode)
            raise e # Re-raise error

        end_time = time.time()
        print(f"\nTraining completed in {end_time - start_time:.2f} seconds.")
        if trainer: print(f"Final episode reached: {trainer.current_episode}")

        # SAVE FINAL MODEL (M Law)
        print("\n--- SAVING FINAL FUNDAMENTAL LAW (M) ---")
        TIMESTAMP = int(time.time())
        if trainer:
            M_FILENAME = f"{CHECKPOINT_DIR}/PEF_{model_id}_G{GRID_SIZE_TRAINING}_Eps{trainer.current_episode}_{TIMESTAMP}_FINAL.pth"
            try:
                if isinstance(trainer.motor.operator, nn.DataParallel):
                    model_state_dict_to_save = trainer.motor.operator.module.state_dict()
                else:
                    model_state_dict_to_save = trainer.motor.operator.state_dict()
                torch.save(model_state_dict_to_save, M_FILENAME)
                print(f"‚úÖ Final Fundamental Law (M) saved: {M_FILENAME}")
            except Exception as e:
                print(f"‚ùå Error saving final model: {e}")

    else:
        print("\n>>> TRAINING PHASE (PHASE 5) SKIPPED <<<")
        # If training is skipped, we still need to initialize the training motor
        # and load the model if we plan to run visualization or large simulation.
        if RUN_POST_TRAINING_VIZ or RUN_LARGE_SIM:
            print("Initializing training-size motor to load model...")
            model_M = QCA_Operator_Deep(D_STATE, HIDDEN_CHANNELS)
            Aetheria_Motor_Train = Aetheria_Motor(GRID_SIZE_TRAINING, D_STATE, model_M)

            # Attempt to load the most recent FINAL model
            model_files = glob.glob(f"{CHECKPOINT_DIR}/PEF_Deep_v3_G{GRID_SIZE_TRAINING}_Eps*_FINAL.pth")
            M_FILENAME = max(model_files, key=os.path.getctime, default=None) if model_files else None

            if M_FILENAME and os.path.exists(M_FILENAME):
                print(f"üì¶ Detected latest .pth file (assumed trained model): {M_FILENAME}. Attempting to load weights...")
                try:
                    model_state_dict = torch.load(M_FILENAME, map_location=DEVICE)
                    if isinstance(Aetheria_Motor_Train.operator, nn.DataParallel):
                        Aetheria_Motor_Train.operator.module.load_state_dict(model_state_dict)
                    else:
                        first_key = next(iter(model_state_dict))
                        if first_key.startswith('module.'):
                            new_state_dict = {k.replace('module.', ''): v for k, v in model_state_dict.items()}
                            Aetheria_Motor_Train.operator.load_state_dict(new_state_dict)
                        else:
                            Aetheria_Motor_Train.operator.load_state_dict(model_state_dict)
                    Aetheria_Motor_Train.operator.eval()
                    print("‚úÖ Model weights loaded successfully into Aetheria_Motor_Train.")
                except Exception as e:
                    print(f"‚ùå Error loading model weights '{M_FILENAME}': {e}")
                    print("‚ö†Ô∏è  Could not load trained model. Visualization/simulation may not work as expected.")
            else:
                print(f"‚ùå No PEF_Deep_v3_G{GRID_SIZE_TRAINING}_Eps*_FINAL.pth files found in '{CHECKPOINT_DIR}'. No trained model will be loaded.")


    # --------------------------------------------------------------------------
    # FASE 6: POST-TRAINING VISUALIZATION (Training Size)
    # --------------------------------------------------------------------------
    if RUN_POST_TRAINING_VIZ:
        print("\n" + "="*60)
        print(">>> STARTING POST-TRAINING VISUALIZATION PHASE (PHASE 6) <<<")
        print("="*60)

        if Aetheria_Motor_Train is not None: # Ensure the training motor exists and model is loaded
            print(f"Generating {NUM_FRAMES_VIZ} frames with the M-Law on {GRID_SIZE_TRAINING}x{GRID_SIZE_TRAINING} grid...")

            Aetheria_Motor_Train.operator.eval()
            Aetheria_Motor_Train.state._reset_state_random() # Start from a fresh random state
            NUM_FRAMES_VIZ_TRAINING = NUM_FRAMES_VIZ # Use the parameter from the top

            FRAMES_DENSITY_TRAINING = []
            FRAMES_CHANNELS_TRAINING = []
            FRAMES_MAGNITUDE_TRAINING = []
            FRAMES_PHASE_TRAINING = []
            FRAMES_CHANGE_TRAINING = []

            prev_state_for_change_viz_training = QCA_State(Aetheria_Motor_Train.size, Aetheria_Motor_Train.d_state)
            prev_state_for_change_viz_training.x_real.data = Aetheria_Motor_Train.state.x_real.data.clone()
            prev_state_for_change_viz_training.x_imag.data = Aetheria_Motor_Train.state.x_imag.data.clone()

            with torch.no_grad():
                for t in range(NUM_FRAMES_VIZ_TRAINING):
                    # Store current state *before* evolution for change calculation
                    current_state_clone_for_change_viz_training = QCA_State(Aetheria_Motor_Train.state.size, Aetheria_Motor_Train.state.d_state)
                    current_state_clone_for_change_viz_training.x_real.data = Aetheria_Motor_Train.state.x_real.data.clone().to(DEVICE)
                    current_state_clone_for_change_viz_training.x_imag.data = Aetheria_Motor_Train.state.x_imag.data.clone().to(DEVICE)

                    Aetheria_Motor_Train.evolve_step()
                    next_state = Aetheria_Motor_Train.state

                    # Generate frames
                    FRAMES_DENSITY_TRAINING.append(get_density_frame_gpu(next_state))
                    FRAMES_CHANNELS_TRAINING.append(get_channel_frame_gpu(next_state, num_channels=min(3, D_STATE)))
                    FRAMES_MAGNITUDE_TRAINING.append(get_state_magnitude_frame_gpu(next_state))
                    FRAMES_PHASE_TRAINING.append(get_state_phase_frame_gpu(next_state))
                    FRAMES_CHANGE_TRAINING.append(get_state_change_magnitude_frame_gpu(next_state, prev_state_for_change_viz_training))

                    # Update previous state
                    prev_state_for_change_viz_training = current_state_clone_for_change_viz_training

                    if (t + 1) % max(1, (NUM_FRAMES_VIZ_TRAINING // 10)) == 0:
                        print(f"-> Capturing frame {t+1}/{NUM_FRAMES_VIZ_TRAINING}...")

            print("‚úÖ Visualization frame capture completed.")

            print("\n--- SAVING VISUALIZATION VIDEOS (Training Size) ---")
            try:
                # Use the filename from the final saved model to name the visualization videos.
                if M_FILENAME and os.path.exists(M_FILENAME):
                    BASE_FILENAME_VIZ_TRAINING = os.path.basename(M_FILENAME).replace('_FINAL.pth', '')
                elif trainer is not None and hasattr(trainer, 'current_episode'):
                    TIMESTAMP_VIZ = int(time.time())
                    BASE_FILENAME_VIZ_TRAINING = f"Viz_TrainSize_G{GRID_SIZE_TRAINING}_Eps{trainer.current_episode}_{TIMESTAMP_VIZ}"
                    print(f"‚ö†Ô∏è M_FILENAME not available. Using default base name: {BASE_FILENAME_VIZ_TRAINING}")
                else:
                    TIMESTAMP_VIZ = int(time.time())
                    BASE_FILENAME_VIZ_TRAINING = f"Viz_TrainSize_G{GRID_SIZE_TRAINING}_{TIMESTAMP_VIZ}"
                    print(f"‚ö†Ô∏è M_FILENAME and trainer not available. Using default base name: {BASE_FILENAME_VIZ_TRAINING}")

                MP4_DENSITY_FILENAME_TRAINING = f"{BASE_FILENAME_VIZ_TRAINING}_1_DENSITY.mp4"
                MP4_CHANNELS_FILENAME_TRAINING = f"{BASE_FILENAME_VIZ_TRAINING}_2_CHANNELS.mp4"
                MP4_MAGNITUDE_FILENAME_TRAINING = f"{BASE_FILENAME_VIZ_TRAINING}_3_MAGNITUDE.mp4"
                MP4_PHASE_FILENAME_TRAINING = f"{BASE_FILENAME_VIZ_TRAINING}_4_PHASE.mp4"
                MP4_CHANGE_FILENAME_TRAINING = f"{BASE_FILENAME_VIZ_TRAINING}_5_CHANGE.mp4"

                if FRAMES_DENSITY_TRAINING:
                    try:
                        imageio.mimsave(MP4_DENSITY_FILENAME_TRAINING, FRAMES_DENSITY_TRAINING, fps=FPS_VIZ_TRAINING, codec='libx264', quality=8)
                        imageio.mimsave(MP4_CHANNELS_FILENAME_TRAINING, FRAMES_CHANNELS_TRAINING, fps=FPS_VIZ_TRAINING, codec='libx264', quality=8)
                        imageio.mimsave(MP4_MAGNITUDE_FILENAME_TRAINING, FRAMES_MAGNITUDE_TRAINING, fps=FPS_VIZ_TRAINING, codec='libx264', quality=8)
                        imageio.mimsave(MP4_PHASE_FILENAME_TRAINING, FRAMES_PHASE_TRAINING, fps=FPS_VIZ_TRAINING, codec='libx264', quality=8)
                        if FRAMES_CHANGE_TRAINING:
                             imageio.mimsave(MP4_CHANGE_FILENAME_TRAINING, FRAMES_CHANGE_TRAINING, fps=FPS_VIZ_TRAINING, codec='libx264', quality=8)
                        print(f"‚úÖ MP4 Videos ({GRID_SIZE_TRAINING}x{GRID_SIZE_TRAINING}) saved.")
                        print(f"   -> {MP4_DENSITY_FILENAME_TRAINING}")
                        print(f"   -> {MP4_CHANNELS_FILENAME_TRAINING}")
                        # (IPython.display.Video calls removed)
                    except Exception as e:
                        print(f"‚ùå Error saving visualization videos (training size): {e}")
                else:
                    print("‚ùå Error: No frames were generated for post-training visualization.")

            except Exception as e:
                print(f"‚ùå General error in post-training visualization video saving: {e}")
        else:
            print("‚ö†Ô∏è Aetheria_Motor_Train was not initialized. Skipping post-training visualization.")

        print("\n>>> POST-TRAINING VISUALIZATION PHASE (PHASE 6) COMPLETED <<<")

    else:
        print("\n>>> POST-TRAINING VISUALIZATION PHASE (PHASE 6) SKIPPED <<<")


    # --------------------------------------------------------------------------
    # FASE 7: MAIN PROLONGED LARGE SIMULATION LOGIC
    # --------------------------------------------------------------------------
    if RUN_LARGE_SIM:
        print("\n" + "="*60)
        print(">>> STARTING PROLONGED LARGE SIMULATION PHASE (PHASE 7) <<<")
        print("="*60)

        print(f"\n--- CONFIGURING LARGE SIMULATION ({GRID_SIZE_INFERENCE}x{GRID_SIZE_INFERENCE}) ---")
        print(f"Using D_STATE={D_STATE}, HIDDEN_CHANNELS={HIDDEN_CHANNELS} (from optimized training)")

        # Instantiate the operator model
        operator_model_inference = QCA_Operator_Deep(
            d_state=D_STATE,
            hidden_channels=HIDDEN_CHANNELS
        )

        # Instantiate the Aetheria_Motor for the large simulation.
        large_scale_motor = Aetheria_Motor(GRID_SIZE_INFERENCE, D_STATE, operator_model_inference)

        # --- Load Trained Model Weights ---
        if not M_FILENAME: # If M_FILENAME is not set (e.g., training skipped and no file found)
            model_files = glob.glob(f"{CHECKPOINT_DIR}/PEF_Deep_v3_G{GRID_SIZE_TRAINING}_Eps*_FINAL.pth")
            M_FILENAME = max(model_files, key=os.path.getctime, default=None) if model_files else None

        if M_FILENAME and os.path.exists(M_FILENAME):
            print(f"üì¶ Loading weights from: {M_FILENAME}")
            try:
                model_state_dict = torch.load(M_FILENAME, map_location=DEVICE)
                if isinstance(large_scale_motor.operator, nn.DataParallel):
                    large_scale_motor.operator.module.load_state_dict(model_state_dict)
                else:
                    first_key = next(iter(model_state_dict))
                    if first_key.startswith('module.'):
                        new_state_dict = {k.replace('module.', ''): v for k, v in model_state_dict.items()}
                        large_scale_motor.operator.load_state_dict(new_state_dict)
                    else:
                        large_scale_motor.operator.load_state_dict(model_state_dict)
                large_scale_motor.operator.eval()
                print("‚úÖ Model weights loaded successfully.")
            except Exception as e:
                print(f"‚ùå Error loading model weights '{M_FILENAME}': {e}")
                print("‚ö†Ô∏è  Simulation will run with random M-Law weights.")
        else:
            print(f"‚ùå No trained model file found in '{CHECKPOINT_DIR}'. Simulation will run with random M-Law weights.")

        # --- Load state from a checkpoint or start from scratch ---
        start_step = 0
        latest_checkpoint_filepath = None

        if LOAD_STATE_CHECKPOINT_INFERENCE:
            if STATE_CHECKPOINT_PATH_INFERENCE and os.path.exists(STATE_CHECKPOINT_PATH_INFERENCE):
                latest_checkpoint_filepath = STATE_CHECKPOINT_PATH_INFERENCE
                print(f"\nAttempting to load state from specified checkpoint: {latest_checkpoint_filepath}")
            else:
                if STATE_CHECKPOINT_PATH_INFERENCE:
                     print(f"‚ö†Ô∏è Specified checkpoint '{STATE_CHECKPOINT_PATH_INFERENCE}' not found. Searching for latest...")
                checkpoint_files = [f for f in os.listdir(LARGE_SIM_CHECKPOINT_DIR) if f.startswith("large_sim_state_step_") and f.endswith(".pth")]
                if checkpoint_files:
                    def extract_step(filename):
                        match = re.search(r"large_sim_state_step_(\d+)\.pth", filename)
                        return int(match.group(1)) if match else 0
                    checkpoint_files.sort(key=extract_step)
                    latest_checkpoint_filename = checkpoint_files[-1]
                    latest_checkpoint_filepath = os.path.join(LARGE_SIM_CHECKPOINT_DIR, latest_checkpoint_filename)
                    print(f"\nDetected latest large sim checkpoint in '{LARGE_SIM_CHECKPOINT_DIR}': {latest_checkpoint_filepath}. Attempting to load...")
                else:
                    print(f"\nLOAD_STATE_CHECKPOINT_INFERENCE is True but no large simulation checkpoints were found in '{LARGE_SIM_CHECKPOINT_DIR}'.")

            if latest_checkpoint_filepath and os.path.exists(latest_checkpoint_filepath):
                loaded_step = load_qca_state(large_scale_motor, latest_checkpoint_filepath)
                if loaded_step != -1:
                    start_step = loaded_step
                    print(f"Simulation resumed from step {start_step}.")
                else:
                    print("‚ùå Failed to load state checkpoint. Starting new simulation.")
                    start_step = 0
            else:
                 print("\nNo large simulation checkpoint found to load. Starting new simulation.")
                 start_step = 0

        if start_step == 0:
            print(f"\nStarting new simulation with initial state mode: '{INITIAL_STATE_MODE_INFERENCE}'.")
            if INITIAL_STATE_MODE_INFERENCE == 'random':
                large_scale_motor.state._reset_state_random()
            elif INITIAL_STATE_MODE_INFERENCE == 'seeded':
                large_scale_motor.state._reset_state_seeded()
            elif INITIAL_STATE_MODE_INFERENCE == 'complex_noise':
                large_scale_motor.state._reset_state_complex_noise()
            else:
                print(f"Initial state mode '{INITIAL_STATE_MODE_INFERENCE}' not recognized. Defaulting to random.")
                large_scale_motor.state._reset_state_random()

        # ----------------------------------------------------------------------
        # Main Large Simulation Loop
        # ----------------------------------------------------------------------

        TIMESTAMP_SIM = int(time.time())
        GRID_SIZE_STR = str(GRID_SIZE_INFERENCE)
        SIMULATION_ID = f"{GRID_SIZE_STR}_{TIMESTAMP_SIM}_{INITIAL_STATE_MODE_INFERENCE}"
        if VIDEO_SAVE_INTERVAL_STEPS is not None and VIDEO_SAVE_INTERVAL_STEPS > 0 and VIDEO_DOWNSCALE_FACTOR > 1:
            SIMULATION_ID += f"_down{VIDEO_DOWNSCALE_FACTOR}"

        MP4_DENSITY_FILENAME = f"MUNDO_DENSITY_{SIMULATION_ID}.mp4"
        MP4_CHANNELS_FILENAME = f"MUNDO_CHANNELS_{SIMULATION_ID}.mp4"
        MP4_MAGNITUDE_FILENAME = f"MUNDO_MAGNITUDE_{SIMULATION_ID}.mp4"
        MP4_PHASE_FILENAME = f"MUNDO_PHASE_{SIMULATION_ID}.mp4"
        MP4_CHANGE_FILENAME = f"MUNDO_CHANGE_{SIMULATION_ID}.mp4"

        print(f"\nüé¨ Starting simulation of {NUM_INFERENCE_STEPS} steps on {GRID_SIZE_INFERENCE}x{GRID_SIZE_INFERENCE} from step {start_step}...")
        if VIDEO_SAVE_INTERVAL_STEPS is not None and VIDEO_SAVE_INTERVAL_STEPS > 0:
            print(f"üìπ Videos will be saved as MUNDO_*_{SIMULATION_ID}.mp4 (Frame every {VIDEO_SAVE_INTERVAL_STEPS} steps @ {VIDEO_FPS} FPS, Downscale: {VIDEO_DOWNSCALE_FACTOR}, Quality: {VIDEO_QUALITY})")
        else:
            print("‚ÑπÔ∏è Video saving is disabled.")
        if REAL_TIME_VIZ_INTERVAL is not None and REAL_TIME_VIZ_INTERVAL > 0:
             print(f"üëÅÔ∏è Real-time visualization enabled (Frame every {REAL_TIME_VIZ_INTERVAL} steps, Downscale: {REAL_TIME_VIZ_DOWNSCALE}, Type: '{REAL_TIME_VIZ_TYPE}').")
        else:
             print("‚ÑπÔ∏è Real-time visualization is disabled.")
        if LARGE_SIM_CHECKPOINT_INTERVAL is not None and LARGE_SIM_CHECKPOINT_INTERVAL > 0:
             print(f"üíæ Raw state checkpoints will be saved every {LARGE_SIM_CHECKPOINT_INTERVAL} steps in '{LARGE_SIM_CHECKPOINT_DIR}'.")
        else:
             print("‚ÑπÔ∏è Raw state checkpointing is disabled.")

        writer_density, writer_channels, writer_magnitude, writer_phase, writer_change = None, None, None, None, None
        if VIDEO_SAVE_INTERVAL_STEPS is not None and VIDEO_SAVE_INTERVAL_STEPS > 0:
            try:
                writer_density = imageio.get_writer(MP4_DENSITY_FILENAME, fps=VIDEO_FPS, codec='libx264', quality=VIDEO_QUALITY)
                writer_channels = imageio.get_writer(MP4_CHANNELS_FILENAME, fps=VIDEO_FPS, codec='libx264', quality=VIDEO_QUALITY)
                writer_magnitude = imageio.get_writer(MP4_MAGNITUDE_FILENAME, fps=VIDEO_FPS, codec='libx264', quality=VIDEO_QUALITY)
                writer_phase = imageio.get_writer(MP4_PHASE_FILENAME, fps=VIDEO_FPS, codec='libx264', quality=VIDEO_QUALITY)
                writer_change = imageio.get_writer(MP4_CHANGE_FILENAME, fps=VIDEO_FPS, codec='libx264', quality=VIDEO_QUALITY)
                print("‚úÖ Video writers initialized.")
            except Exception as e:
                print(f"‚ùå Error initializing video writers: {e}")
                writer_density, writer_channels, writer_magnitude, writer_phase, writer_change = None, None, None, None, None

        prev_state_for_change_viz = None
        if start_step > 0:
            prev_step_checkpoint_filepath = os.path.join(LARGE_SIM_CHECKPOINT_DIR, f"large_sim_state_step_{start_step-1}.pth")
            if os.path.exists(prev_step_checkpoint_filepath):
                print(f"Attempting to load previous state for change visualization from: {prev_step_checkpoint_filepath}")
                temp_operator = QCA_Operator_Deep(D_STATE, HIDDEN_CHANNELS).to(DEVICE)
                temp_motor = Aetheria_Motor(GRID_SIZE_INFERENCE, D_STATE, temp_operator)
                if load_qca_state(temp_motor, prev_step_checkpoint_filepath) != -1:
                    prev_state_for_change_viz = temp_motor.state
                    print("Loaded previous state for change visualization.")
                else:
                    print("‚ùå Failed to load previous state for change visualization.")
            else:
                print(f"‚ùå Previous state checkpoint for step {start_step-1} not found. Change video/viz might be inaccurate.")

        if prev_state_for_change_viz is None and ((VIDEO_SAVE_INTERVAL_STEPS is not None and VIDEO_SAVE_INTERVAL_STEPS > 0) or \
           (REAL_TIME_VIZ_INTERVAL is not None and REAL_TIME_VIZ_INTERVAL > 0 and REAL_TIME_VIZ_TYPE == 'change')):
            prev_state_for_change_viz = QCA_State(large_scale_motor.state.size, large_scale_motor.state.d_state)
            prev_state_for_change_viz.x_real.data = large_scale_motor.state.x_real.data.clone().to(DEVICE)
            prev_state_for_change_viz.x_imag.data = large_scale_motor.state.x_imag.data.clone().to(DEVICE)
            print("Initialized previous state clone for change visualization.")

        t = start_step
        try:
            with torch.no_grad():
                for t in range(start_step, NUM_INFERENCE_STEPS):

                    current_state_clone_for_change_viz = None
                    if prev_state_for_change_viz is not None:
                        current_state_clone_for_change_viz = QCA_State(large_scale_motor.state.size, large_scale_motor.state.d_state)
                        current_state_clone_for_change_viz.x_real.data = large_scale_motor.state.x_real.data.clone().to(DEVICE)
                        current_state_clone_for_change_viz.x_imag.data = large_scale_motor.state.x_imag.data.clone().to(DEVICE)

                    large_scale_motor.evolve_step()
                    current_state = large_scale_motor.state

                    # Real-time Visualization (Print statement only)
                    if REAL_TIME_VIZ_INTERVAL is not None and REAL_TIME_VIZ_INTERVAL > 0 and (t + 1) % REAL_TIME_VIZ_INTERVAL == 0:
                        print(f"--- Real-time frame {t+1} (display disabled in .py script) ---")
                        # (Original display(Image.fromarray(...)) call removed)

                    # Save Video Frame
                    if VIDEO_SAVE_INTERVAL_STEPS is not None and VIDEO_SAVE_INTERVAL_STEPS > 0 and ((t + 1) % VIDEO_SAVE_INTERVAL_STEPS == 0 or (t == start_step and start_step == 0)):
                        try:
                            density_frame = get_density_frame_gpu(current_state)
                            channels_frame = get_channel_frame_gpu(current_state, num_channels=min(3, D_STATE))
                            magnitude_frame = get_state_magnitude_frame_gpu(current_state)
                            phase_frame = get_state_phase_frame_gpu(current_state)

                            if VIDEO_DOWNSCALE_FACTOR > 1:
                                density_frame = downscale_frame(density_frame, VIDEO_DOWNSCALE_FACTOR)
                                channels_frame = downscale_frame(channels_frame, VIDEO_DOWNSCALE_FACTOR)
                                magnitude_frame = downscale_frame(magnitude_frame, VIDEO_DOWNSCALE_FACTOR)
                                phase_frame = downscale_frame(phase_frame, VIDEO_DOWNSCALE_FACTOR)

                            if writer_density: writer_density.append_data(density_frame)
                            if writer_channels: writer_channels.append_data(channels_frame)
                            if writer_magnitude: writer_magnitude.append_data(magnitude_frame)
                            if writer_phase: writer_phase.append_data(phase_frame)

                            if prev_state_for_change_viz is not None:
                                change_frame = get_state_change_magnitude_frame_gpu(current_state, prev_state_for_change_viz)
                                if VIDEO_DOWNSCALE_FACTOR > 1:
                                     change_frame = downscale_frame(change_frame, VIDEO_DOWNSCALE_FACTOR)
                                if writer_change: writer_change.append_data(change_frame)
                        except Exception as e:
                            print(f"‚ö†Ô∏è  Error generating/saving video frame at step {t+1}: {e}")

                    if current_state_clone_for_change_viz is not None:
                        prev_state_for_change_viz = current_state_clone_for_change_viz

                    # Save State Checkpoint
                    if LARGE_SIM_CHECKPOINT_INTERVAL is not None and LARGE_SIM_CHECKPOINT_INTERVAL > 0 and (t + 1) % LARGE_SIM_CHECKPOINT_INTERVAL == 0:
                        save_qca_state(large_scale_motor, t + 1, LARGE_SIM_CHECKPOINT_DIR)

                    # Print Progress
                    if (t + 1) % max(1, (NUM_INFERENCE_STEPS // 20)) == 0 or (t + 1) == NUM_INFERENCE_STEPS:
                        print(f"üìà Simulation Progress: {t+1}/{NUM_INFERENCE_STEPS} steps completed.")

                    if (t + 1) % 200 == 0:
                        gc.collect()
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()

                # Save final checkpoint
                if LARGE_SIM_CHECKPOINT_INTERVAL is not None and LARGE_SIM_CHECKPOINT_INTERVAL > 0 and (NUM_INFERENCE_STEPS % LARGE_SIM_CHECKPOINT_INTERVAL != 0):
                     save_qca_state(large_scale_motor, NUM_INFERENCE_STEPS, LARGE_SIM_CHECKPOINT_DIR)
                print("‚úÖ Simulation loop completed.")

        except KeyboardInterrupt:
            print(f"\n‚èπÔ∏è Simulation interrupted by user at step {t+1}.")
            if LARGE_SIM_CHECKPOINT_INTERVAL is not None and LARGE_SIM_CHECKPOINT_INTERVAL > 0:
                save_qca_state(large_scale_motor, t + 1, LARGE_SIM_CHECKPOINT_DIR)
        except Exception as e:
            print(f"\n‚ùå Error during simulation at step {t+1}: {e}")
            if LARGE_SIM_CHECKPOINT_INTERVAL is not None and LARGE_SIM_CHECKPOINT_INTERVAL > 0:
                save_qca_state(large_scale_motor, t + 1, LARGE_SIM_CHECKPOINT_DIR)
            raise e # Re-raise error
        finally:
            try:
                if writer_density: writer_density.close()
                if writer_channels: writer_channels.close()
                if writer_magnitude: writer_magnitude.close()
                if writer_phase: writer_phase.close()
                if writer_change: writer_change.close()
                if any([writer_density, writer_channels, writer_magnitude, writer_phase, writer_change]):
                     print("üìπ Video writers closed successfully.")
            except Exception as e:
                print(f"‚ùå Error closing video writers: {e}")

        # --- Final Summary ---
        print("\n" + "="*60)
        print("üéâ PROLONGED LARGE SIMULATION FINISHED (Completed or Interrupted)")
        print("="*60)
        print("üìÅ Video files generated in this run (if video saving was enabled):")
        if VIDEO_SAVE_INTERVAL_STEPS is not None and VIDEO_SAVE_INTERVAL_STEPS > 0:
            print(f"   ‚Ä¢ Density: {MP4_DENSITY_FILENAME}")
            print(f"   ‚Ä¢ Channels: {MP4_CHANNELS_FILENAME}")
            print(f"   ‚Ä¢ Magnitude: {MP4_MAGNITUDE_FILENAME}")
            print(f"   ‚Ä¢ Phase: {MP4_PHASE_FILENAME}")
            print(f"   ‚Ä¢ Change: {MP4_CHANGE_FILENAME}")
        else:
            print("   ‚Ä¢ Video saving was disabled.")
        print(f"   ‚Ä¢ Raw state checkpoints in: {LARGE_SIM_CHECKPOINT_DIR}/")

        if VIDEO_SAVE_INTERVAL_STEPS is not None and VIDEO_SAVE_INTERVAL_STEPS > 0 and os.path.exists(MP4_DENSITY_FILENAME):
            file_size = os.path.getsize(MP4_DENSITY_FILENAME) / (1024*1024)
            print(f"üìä Estimated size of Density video: {file_size:.1f} MB")

        print("\nüéØ Download Instructions:")
        print("   You can find the saved MP4 files and .pth checkpoints in your working directory")
        print(f"   (or '{LARGE_SIM_CHECKPOINT_DIR}/' for state checkpoints).")

        print("\n>>> PROLONGED LARGE SIMULATION PHASE (PHASE 7) COMPLETED <<<")

    else:
        print("\n>>> PROLONGED LARGE SIMULATION PHASE (PHASE 7) SKIPPED <<<")

    print("\n--- AETHERIA PIPELINE EXECUTION FINISHED ---")


# ==============================================================================
# SCRIPT ENTRY POINT
# ==============================================================================
if __name__ == "__main__":
    main_pipeline()

Using device: cuda
Detected 1 GPU.
Training Checkpoint Directory: checkpoints_optimized
Large Simulation Checkpoint Directory: large_sim_checkpoints_1024

‚ÑπÔ∏è Not configured to use a specific input model. Will look for locally trained models in CHECKPOINT_DIR.
Input model configuration (optional) completed.
Global parameters set.
QCA_State, QCA_Operator_Deep, and Aetheria_Motor classes defined.
Visualization and state checkpointing functions defined.
QC_Trainer_v3 class defined.
--- STARTING AETHERIA PIPELINE EXECUTION ---

>>> TRAINING PHASE (PHASE 5) SKIPPED <<<
Initializing training-size motor to load model...
üì¶ Detected latest .pth file (assumed trained model): checkpoints_optimized/PEF_Deep_v3_G256_Eps2_1762202475_FINAL.pth. Attempting to load weights...
‚úÖ Model weights loaded successfully into Aetheria_Motor_Train.

>>> STARTING POST-TRAINING VISUALIZATION PHASE (PHASE 6) <<<
Generating 1500 frames with the M-Law on 256x256 grid...
-> Capturing frame 150/1500...
-> Captur