In [7]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
# from matplotlib.colors import ListedColormap # No longer used
import time
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Any
import os

try:
    import cupy as cp
    try:
        cp.array([1,2,3]).sum()
        DEFAULT_GPU_ENABLED = True
    except cp.cuda.runtime.CUDARuntimeError:
        DEFAULT_GPU_ENABLED = False
    except Exception: # Broad exception for other potential CuPy issues
        DEFAULT_GPU_ENABLED = False
except ImportError:
    DEFAULT_GPU_ENABLED = False

@dataclass
class SimConfig:
    N_VORTICES: int = 20
    N_TRACERS: int = 500000
    DOMAIN_RADIUS: float = 1.0
    SIMULATION_TIME: float = 3.0
    DT: float = 0.002
    OUTPUT_FILENAME: str = "point_vortex_dynamics.mp4"
    PLOT_INTERVAL: int = 2
    DPI: int = 120

    VORTEX_CORE_A_SQ: float = 0.001
    TRACER_CORE_A_SQ: float = 0.0005

    BOUNDARY_WARN_THRESHOLD: float = 0.995

    GPU_ENABLED: bool = DEFAULT_GPU_ENABLED
    RANDOM_SEED: Optional[int] = 42

    xp: Any = field(init=False)
    rng: Any = field(init=False)
    float_type: Any = field(init=False)

    TRACER_PARTICLE_SIZE: float = 0.3
    TRACER_ALPHA: float = 0.4
    
    # Colormap settings for tracers
    TRACER_CMAP: str = "jet" # Default to "jet" as often desired for scalar fields
    
    # NEW: Mode for tracer coloring
    TRACER_COLORING_MODE: str = "group"  # Options: "group", "scalar"
                                         # "group": colors based on NUM_TRACER_GROUPS (discrete patches)
                                         # "scalar": colors based on continuous scalar value per tracer 
                                         #           (e.g., for "viscosity"-like coloring with a gradient)
    
    NUM_TRACER_GROUPS: int = 3 # Used if TRACER_COLORING_MODE is "group"

    TRACER_GLOW_LAYERS: list = field(
        default_factory=lambda: [
            (0.10, 0.10),   # outer, larger, faint halo
            (0.05, 0.05)    # inner, smaller, faint halo
        ]
    )

    VORTEX_MARKER_SIZE_BASE: float = 10
    Vortex_MARKER_SIZE_SCALE: float = 20
    VORTEX_COLOR_POS: str = '#FFFF00' # Yellow for positive strength
    VORTEX_COLOR_NEG: str = '#FF00FF' # Magenta for negative strength
    
    FIGURE_BG_COLOR: str = '#080808'
    AXES_BG_COLOR: str = '#101010'
    
    FPS: int            = 30            # frame rate for the writer
    FFMPEG_CODEC: str   = "libx264"     # or "h264_nvenc" / "hevc_nvenc" if FFmpeg has NVENC
    FFMPEG_PRESET: str  = "ultrafast"   # ultrafast ≪ superfast ≪ veryfast ≪ fast ≪ medium
    FFMPEG_CRF: int     = 23            # visually loss-less ≤ 18, streaming 23–28
    FFMPEG_THREADS: int = 0             # 0 = let FFmpeg pick (#logical cores)
    FFMPEG_CQ: int = 19

    def __post_init__(self):
        if self.GPU_ENABLED:
            self.xp = cp
            self.float_type = cp.float32
            print("CuPy active. Using GPU acceleration with float32.")
        else:
            self.xp = np
            self.float_type = np.float64
            print("CuPy not available or disabled. Running on CPU with NumPy with float64.")
        
        if self.RANDOM_SEED is not None:
            self.rng = self.xp.random.default_rng(self.RANDOM_SEED)
        else:
            self.rng = self.xp.random.default_rng()

        if self.N_TRACERS > 75000 and self.xp == np: # Check if xp is numpy
            print(f"Warning: N_TRACERS ({self.N_TRACERS}) is high for CPU. Simulation might be slow.")
        
        if self.TRACER_COLORING_MODE not in ["group", "scalar", "speed"]:
            raise ValueError(f"Invalid TRACER_COLORING_MODE: {self.TRACER_COLORING_MODE}. Must be 'group' or 'scalar'.")


def initialize_vortices(config: SimConfig):
    xp = config.xp
    positions = xp.zeros((config.N_VORTICES, 2), dtype=config.float_type)
    strengths = xp.zeros(config.N_VORTICES, dtype=config.float_type)

    if config.N_VORTICES == 0:
        return positions, strengths

    # Example: Four primary vortices in a square, alternating signs for ΣΓ near 0
    if config.N_VORTICES >= 4:
        s = 0.4 * config.DOMAIN_RADIUS
        positions[0] = xp.array([-s,  s])
        strengths[0] = 1.5
        positions[1] = xp.array([ s,  s])
        strengths[1] = -1.5
        positions[2] = xp.array([-s, -s])
        strengths[2] = -1.5
        positions[3] = xp.array([ s, -s])
        strengths[3] = 1.5
        
        if config.N_VORTICES > 4:
            num_remaining = config.N_VORTICES - 4
            radii = config.rng.uniform(0.1, 0.7, num_remaining).astype(config.float_type) * config.DOMAIN_RADIUS
            angles = config.rng.uniform(0, 2 * xp.pi, num_remaining).astype(config.float_type)
            positions[4:, 0] = radii * xp.cos(angles)
            positions[4:, 1] = radii * xp.sin(angles)
            
            rand_strengths_choices = xp.array([-0.75, 0.75, -0.5, 0.5], dtype=config.float_type)
            chosen_indices = config.rng.integers(0, len(rand_strengths_choices), num_remaining)
            rand_strengths_base = rand_strengths_choices[chosen_indices]
            strengths[4:] = rand_strengths_base * config.rng.uniform(0.5, 1.0, num_remaining).astype(config.float_type)
    else: # Fewer than 4 vortices, distribute randomly
        radii = config.rng.uniform(0.1, 0.7, config.N_VORTICES).astype(config.float_type) * config.DOMAIN_RADIUS
        angles = config.rng.uniform(0, 2 * xp.pi, config.N_VORTICES).astype(config.float_type)
        positions[:, 0] = radii * xp.cos(angles)
        positions[:, 1] = radii * xp.sin(angles)
        
        rand_strengths_choices = xp.array([-1.0, 1.0], dtype=config.float_type)
        chosen_indices = config.rng.integers(0, len(rand_strengths_choices), config.N_VORTICES)
        base_s = rand_strengths_choices[chosen_indices]
        strengths[:] = base_s * config.rng.uniform(0.5, 1.5, config.N_VORTICES).astype(config.float_type)
    
    total_initial_strength = xp.sum(strengths)
    print(f"Total initial vortex strength: {total_initial_strength:.3e}")

    return positions, strengths

def initialize_tracers(config: SimConfig):
    xp = config.xp
    tracer_pos = xp.zeros((config.N_TRACERS, 2), dtype=config.float_type)
    tracer_scalar_values = xp.zeros(config.N_TRACERS, dtype=config.float_type)

    if config.N_TRACERS == 0:
        return tracer_pos, tracer_scalar_values

    if config.TRACER_COLORING_MODE == "group":
        print(f"Initializing tracers with 'group' coloring mode using {config.NUM_TRACER_GROUPS} groups and cmap '{config.TRACER_CMAP}'.")
        num_groups = max(1, config.NUM_TRACER_GROUPS)
        tracers_per_group = config.N_TRACERS // num_groups
        
        patch_radius_base = config.DOMAIN_RADIUS * 0.25
        patch_center_dist = config.DOMAIN_RADIUS * 0.45

        current_idx = 0
        for i in range(num_groups):
            num_in_patch = tracers_per_group
            if i == num_groups - 1: # Assign remaining to last patch
                num_in_patch = config.N_TRACERS - current_idx
            if num_in_patch == 0: continue

            angle_offset = (2 * xp.pi / num_groups) * i
            center_x = patch_center_dist * xp.cos(angle_offset)
            center_y = patch_center_dist * xp.sin(angle_offset)

            # Distribute tracers uniformly within circular patches
            r_sqrt_uniform = config.rng.uniform(0, 1, num_in_patch).astype(config.float_type)
            r = patch_radius_base * xp.sqrt(r_sqrt_uniform) # sqrt for uniform area distribution
            theta = config.rng.uniform(0, 2 * xp.pi, num_in_patch).astype(config.float_type)
            
            start, end = current_idx, current_idx + num_in_patch
            tracer_pos[start:end, 0] = center_x + r * xp.cos(theta)
            tracer_pos[start:end, 1] = center_y + r * xp.sin(theta)
            
            # Assign a scalar value for colormapping, distributed across groups
            # Values are typically in [0,1] for colormaps
            scalar_val = (i + 0.5) / num_groups 
            tracer_scalar_values[start:end] = scalar_val
            current_idx += num_in_patch
            
    elif config.TRACER_COLORING_MODE == "scalar":
        print(f"Initializing tracers with 'scalar' coloring mode using cmap '{config.TRACER_CMAP}'.")
        # For "scalar" mode, tracers are initialized (e.g., in a single large patch or other distribution)
        # and their scalar_values are continuous (e.g., random, or based on initial position).
        
        # Example: Initialize positions in a single circular patch covering a good portion of the domain
        max_r_init = config.DOMAIN_RADIUS * 0.7 # e.g., 70% of domain radius for initial spread
        r_sqrt_uniform = config.rng.uniform(0, 1, config.N_TRACERS).astype(config.float_type)
        radii = max_r_init * xp.sqrt(r_sqrt_uniform)
        angles = config.rng.uniform(0, 2 * xp.pi, config.N_TRACERS).astype(config.float_type)
        tracer_pos[:, 0] = radii * xp.cos(angles)
        tracer_pos[:, 1] = radii * xp.sin(angles)
        
        # Assign continuous scalar values, e.g., random in [0, 1]
        # These values will be mapped by TRACER_CMAP (e.g., "jet")
        tracer_scalar_values = config.rng.uniform(0.0, 1.0, config.N_TRACERS).astype(config.float_type)
        
        # Alternative for "scalar" mode: color by initial radius (normalized)
        # initial_r_values = xp.sqrt(xp.sum(tracer_pos**2, axis=1))
        # max_init_r = xp.max(initial_r_values)
        # if max_init_r > 1e-9: # Avoid division by zero if all tracers at origin
        #     tracer_scalar_values = initial_r_values / max_init_r
        # else:
        #     tracer_scalar_values = xp.zeros(config.N_TRACERS, dtype=config.float_type)
        # tracer_scalar_values = xp.clip(tracer_scalar_values, 0.0, 1.0) # Ensure in [0,1]
    elif config.TRACER_COLORING_MODE == "speed":
        print("Initializing tracers with 'speed' colouring mode.")
        max_r_init = config.DOMAIN_RADIUS * 0.7
        r_sqrt_uniform = config.rng.uniform(0, 1, config.N_TRACERS).astype(config.float_type)
        radii  = max_r_init * xp.sqrt(r_sqrt_uniform)
        angles = config.rng.uniform(0, 2 * xp.pi, config.N_TRACERS).astype(config.float_type)
        tracer_pos[:, 0] = radii * xp.cos(angles)
        tracer_pos[:, 1] = radii * xp.sin(angles)
        tracer_scalar_values[:] = 0.0               # will be recoloured every frame
    else:
        # This case should ideally be caught by SimConfig.__post_init__
        raise ValueError(f"Unknown TRACER_COLORING_MODE: {config.TRACER_COLORING_MODE}")

    # Ensure all tracers are initially within the domain boundary, regardless of initialization mode
    dist_sq = xp.sum(tracer_pos**2, axis=1)
    if config.N_TRACERS > 0:
        initial_max_dist_sq = xp.max(dist_sq)
        if initial_max_dist_sq > 0: # Avoid sqrt of 0
            initial_max_dist = xp.sqrt(initial_max_dist_sq)
            if initial_max_dist >= config.DOMAIN_RADIUS:
                # Scale them to be just inside the domain
                scale_factor = (config.DOMAIN_RADIUS * 0.99) / initial_max_dist 
                tracer_pos *= scale_factor
        
    return tracer_pos, tracer_scalar_values

def _lamb_oseen_factor(r_sq, core_a_sq, xp, float_type):
    epsilon = 1e-12 if float_type == xp.float64 else 1e-7
    r_sq_safe = xp.where(r_sq < epsilon, epsilon, r_sq)
    # For r_sq -> 0, (1 - exp(-r_sq/a_sq)) / r_sq -> 1/a_sq
    # (1 - (1 - r_sq/a_sq + O((r_sq/a_sq)^2))) / r_sq = (r_sq/a_sq) / r_sq = 1/a_sq
    val = (1.0 - xp.exp(-r_sq_safe / core_a_sq)) / r_sq_safe
    limit_val = 1.0 / core_a_sq 
    return xp.where(r_sq < epsilon * 10, limit_val, val)

def _get_velocities_induced_by_vortices(target_positions, vortex_positions, vortex_strengths, core_a_sq, config: SimConfig, total_vortex_strength_for_bg_flow):
    xp = config.xp
    M = target_positions.shape[0]
    N = vortex_positions.shape[0]
    
    if N == 0 or M == 0:
        return xp.zeros_like(target_positions)

    velocities = xp.zeros_like(target_positions)
    
    norm_sq_all_vortices = xp.sum(vortex_positions**2, axis=1)
    epsilon_norm_sq = 1e-9 if config.float_type == xp.float64 else 1e-6
    norm_sq_all_vortices_safe = xp.where(norm_sq_all_vortices < epsilon_norm_sq, epsilon_norm_sq, norm_sq_all_vortices) 
    
    img_v_positions = (config.DOMAIN_RADIUS**2 / norm_sq_all_vortices_safe[:, xp.newaxis]) * vortex_positions
    img_v_strengths = -vortex_strengths
    
    target_pos_exp = target_positions[:, xp.newaxis, :] # M x 1 x 2
    
    # Velocity due to real vortices
    v_pos_exp = vortex_positions[xp.newaxis, :, :]      # 1 x N x 2
    diff_real = target_pos_exp - v_pos_exp              # M x N x 2
    r_sq_real = xp.sum(diff_real**2, axis=2)            # M x N
    
    interaction_factor_real = _lamb_oseen_factor(r_sq_real, core_a_sq, xp, config.float_type) # M x N
    coeff_real = vortex_strengths[xp.newaxis, :] / (2 * xp.pi) # 1 x N
    term_real = coeff_real * interaction_factor_real # M x N

    velocities[:, 0] += xp.sum(-term_real * diff_real[:, :, 1], axis=1)
    velocities[:, 1] += xp.sum( term_real * diff_real[:, :, 0], axis=1)

    # Velocity due to image vortices
    img_v_pos_exp = img_v_positions[xp.newaxis, :, :]   # 1 x N x 2
    diff_img = target_pos_exp - img_v_pos_exp           # M x N x 2
    r_sq_img = xp.sum(diff_img**2, axis=2)              # M x N
    
    interaction_factor_img = _lamb_oseen_factor(r_sq_img, core_a_sq, xp, config.float_type) # M x N
    coeff_img = img_v_strengths[xp.newaxis, :] / (2 * xp.pi) # 1 x N
    term_img = coeff_img * interaction_factor_img # M x N

    velocities[:, 0] += xp.sum(-term_img * diff_img[:, :, 1], axis=1)
    velocities[:, 1] += xp.sum( term_img * diff_img[:, :, 0], axis=1)
    
    if xp.abs(total_vortex_strength_for_bg_flow) > epsilon_norm_sq:
        K_bg = total_vortex_strength_for_bg_flow / (2 * xp.pi * config.DOMAIN_RADIUS**2)
        velocities[:, 0] += -K_bg * target_positions[:, 1]
        velocities[:, 1] +=  K_bg * target_positions[:, 0]
        
    return velocities

def get_vortex_velocities(v_positions, v_strengths, config: SimConfig, total_vortex_strength):
    xp = config.xp
    N = v_positions.shape[0]
    if N == 0:
        return xp.zeros_like(v_positions)

    v_pos_i = v_positions[:, xp.newaxis, :]        # N x 1 x 2
    v_pos_j = v_positions[xp.newaxis, :, :]        # 1 x N x 2

    # 1. Interaction with other *real* vortices (self-term excluded)
    diff_real = v_pos_i - v_pos_j                  # N x N x 2
    r_sq_real = xp.sum(diff_real**2, axis=2)       # N x N
    identity_mask = xp.eye(N, dtype=bool)          # N x N

    interaction_factor_real = _lamb_oseen_factor(r_sq_real, config.VORTEX_CORE_A_SQ, xp, config.float_type)
    coeff_real = v_strengths[xp.newaxis, :] / (2 * xp.pi)
    term_real_masked = coeff_real * interaction_factor_real * (~identity_mask)

    velocities = xp.zeros_like(v_positions)
    velocities[:, 0] = xp.sum(-term_real_masked * diff_real[:, :, 1], axis=1)
    velocities[:, 1] = xp.sum( term_real_masked * diff_real[:, :, 0], axis=1)

    # 2. Interaction with *image* vortices (including self-image)
    norm_sq_all = xp.sum(v_positions**2, axis=1)
    eps = 1e-9 if config.float_type == xp.float64 else 1e-6
    norm_sq_safe = xp.where(norm_sq_all < eps, eps, norm_sq_all)

    img_pos = (config.DOMAIN_RADIUS**2 / norm_sq_safe[:, xp.newaxis]) * v_positions
    img_str = -v_strengths

    diff_img = v_pos_i - img_pos[xp.newaxis, :, :]     # N x N x 2
    r_sq_img = xp.sum(diff_img**2, axis=2)             # N x N

    interaction_factor_img = _lamb_oseen_factor(r_sq_img, config.VORTEX_CORE_A_SQ, xp, config.float_type)
    coeff_img = img_str[xp.newaxis, :] / (2 * xp.pi)
    term_img = coeff_img * interaction_factor_img

    velocities[:, 0] += xp.sum(-term_img * diff_img[:, :, 1], axis=1)
    velocities[:, 1] += xp.sum( term_img * diff_img[:, :, 0], axis=1)

    # 3. Add compensating solid-body background flow if ΣΓ ≠ 0
    if xp.abs(total_vortex_strength) > eps:
        K_bg = total_vortex_strength / (2 * xp.pi * config.DOMAIN_RADIUS**2)
        velocities[:, 0] += -K_bg * v_positions[:, 1]
        velocities[:, 1] +=  K_bg * v_positions[:, 0]

    return velocities

def get_tracer_velocities(t_positions, v_positions, v_strengths, config: SimConfig, total_vortex_strength):
    return _get_velocities_induced_by_vortices(
        t_positions, v_positions, v_strengths, 
        config.TRACER_CORE_A_SQ, config, total_vortex_strength
    )

def rk4_step_system(v_pos, t_pos, v_str, total_v_str, config: SimConfig):
    dt = config.DT
    
    k1_v = get_vortex_velocities(v_pos, v_str, config, total_v_str)
    k1_t = get_tracer_velocities(t_pos, v_pos, v_str, config, total_v_str)
    
    v_pos_k2_arg = v_pos + 0.5 * dt * k1_v
    t_pos_k2_arg = t_pos + 0.5 * dt * k1_t
    k2_v = get_vortex_velocities(v_pos_k2_arg, v_str, config, total_v_str)
    k2_t = get_tracer_velocities(t_pos_k2_arg, v_pos_k2_arg, v_str, config, total_v_str)
    
    v_pos_k3_arg = v_pos + 0.5 * dt * k2_v
    t_pos_k3_arg = t_pos + 0.5 * dt * k2_t
    k3_v = get_vortex_velocities(v_pos_k3_arg, v_str, config, total_v_str)
    k3_t = get_tracer_velocities(t_pos_k3_arg, v_pos_k3_arg, v_str, config, total_v_str)
    
    v_pos_k4_arg = v_pos + dt * k3_v
    t_pos_k4_arg = t_pos + dt * k3_t
    k4_v = get_vortex_velocities(v_pos_k4_arg, v_str, config, total_v_str)
    k4_t = get_tracer_velocities(t_pos_k4_arg, v_pos_k4_arg, v_str, config, total_v_str)
    
    new_v_pos = v_pos + (dt / 6.0) * (k1_v + 2*k2_v + 2*k3_v + k4_v)
    new_t_pos = t_pos + (dt / 6.0) * (k1_t + 2*k2_t + 2*k3_t + k4_t)
    
    return new_v_pos, new_t_pos

def enforce_boundaries(positions, config: SimConfig, current_sim_time: float, is_vortex=False):
    xp = config.xp
    if positions.shape[0] == 0:
        return positions
        
    norm_sq = xp.sum(positions**2, axis=1)
    current_dist = xp.sqrt(norm_sq)
    
    boundary_check_radius = config.DOMAIN_RADIUS * (config.BOUNDARY_WARN_THRESHOLD if is_vortex else 1.0)
    escaped_mask = current_dist > boundary_check_radius

    if xp.any(escaped_mask):
        if is_vortex:
            problematic_indices = xp.where(escaped_mask)[0] # For potential debugging
            max_dist_escaped = xp.max(current_dist[escaped_mask])
            print(f"WARNING: {len(problematic_indices)} vortices near/past boundary threshold ({boundary_check_radius:.2f}) at t={current_sim_time:.3f}. Max dist: {max_dist_escaped:.3f}.")
            
            # For vortices that truly go outside, pull them back just inside
            truly_outside_mask = current_dist > config.DOMAIN_RADIUS
            if xp.any(truly_outside_mask):
                 positions[truly_outside_mask] *= (config.DOMAIN_RADIUS * 0.999 / current_dist[truly_outside_mask, xp.newaxis])
        else: # Tracers
            truly_outside_mask = current_dist > config.DOMAIN_RADIUS
            if xp.any(truly_outside_mask):
                # Pull tracers back just inside
                positions[truly_outside_mask] *= (config.DOMAIN_RADIUS * 0.9999 / current_dist[truly_outside_mask, xp.newaxis])
    return positions

def calculate_angular_impulse(v_positions, v_strengths, xp):
    if v_positions.shape[0] == 0: return xp.array(0.0, dtype=v_positions.dtype) 
    r_sq = xp.sum(v_positions**2, axis=1)
    return xp.sum(v_strengths * r_sq)

def calculate_linear_impulse(v_positions, v_strengths, xp):
    if v_positions.shape[0] == 0: 
        return xp.array(0.0, dtype=v_positions.dtype), xp.array(0.0, dtype=v_positions.dtype)
    P_x = xp.sum(v_strengths * v_positions[:, 1])
    P_y = -xp.sum(v_strengths * v_positions[:, 0])
    return P_x, P_y

def run_simulation(config: SimConfig):
    xp = config.xp
    
    vortex_pos, vortex_strengths = initialize_vortices(config)
    tracer_pos, tracer_scalar_values = initialize_tracers(config)
    
    num_steps = int(config.SIMULATION_TIME / config.DT)
    
    tracer_pos_history = [] 
    vortex_pos_history = []
    times_history = []
    angular_impulse_history = []
    linear_impulse_Px_history = []
    linear_impulse_Py_history = []

    total_vortex_strength = xp.sum(vortex_strengths) 
    initial_Lz = calculate_angular_impulse(vortex_pos, vortex_strengths, xp)
    initial_Px, initial_Py = calculate_linear_impulse(vortex_pos, vortex_strengths, xp)
    
    Lz_denom_for_rel_err = initial_Lz if xp.abs(initial_Lz) > 1e-9 else xp.array(1.0, dtype=config.float_type)
    Px_denom_for_rel_err = initial_Px if xp.abs(initial_Px) > 1e-9 else xp.array(1.0, dtype=config.float_type)
    Py_denom_for_rel_err = initial_Py if xp.abs(initial_Py) > 1e-9 else xp.array(1.0, dtype=config.float_type)

    current_sim_time_val = 0.0
    
    start_sim_time_wc = time.time() 
    for step in range(num_steps + 1):
        if step % config.PLOT_INTERVAL == 0:
            tracer_pos_history.append(tracer_pos.copy())
            vortex_pos_history.append(vortex_pos.copy())
            times_history.append(current_sim_time_val)
            
            Lz = calculate_angular_impulse(vortex_pos, vortex_strengths, xp)
            Px, Py = calculate_linear_impulse(vortex_pos, vortex_strengths, xp)
            angular_impulse_history.append(Lz)
            linear_impulse_Px_history.append(Px)
            linear_impulse_Py_history.append(Py)

        if step == num_steps: break # Exit after saving the last state

        vortex_pos, tracer_pos = rk4_step_system(vortex_pos, tracer_pos, vortex_strengths, total_vortex_strength, config)
        
        vortex_pos = enforce_boundaries(vortex_pos, config, current_sim_time_val, is_vortex=True)
        tracer_pos = enforce_boundaries(tracer_pos, config, current_sim_time_val, is_vortex=False)
        
        current_sim_time_val += config.DT
        
        if step > 0 and (step % max(1, num_steps // 20) == 0 or step == num_steps -1) :
            Lz_curr = angular_impulse_history[-1] # This will exist due to appending at start of loop
            rel_Lz_error = xp.abs((Lz_curr - initial_Lz) / Lz_denom_for_rel_err)
            print(f"Step {step}/{num_steps}, Sim Time: {current_sim_time_val:.2f}s, Rel. ΔLz/Lz₀: {rel_Lz_error:.2e}")
            
    end_sim_time_wc = time.time()
    print(f"Simulation finished in {end_sim_time_wc - start_sim_time_wc:.2f} seconds (wall clock).")
    
    history_pack = {
        "tracer_pos": tracer_pos_history, "vortex_pos": vortex_pos_history,
        "times": times_history, "tracer_scalar_values": tracer_scalar_values,
        "vortex_strengths": vortex_strengths,
        "Lz": angular_impulse_history, "initial_Lz": initial_Lz, "Lz_denom": Lz_denom_for_rel_err,
        "Px": linear_impulse_Px_history, "initial_Px": initial_Px, "Px_denom": Px_denom_for_rel_err,
        "Py": linear_impulse_Py_history, "initial_Py": initial_Py, "Py_denom": Py_denom_for_rel_err,
    }

    # Convert CuPy arrays to NumPy for Matplotlib compatibility if GPU was used
    if config.GPU_ENABLED:
        for key, val in history_pack.items():
            if isinstance(val, list) and len(val) > 0 and isinstance(val[0], cp.ndarray):
                history_pack[key] = [cp.asnumpy(arr) for arr in val]
            elif isinstance(val, cp.ndarray):
                history_pack[key] = cp.asnumpy(val)
    
    # Ensure single scalar values (like initial impulses) are also NumPy types if they were CuPy scalars
    for key in ["initial_Lz", "Lz_denom", "initial_Px", "Px_denom", "initial_Py", "Py_denom", 
                "vortex_strengths", "tracer_scalar_values"]:
         if key in history_pack and not isinstance(history_pack[key], np.ndarray) and hasattr(history_pack[key], 'get'): 
            # .get() converts CuPy scalar to Python type, then np.array ensures it's NumPy compatible
            history_pack[key] = np.array(history_pack[key].get())
         elif key in history_pack and isinstance(history_pack[key], (cp.ndarray, np.ndarray)) and history_pack[key].ndim == 0:
             # If it's a 0-dim array (scalar array), convert to simple NumPy scalar if not already
             history_pack[key] = np.array(history_pack[key].item())


    history_pack["times"] = np.array(history_pack["times"]) # Ensure times is a NumPy array

    return history_pack


def animate(data_pack, config: SimConfig):
    num_frames = len(data_pack["tracer_pos"])
    if num_frames == 0 and (config.N_VORTICES == 0 or len(data_pack["vortex_pos"]) == 0):
        print("No history to animate.")
        return

    fig = plt.figure(figsize=(12, 12)) 
    gs = fig.add_gridspec(4, 1, height_ratios=[12, 1, 1, 0.5], hspace=0.2) 
    
    ax_main = fig.add_subplot(gs[0])
    ax_Lz = fig.add_subplot(gs[1])
    ax_P = fig.add_subplot(gs[2]) 
    ax_info = fig.add_subplot(gs[3])
    ax_info.axis('off')

    fig.patch.set_facecolor(config.FIGURE_BG_COLOR)
    ax_main.set_facecolor(config.AXES_BG_COLOR)
    ax_main.set_aspect('equal')
    ax_main.set_xlim(-config.DOMAIN_RADIUS * 1.02, config.DOMAIN_RADIUS * 1.02)
    ax_main.set_ylim(-config.DOMAIN_RADIUS * 1.02, config.DOMAIN_RADIUS * 1.02)
    title_str = f"2D Point Vortex Dynamics ({config.N_VORTICES} Vortices, {config.N_TRACERS} Tracers, Mode: {config.TRACER_COLORING_MODE})"
    ax_main.set_title(title_str, color='white', fontsize=16)
    ax_main.set_xticks([])
    ax_main.set_yticks([])
    for spine in ax_main.spines.values(): spine.set_edgecolor('gray')

    domain_circle = plt.Circle((0, 0), config.DOMAIN_RADIUS, color='gray', fill=False, ls='-', lw=1.0, alpha=0.5)
    ax_main.add_artist(domain_circle)

    tracer_cmap = plt.cm.get_cmap(config.TRACER_CMAP)
    
    # Ensure initial positions are NumPy arrays for Matplotlib
    initial_tracer_positions_np = np.empty((0,2), dtype=np.float32)
    if config.N_TRACERS > 0 and len(data_pack["tracer_pos"]) > 0:
        initial_tracer_positions_np = np.asarray(data_pack["tracer_pos"][0])
        
    # -- extra code for 'speed' colouring --
    global_max_speed = 1.0
    if config.TRACER_COLORING_MODE == "speed" and num_frames > 1:
        print("Scanning trajectory to find peak tracer speed …")
        times_np   = np.asarray(data_pack["times"])
        # stack into 3-D array (frames, tracers, xy)
        tp         = np.stack([np.asarray(p) for p in data_pack["tracer_pos"]])
        peak = 0.0
        for f in range(1, num_frames):
            dt  = times_np[f] - times_np[f - 1]
            if dt <= 0:          # safety
                continue
            speed_mag = np.linalg.norm(tp[f] - tp[f - 1], axis=1) / dt
            s_max     = speed_mag.max()
            if s_max > peak:
                peak = s_max
        global_max_speed = max(peak, 1e-8)  # avoid divide-by-zero
        print(f"Peak tracer speed ≈ {global_max_speed:.3e}")
    
    initial_tracer_scalar_values_np = np.empty((0,), dtype=np.float32)
    if config.N_TRACERS > 0 :
         initial_tracer_scalar_values_np = np.asarray(data_pack["tracer_scalar_values"])


    scatter_layers = []
    if config.N_TRACERS > 0 and initial_tracer_positions_np.shape[0] > 0:
        # Normalize scalar values if they are not already in [0,1] range for colormap
        # For "group" mode, they are already ~[0,1]. For "scalar" with random, also [0,1].
        # If using e.g. radius for scalar, ensure normalization was done in initialize_tracers.
        norm_scalar_values = initial_tracer_scalar_values_np
        if np.any(norm_scalar_values < 0) or np.any(norm_scalar_values > 1):
            min_val, max_val = np.min(norm_scalar_values), np.max(norm_scalar_values)
            if max_val > min_val:
                norm_scalar_values = (norm_scalar_values - min_val) / (max_val - min_val)
            else: # all values are the same
                norm_scalar_values = np.full_like(norm_scalar_values, 0.5)


        scatter_colors_mapped = tracer_cmap(norm_scalar_values)
        
        for size_mult, alpha_mult in reversed(config.TRACER_GLOW_LAYERS):
            glow_scatter = ax_main.scatter(
                initial_tracer_positions_np[:,0], initial_tracer_positions_np[:,1], 
                s=config.TRACER_PARTICLE_SIZE * size_mult, 
                c=scatter_colors_mapped, marker='o', edgecolors='none', 
                alpha=config.TRACER_ALPHA * alpha_mult, zorder=1)
            scatter_layers.append(glow_scatter)

        main_scatter = ax_main.scatter(
            initial_tracer_positions_np[:,0], initial_tracer_positions_np[:,1], 
            s=config.TRACER_PARTICLE_SIZE, c=scatter_colors_mapped, 
            marker='o', edgecolors='none', alpha=config.TRACER_ALPHA, zorder=2)
        scatter_layers.append(main_scatter)

    vortex_scatter = None 
    if config.N_VORTICES > 0:
        v_strengths_np = np.asarray(data_pack["vortex_strengths"])
        vortex_colors = [config.VORTEX_COLOR_POS if s > 0 else config.VORTEX_COLOR_NEG for s in v_strengths_np] 
        
        max_abs_strength = np.max(np.abs(v_strengths_np)) if len(v_strengths_np) > 0 else 1.0
        if max_abs_strength < 1e-9: max_abs_strength = 1.0 
        vortex_sizes = config.Vortex_MARKER_SIZE_SCALE * np.abs(v_strengths_np) / max_abs_strength + config.VORTEX_MARKER_SIZE_BASE
        
        initial_vortex_positions_np = np.empty((0,2), dtype=np.float32)
        if len(data_pack["vortex_pos"]) > 0 :
             initial_vortex_positions_np = np.asarray(data_pack["vortex_pos"][0])

        if initial_vortex_positions_np.shape[0] > 0:
            vortex_scatter = ax_main.scatter(initial_vortex_positions_np[:,0], initial_vortex_positions_np[:,1], 
                                            s=vortex_sizes, c=vortex_colors, 
                                            edgecolors='black', linewidths=0.5, zorder=10, alpha=0.9)
    if vortex_scatter is None: # Create a dummy scatter if no vortices, for update function consistency
        vortex_scatter = ax_main.scatter([],[], s=[], c=[])


    time_text = ax_main.text(0.02, 0.95, '', transform=ax_main.transAxes, color='white', fontsize=12)
    
    ax_Lz.set_xlim(0, config.SIMULATION_TIME)
    Lz_hist_np = np.asarray(data_pack["Lz"])
    Lz0_np = np.asarray(data_pack["initial_Lz"]).item() # Ensure scalar
    Lz_denom_np = np.asarray(data_pack["Lz_denom"]).item() # Ensure scalar

    if abs(Lz_denom_np) > 1e-9 and len(Lz_hist_np) > 0:
        rel_err_Lz = (Lz_hist_np - Lz0_np) / Lz_denom_np
        max_abs_rel_err = np.max(np.abs(rel_err_Lz)) if len(rel_err_Lz) > 0 else 1e-5
        plot_Lz_y_limit = max(1e-5, max_abs_rel_err * 1.2) # Ensure positive limit
        ax_Lz.set_ylim(-plot_Lz_y_limit, plot_Lz_y_limit)
        ax_Lz.set_ylabel("Rel. ΔLz/Lz₀", color='lightgray', fontsize=10)
    else: # Denominator is zero or no history
        ax_Lz.set_ylabel("Lz (abs)", color='lightgray', fontsize=10) 
        if len(Lz_hist_np)>0:
             min_Lz_val = np.min(Lz_hist_np)
             max_Lz_val = np.max(Lz_hist_np)
             margin = (max_Lz_val - min_Lz_val) * 0.1 + 1e-5
             ax_Lz.set_ylim(min_Lz_val - margin, max_Lz_val + margin)
        else:
             ax_Lz.set_ylim(-1,1) # Default if no data

    ax_Lz.tick_params(axis='x', colors='lightgray'); ax_Lz.tick_params(axis='y', colors='lightgray')
    ax_Lz.set_facecolor(config.AXES_BG_COLOR); [s.set_edgecolor('gray') for s in ax_Lz.spines.values()]
    line_Lz, = ax_Lz.plot([], [], lw=1.5, color='#FFD700') 

    ax_P.set_xlim(0, config.SIMULATION_TIME)
    ax_P.set_ylabel("Px, Py (abs)", color='lightgray', fontsize=10)
    ax_P.tick_params(axis='x', colors='lightgray'); ax_P.tick_params(axis='y', colors='lightgray')
    ax_P.set_facecolor(config.AXES_BG_COLOR); [s.set_edgecolor('gray') for s in ax_P.spines.values()]
    line_Px, = ax_P.plot([], [], lw=1.5, color='#00FF00', label='Px') 
    line_Py, = ax_P.plot([], [], lw=1.5, color='#FF00FF', label='Py') 
    ax_P.legend(fontsize='small', facecolor=config.AXES_BG_COLOR, edgecolor='gray', labelcolor='lightgray', loc='upper right')

    gpu_info_text = f"GPU: ON ({config.xp.__name__})" if config.GPU_ENABLED else f"GPU: OFF ({config.xp.__name__})"
    info_str = f"DT: {config.DT:.1e} | Sim Time: {config.SIMULATION_TIME:.1f}s | {gpu_info_text} | Seed: {config.RANDOM_SEED}"
    fig_info_text = ax_info.text(0.5, 0.5, info_str, color='lightgray', ha='center', va='center', fontsize=10)

    plt.tight_layout(rect=[0, 0.02, 1, 0.97]) 

    update_elements = []
    if config.N_TRACERS > 0 and initial_tracer_positions_np.shape[0] > 0: update_elements.extend(scatter_layers)
    if config.N_VORTICES > 0 : update_elements.append(vortex_scatter) # vortex_scatter always exists
    update_elements.extend([time_text, line_Lz, line_Px, line_Py])


    def update(frame_idx):
        frame_tracer_pos = np.asarray(data_pack["tracer_pos"][frame_idx])
        if config.N_TRACERS > 0 and frame_tracer_pos.size > 0:
            # ----- update colours if using speed mode -----
            if config.TRACER_COLORING_MODE == "speed":
                if frame_idx == 0:
                    speed_norm = np.zeros(frame_tracer_pos.shape[0], dtype=np.float32)
                else:
                    dt_local  = data_pack["times"][frame_idx] - data_pack["times"][frame_idx - 1]
                    disp      = frame_tracer_pos - np.asarray(data_pack["tracer_pos"][frame_idx - 1])
                    speed_mag = np.linalg.norm(disp, axis=1) / dt_local
                    speed_norm = np.clip(speed_mag / global_max_speed, 0.0, 1.0)
                scatter_colors_mapped = tracer_cmap(speed_norm)
                for scat in scatter_layers:
                    scat.set_offsets(frame_tracer_pos)
                    scat.set_facecolors(scatter_colors_mapped)
            # ----- original handling for 'group' / 'scalar' -----
            else:
                for scat in scatter_layers:
                    scat.set_offsets(frame_tracer_pos)
        
        frame_vortex_pos = np.asarray(data_pack["vortex_pos"][frame_idx])
        if config.N_VORTICES > 0 and frame_vortex_pos.shape[0] > 0:
            vortex_scatter.set_offsets(frame_vortex_pos)
        
        current_times_np = np.asarray(data_pack["times"][:frame_idx+1])
        time_text.set_text(f"Time: {current_times_np[-1]:.2f}s")
        
        current_Lz_values = Lz_hist_np[:frame_idx+1]
        if abs(Lz_denom_np) > 1e-9 and len(current_Lz_values)>0:
            line_Lz.set_data(current_times_np, (current_Lz_values - Lz0_np) / Lz_denom_np)
        elif len(current_Lz_values)>0 : # Denominator is zero, plot absolute values
            line_Lz.set_data(current_times_np, current_Lz_values)
            # Y-limits for absolute Lz already set or handled at initialization
        
        current_Px_values = np.asarray(data_pack["Px"][:frame_idx+1])
        current_Py_values = np.asarray(data_pack["Py"][:frame_idx+1])
        if len(current_Px_values) > 0: # Check if there's data to plot
            line_Px.set_data(current_times_np, current_Px_values)
            line_Py.set_data(current_times_np, current_Py_values)
            if frame_idx == 0: # Set y-limits for Px, Py on the first frame
                all_P_values = np.concatenate((data_pack["Px"], data_pack["Py"]))
                if len(all_P_values) > 0:
                    min_P, max_P = np.min(all_P_values), np.max(all_P_values)
                    margin = (max_P - min_P) * 0.1 + 1e-5 
                    ax_P.set_ylim(min_P - margin, max_P + margin)
                else: 
                    ax_P.set_ylim(-1, 1)


        if frame_idx % 20 == 0: print(f"Animating frame {frame_idx+1}/{num_frames}")
        return update_elements

    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=max(20, 1000 // 30), blit=True)
    
    print(f"Saving animation to {config.OUTPUT_FILENAME} (this may take a while)...")
    start_save_time = time.time()
    if config.FFMPEG_CODEC in {"h264_nvenc", "hevc_nvenc"}:
        extra = [
            "-preset",  config.FFMPEG_PRESET,   # speed / compression
            "-rc",      "vbr",                  # NVENC rate-control mode
            "-cq",      str(config.FFMPEG_CQ),  # constant-quality target
            "-pix_fmt", "yuv420p",
        ]
    else:  # software encoders such as libx264 / libx265
        extra = [
            "-preset",  config.FFMPEG_PRESET,
            "-crf",     str(config.FFMPEG_CRF),
            "-pix_fmt", "yuv420p",
        ]

    # optional threading flag (safe for both paths)
    if config.FFMPEG_THREADS != 0:
        extra.extend(["-threads", str(config.FFMPEG_THREADS)])
    writer = animation.FFMpegWriter(
        fps      = config.FPS,
        codec    = config.FFMPEG_CODEC,
        metadata = dict(artist='AI Simulation Project'),
        extra_args = extra
    )
    ani.save(config.OUTPUT_FILENAME, writer=writer, dpi=config.DPI)
    end_save_time = time.time()
    print(f"Animation saved in {end_save_time - start_save_time:.2f} seconds.")
    plt.close(fig)


if __name__ == "__main__":
    # --- Default Configuration (Group coloring) ---
    # sim_config_group = SimConfig(
    #     N_VORTICES=4, 
    #     N_TRACERS=100000, 
    #     SIMULATION_TIME=0.3, 
    #     GPU_ENABLED=True, # Set to False to test CPU
    #     TRACER_COLORING_MODE="group", 
    #     TRACER_CMAP="viridis", # Example: 'viridis', 'plasma', 'magma'
    #     NUM_TRACER_GROUPS=4, 
    #     OUTPUT_FILENAME="vortex_group_viridis.mp4"
    # )
    # print("Starting simulation with 'group' tracer coloring...")
    # simulation_data_package = run_simulation(sim_config_group)
    # if len(simulation_data_package["times"]) > 1:
    #     print("Starting animation for 'group' mode...")
    #     animate(simulation_data_package, sim_config_group)
    # else:
    #     print("Simulation too short or no data to animate for 'group' mode.")

    # --- Scalar/Viscosity-like Configuration (Continuous coloring) ---
    sim_config_scalar = SimConfig(
        N_VORTICES=20, 
        N_TRACERS=1200000, 
        SIMULATION_TIME=10.0, 
        GPU_ENABLED=True, # Set to False to test CPU
        TRACER_COLORING_MODE="group", 
        TRACER_CMAP="YlGnBu_r", # "jet" is good for continuous scalar fields
        OUTPUT_FILENAME="vortex_scalar_YlGnBu_r.mp4",
        NUM_TRACER_GROUPS=5,
        RANDOM_SEED=42
    )
    simulation_data_package_scalar = run_simulation(sim_config_scalar)
    if len(simulation_data_package_scalar["times"]) > 1:
        print("Starting animation for 'scalar' mode...")
        animate(simulation_data_package_scalar, sim_config_scalar)
    else:
        print("Simulation too short or no data to animate for 'scalar' mode.")

    # --- Other examples ---
    # sim_config = SimConfig(N_VORTICES=3, N_TRACERS=50000, SIMULATION_TIME=3.0, GPU_ENABLED=False, RANDOM_SEED=123, VORTEX_COLOR_POS="cyan", VORTEX_COLOR_NEG="red", OUTPUT_FILENAME="vortex_3_long_cpu.mp4")
    # sim_config = SimConfig(N_VORTICES=0, N_TRACERS=10000, SIMULATION_TIME=0.3, TRACER_COLORING_MODE="scalar", TRACER_CMAP="coolwarm", OUTPUT_FILENAME="no_vortices_scalar_coolwarm.mp4")
    # sim_config = SimConfig(N_VORTICES=7, N_TRACERS=0, OUTPUT_FILENAME="no_tracers_vortices_only.mp4")
    
    print("All specified simulations and animations are done.")

CuPy active. Using GPU acceleration with float32.
Total initial vortex strength: 1.153e+00
Initializing tracers with 'group' coloring mode using 5 groups and cmap 'YlGnBu_r'.
Step 250/5000, Sim Time: 0.50s, Rel. ΔLz/Lz₀: 6.98e-05
Step 500/5000, Sim Time: 1.00s, Rel. ΔLz/Lz₀: 1.28e-04
Step 750/5000, Sim Time: 1.50s, Rel. ΔLz/Lz₀: 1.84e-04
Step 1000/5000, Sim Time: 2.00s, Rel. ΔLz/Lz₀: 2.19e-04
Step 1250/5000, Sim Time: 2.50s, Rel. ΔLz/Lz₀: 2.80e-04
Step 1500/5000, Sim Time: 3.00s, Rel. ΔLz/Lz₀: 3.20e-04
Step 1750/5000, Sim Time: 3.50s, Rel. ΔLz/Lz₀: 3.81e-04
Step 2000/5000, Sim Time: 4.00s, Rel. ΔLz/Lz₀: 4.26e-04
Step 2250/5000, Sim Time: 4.50s, Rel. ΔLz/Lz₀: 4.82e-04
Step 2500/5000, Sim Time: 5.00s, Rel. ΔLz/Lz₀: 4.79e-04
Step 2750/5000, Sim Time: 5.50s, Rel. ΔLz/Lz₀: 4.89e-04
Step 3000/5000, Sim Time: 6.00s, Rel. ΔLz/Lz₀: 5.26e-04
Step 3250/5000, Sim Time: 6.50s, Rel. ΔLz/Lz₀: 5.42e-04
Step 3500/5000, Sim Time: 7.00s, Rel. ΔLz/Lz₀: 5.32e-04
Step 3750/5000, Sim Time: 7.50s, Rel. ΔLz/Lz

MemoryError: Unable to allocate 9.16 MiB for an array with shape (1200000, 2) and data type float32