In [None]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import matplotlib as mpl
from cupyx.scipy.ndimage import gaussian_filter  # If needed for smoothing

from matplotlib.colors import LogNorm
import threading

import numpy as np
from IPython.display import display


#------------------- Toy simulation Jp.Petit Janus Cosmological Model anti gravity effect, pseudo 2D with 3D Poisson on plane -----------------------------#

#     ------         Adjust parameter to produce differents types of galaxies (Need to be listed and more accessible in future)

#       --           if you find some error , open a Github Issue , the code is not "peer reviewed"

#                    go to the ReadMe.md files for more details



# Constants and parameters
N1 = 1000000 # Number of particles for fluid 1 (galaxy)
N2 = 1000000  # Number of particles for fluid 2 (background)

variable_m1=[]

 

r1 = 95
 # Radius of the galaxy (in kpc)
R_d = r1*4.5  # Disk scale length (in kpc) for the exponential shaped  galaxy

R_a = 20
M1 = 5e10 # Total mass of fluid 1 (in solar masses)

L = 450.0  # Box size (in kpc)


init_vel2= 0.58 # for boundary condition of fluids 2

downsample_factor = 1  # Downsample factor for density plots

mean_density_m1=M1*(r1**-2)/cp.pi
#mean_density_m1= 16000000

Rohdensity2 = -mean_density_m1*0.00002#  Density of fluid 2 (in solar masses per kpc^2)
#Rohdensity2=0.000000000000000000000000000001



#M1 = L**2 * density1
M2 = L**2 * Rohdensity2


print(M2)  # Total mass of fluid 2 (in solar masses)
N_grid = 4096  # Number of grid points in each dimension
dt = 5 # Time step (in Myr)
n_steps = 1000 # Number of simulation steps
skipped_anime_frame= 3

# Gravitational constant in kpc^3 M_sun^-1 Myr^-2
G = 4.498502151469554e-12   # good ?



#G = 4.302e-15

# Derived parameters
dx = L / N_grid
m1 = M1 / N1  # Mass per particle for fluid 1
m2 = M2 / N2  # Mass per particle for fluid 2




def plot_density(pos1, pos2, title):
    fig, axs = plt.subplots(1, 2, figsize=(20, 10))

    # Define central region
    center = L / 2
    delta = L / 4  # Adjust delta to change the size of the central region
    region = [[center - delta, center + delta], [center - delta, center + delta]]

    # Calculate histograms for both fluids
    hist1, _, _ = np.histogram2d(
        pos1[:, 0], pos1[:, 1],
        bins=N_grid, range=region
    )
    hist2, _, _ = np.histogram2d(
        pos2[:, 0], pos2[:, 1],
        bins=1024, range=region
    )

    # Avoid zeros for LogNorm
    hist1[hist1 == 0] = 1e-3
    hist2[hist2 == 0] = 1e-3
    #hist1 = hist1 / hist1.max()
    # Plot fluid 1
    im1 = axs[0].imshow(
        hist1,
        origin='lower',
        extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
        cmap='gist_heat', 
        norm=LogNorm()
    )
    axs[0].set_title(f'{title} - Fluid 1')
    axs[0].set_xlabel('x (kpc)')
    axs[0].set_ylabel('y (kpc)')
    plt.colorbar(im1, ax=axs[0], label='Logarithmic Density')

    # Plot fluid 2
    im2 = axs[1].imshow(
        hist2,
        origin='lower',
        extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
        cmap='plasma',
        norm=LogNorm()
    )
    axs[1].set_title(f'{title} - Fluid 2')
    axs[1].set_xlabel('x (kpc)')
    axs[1].set_ylabel('y (kpc)')
    plt.colorbar(im2, ax=axs[1], label='Logarithmic Density')

    plt.tight_layout()
    plt.show()

def calculate_velocities(pos1_gpu, force_grid, dx, L, bulge_radius):

    print(" velocity calcul with random Osipkov-Merritt anisotropy")

    # Initial setup
    N_grid = force_grid.shape[0]  # Assuming square grid
    grid_pos = (pos1_gpu / dx).astype(int) % N_grid  # Wrap around using modulo
    forces = force_grid[grid_pos[:, 0], grid_pos[:, 1]]
    forces = interpolate_force(force_grid, pos1_gpu)

    center = cp.array([L / 2, L / 2])  # Center of the galaxy

    r_vec = pos1_gpu - center  # Vector from center for each particle
    r = cp.sqrt(cp.sum(r_vec**2, axis=1))  # Radial distance for each particle

    # Corrected variable: replaced r1 with bulge_radius
    sigma_0, r_a = 1e-7, R_a
 

    vel1 = cp.zeros_like(pos1_gpu)  # Initialize velocity array
    valid_mask = r > 0  # Valid particles have r > 0

    # Calculate anisotropy parameter beta(r) = r^2 / (r_a^2 + r^2)
    beta = cp.zeros_like(r)
    beta[valid_mask] = (r[valid_mask]**2) / (r_a**2 + r[valid_mask]**2)
    #beta= (r**2) / (r_a**2 + r*2)

    # Derived dispersions From beta = 1 - (sigma_t^2) / (sigma_r^2)

    sigma_r = sigma_0 * cp.ones_like(r)

    sigma_t = cp.sqrt((1 - beta)) * sigma_r  # Tangential dispersion
    # Initialize array to store circular velocities
    v_circ = cp.zeros_like(r)

    # Optimize circular velocity calculation
    print("Computing circular velocities in batches with anisotropic dispersion...")
    batch_size = 100000  # Adjust based on available GPU memory

    valid_indices = cp.where(valid_mask)[0]
    print("Valid indices", valid_indices)
    n_batches = (len(valid_indices) + batch_size - 1) // batch_size

    #n_batches = (N1 + batch_size - 1) // batch_size
    for i in range(n_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(valid_indices))
        batch_indices = valid_indices[start_idx:end_idx]

        # Optimized force norm calculation
        forces_batch = forces[batch_indices]
        force_norm = cp.sqrt(forces_batch[:, 0]**2 + forces_batch[:, 1]**2)

        # Calculate v_circ for batch
        r_batch = r[batch_indices]
        v_circ_batch = cp.sqrt(force_norm * r_batch)*1  # Perturbation on rotation equilibrium

        # test same form like JPP- F.Lhanseat rotation curve
        R0=1.5
        v_circ_batch2 = r_batch * cp.exp(-r_batch / R0)

        # Store v_circ
        v_circ[batch_indices] = v_circ_batch*1 + v_circ_batch2*0

        # Set circular velocities for batch
        r_vec_batch = r_vec[batch_indices]
        vel_circ_x = -v_circ_batch * r_vec_batch[:, 1] / r_batch
        vel_circ_y = v_circ_batch * r_vec_batch[:, 0] / r_batch

        # Generate Anisotropic Random Dispersion Velocities
        # 1. Generate random angles for tangential direction
        theta = cp.arctan2(r_vec_batch[:, 1], r_vec_batch[:, 0])  # Angle of position vector

        # 2. Generate random radial and tangential components
        disp_r = cp.random.normal(0, sigma_r[batch_indices])
        disp_t = cp.random.normal(0, sigma_t[batch_indices])

        # 3. Convert radial and tangential dispersions to Cartesian components
        disp_x = disp_r * cp.cos(theta) - disp_t * cp.sin(theta)
        disp_y = disp_r * cp.sin(theta) + disp_t * cp.cos(theta)

        # 4. Combine circular velocities with dispersion
        vel1[batch_indices, 0] = vel_circ_x + disp_x
        vel1[batch_indices, 1] = vel_circ_y + disp_y

        # Progress logging
        if (i + 1) % 10 == 0 or (i + 1) == n_batches:
            print(f"Processed batch {i+1}/{n_batches}")

    # Define number of radial bins and set max_radius manually
    n_radial_bins = 50
    max_radius = 50  # Set to the desired maximum radius
    print(f"Using fixed maximum radius for binning: {max_radius}")

    radial_bins = cp.linspace(0, max_radius, n_radial_bins + 1)

    # Digitize the radii to assign particles to bins
    bin_indices = cp.digitize(r, bins=radial_bins) - 1  # Bin indices start at 0

    # Initialize arrays to hold binned averages
    sigma_r_binned = cp.zeros(n_radial_bins)
    sigma_t_binned = cp.zeros(n_radial_bins)
    v_circ_binned = cp.zeros(n_radial_bins)
    bin_centers = (radial_bins[:-1] + radial_bins[1:]) / 2

    # Compute binned averages
    print("Computing binned averages...")
    for b in range(n_radial_bins):
        in_bin = bin_indices == b
        count = cp.sum(in_bin)
        if count > 0:
            sigma_r_binned[b] = cp.mean(sigma_r[in_bin])
            sigma_t_binned[b] = cp.mean(sigma_t[in_bin])
            v_circ_binned[b] = cp.mean(v_circ[in_bin])
        else:
            sigma_r_binned[b] = cp.nan
            sigma_t_binned[b] = cp.nan
            v_circ_binned[b] = cp.nan

    # Transfer data to CPU for plotting
    bin_centers_cpu = bin_centers.get()
    sigma_r_cpu = sigma_r_binned.get()
    sigma_t_cpu = sigma_t_binned.get()
    v_circ_cpu = v_circ_binned.get()

    # Handle potential NaNs by removing them
    valid_bins = ~np.isnan(sigma_r_cpu) & ~np.isnan(sigma_t_cpu) & ~np.isnan(v_circ_cpu)
    bin_centers_cpu = bin_centers_cpu[valid_bins]
    sigma_r_cpu = sigma_r_cpu[valid_bins]
    sigma_t_cpu = sigma_t_cpu[valid_bins]
    v_circ_cpu = v_circ_cpu[valid_bins]

    print(f"Maximum radius in bins: {bin_centers_cpu.max()}")

    # Plotting
    fig, axs = plt.subplots(2, 1, figsize=(10, 12))

    # Plot sigma_r and sigma_t vs radius
    axs[0].plot(bin_centers_cpu, sigma_r_cpu, label=r'$\sigma_r$', color='blue')
    axs[0].plot(bin_centers_cpu, sigma_t_cpu, label=r'$\sigma_t$', color='red')
    axs[0].set_xlabel('Radius (units)')
    axs[0].set_ylabel('Velocity Dispersion')
    axs[0].set_title(r'Velocity Dispersions $\sigma_r$ and $\sigma_t$ vs Radius')
    axs[0].legend()
    axs[0].grid(True)
    axs[0].set_xlim(0, max_radius)  # Ensure full radius is displayed

    # Plot v_circ vs radius
    axs[1].plot(bin_centers_cpu, v_circ_cpu, label=r'$v_{\mathrm{circ}}$', color='green')
    axs[1].set_xlabel('Radius (units)')
    axs[1].set_ylabel('Circular Velocity')
    axs[1].set_title(r'Circular Velocity $v_{\mathrm{circ}}$ vs Radius')
    axs[1].legend()
    axs[1].grid(True)
    axs[1].set_xlim(0, max_radius)  # Ensure full radius is displayed

    plt.tight_layout()
    plt.show()

    print("Plots generated successfully.")

    return vel1



def compute_potential_from_density(density, dx, G):
    """
    FFT-based Poisson solver.
    """
    density_k = cp.fft.fft2(density)
    kx = cp.fft.fftfreq(density.shape[0], d=dx)[:, cp.newaxis] * 2 * np.pi
    ky = cp.fft.fftfreq(density.shape[1], d=dx)[cp.newaxis, :] * 2 * np.pi
    k_squared = kx ** 2 + ky ** 2

    # Avoid division by zero for the zero-frequency component
    k_squared[0, 0] = 1.0

    phi_k = -4 * cp.pi * G * density_k / k_squared
    phi_k[0, 0] = 0.0  # Set mean potential to zero

    phi = cp.real(cp.fft.ifft2(phi_k))

    return phi





def characteristic_equations(x, y, vx, vy, phi_total, dx):
    """
    Compute the derivatives for the characteristic equations with interpolation.
    """


    # Compute the gradients with periodic boundaries
    phi_x = (cp.roll(phi_total, -1, axis=0) - cp.roll(phi_total, 1, axis=0)) / (2 * dx)
    phi_y = (cp.roll(phi_total, -1, axis=1) - cp.roll(phi_total, 1, axis=1)) / (2 * dx)

    # Compute fractional indices
    x_indices = x / dx
    y_indices = y / dx

    # Get integer parts and fractional parts
    x0 = cp.floor(x_indices).astype(int)
    y0 = cp.floor(y_indices).astype(int)
    tx = x_indices - x0
    ty = y_indices - y0

    # Handle boundary conditions 
    x0 = cp.mod(x0, phi_total.shape[0])
    y0 = cp.mod(y0, phi_total.shape[1])
    x1 = cp.mod(x0 + 1, phi_total.shape[0])
    y1 = cp.mod(y0 + 1, phi_total.shape[1])

    # Perform bilinear interpolation for grad phi_x
    phi_x00 = phi_x[x0, y0]
    phi_x10 = phi_x[x1, y0]
    phi_x01 = phi_x[x0, y1]
    phi_x11 = phi_x[x1, y1]
    dvx_dt = -(
        (1 - tx) * (1 - ty) * phi_x00 +
        tx * (1 - ty) * phi_x10 +
        (1 - tx) * ty * phi_x01 +
        tx * ty * phi_x11
    )

    # Perform bilinear interpolation for grad phi_y
    phi_y00 = phi_y[x0, y0]
    phi_y10 = phi_y[x1, y0]
    phi_y01 = phi_y[x0, y1]
    phi_y11 = phi_y[x1, y1]
    dvy_dt = -(
        (1 - tx) * (1 - ty) * phi_y00 +
        tx * (1 - ty) * phi_y10 +
        (1 - tx) * ty * phi_y01 +
        tx * ty * phi_y11
    )

    # Characteristic equations
    dx_dt = vx
    dy_dt = vy


    # Need to remove old useless parameter...

    return dx_dt, dy_dt, dvx_dt, dvy_dt

def rk4_step(x, y, vx, vy, phi_total, dx, dt):
    """
    Perform a single RK4 integration step for the particles.
    """

    k1_x, k1_y, k1_vx, k1_vy = characteristic_equations(x, y, vx, vy, phi_total, dx)
    k2_x, k2_y, k2_vx, k2_vy = characteristic_equations(
        x + 0.5 * k1_x * dt, y + 0.5 * k1_y * dt,
        vx + 0.5 * k1_vx * dt, vy + 0.5 * k1_vy * dt, phi_total, dx)
    k3_x, k3_y, k3_vx, k3_vy = characteristic_equations(
        x + 0.5 * k2_x * dt, y + 0.5 * k1_y * dt,
        vx + 0.5 * k2_vx * dt, vy + 0.5 * k1_vy * dt, phi_total, dx)
    k4_x, k4_y, k4_vx, k4_vy = characteristic_equations(
        x + k3_x * dt, y + k3_y * dt,
        vx + k3_vx * dt, vy + k3_vy * dt, phi_total, dx)

    # Update positions and velocities using RK4 coefficients
    x_new = x + (k1_x + 2 * k2_x + 2 * k3_x + k4_x) * dt / 6
    y_new = y + (k1_y + 2 * k2_y + 2 * k3_y + k4_y) * dt / 6
    vx_new = vx + (k1_vx + 2 * k2_vx + 2 * k3_vx + k4_vx) * dt / 6
    vy_new = vy + (k1_vy + 2 * k2_vy + 2 * k3_vy + k4_vy) * dt / 6

    return x_new, y_new, vx_new, vy_new



def backward_integrate(pos, vel, Phi1, Phi2, dx, dt, steps,
                       boundary_injection_func, sigma, init_vel):
    N_particles = pos.shape[0]  # Total number of particles to maintain

    for step in range(steps):
        # Compute total potential
        density2 = particle_to_mesh(pos, -m2)

        Phi2 = compute_potential_from_density(density2, dx, G)

        phi_total = Phi1 + Phi2

        # Perform a single RK4 step
        x_new, y_new, vx_new, vy_new = rk4_step(
            pos[:, 0], pos[:, 1], vel[:, 0], vel[:, 1], phi_total, dx, dt
        )

        # Update positions and velocities
        pos = cp.column_stack((x_new, y_new))
        vel = cp.column_stack((vx_new, vy_new))

        # Identify particles that have left the domain
        outside_domain = (
            (pos[:, 0] < 0) | (pos[:, 0] > L) |
            (pos[:, 1] < 0) | (pos[:, 1] > L)
        )
        num_particles_removed = int(cp.count_nonzero(outside_domain))  # Convert to int


        # Remove particles that have left the domain
        pos = pos[~outside_domain]
        vel = vel[~outside_domain]

        # Calculate the number of particles to inject
        num_particles_to_inject = num_particles_removed

        # Log the number of particles to inject
        #print(f"Particles to inject: {num_particles_to_inject}")

        if num_particles_to_inject > 0:
            # Limit the number of particles to inject at once if necessary
            max_injection_batch = 10000  # Adjust this value based on system capability
            particles_remaining = num_particles_to_inject

            while particles_remaining > 0:
                particles_to_inject_now = min(particles_remaining, max_injection_batch)
                new_particles = boundary_injection_func(
                    particles_to_inject_now, L, dx, init_vel, sigma
                )

                # Unpack the returned tuple
                if (isinstance(new_particles, tuple) and len(new_particles) == 2):
                    new_particles_pos, new_particles_vel = new_particles

                    # Validate the returned arrays
                    if (new_particles_pos.size > 0 and
                        new_particles_vel.size > 0 and
                        new_particles_pos.shape[0] == new_particles_vel.shape[0]):

                        # Concatenate new particles
                        pos = cp.vstack([pos, new_particles_pos])
                        vel = cp.vstack([vel, new_particles_vel])
                        particles_remaining -= particles_to_inject_now
                    else:
                        print("Warning: Invalid new particles from boundary_injection_func.")
                        break  # Exit loop to prevent infinite looping
                else:
                    print("Warning: boundary_injection_func did not return a valid tuple.")
                    break  # Exit loop to prevent infinite looping
        else:
            print("No particles to inject this step.")

        # Ensure total number of particles remains consistent
        total_particles = pos.shape[0]
        if total_particles != N_particles:
            # Adjust particle count to maintain consistent total
            particle_difference = N_particles - total_particles
            if particle_difference > 0:
                # Inject additional particles to reach N_particles
                extra_particles = boundary_injection_func(
                    particle_difference, L, dx, init_vel, sigma
                )
                # Unpack and validate extra particles as before
                if (isinstance(extra_particles, tuple) and len(extra_particles) == 2):
                    extra_particles_pos, extra_particles_vel = extra_particles
                    if (extra_particles_pos.size > 0 and
                        extra_particles_vel.size > 0 and
                        extra_particles_pos.shape[0] == extra_particles_vel.shape[0]):
                        pos = cp.vstack([pos, extra_particles_pos])
                        vel = cp.vstack([vel, extra_particles_vel])
                    else:
                        print("Warning: Invalid extra particles from boundary_injection_func.")
                else:
                    print("Warning: boundary_injection_func did not return a valid tuple for extra particles.")
            elif particle_difference < 0:
                # Remove excess particles
                excess = -particle_difference
                indices_to_remove = cp.random.choice(
                    total_particles, size=excess, replace=False
                )
                pos = cp.delete(pos, indices_to_remove, axis=0)
                vel = cp.delete(vel, indices_to_remove, axis=0)

    return pos, vel


def boundary_injection_func(num_particles_to_inject, L, dx, init_vel, sigma):
    factor_cutoff = 2
    num_particles_to_inject = int(num_particles_to_inject)
    if num_particles_to_inject <= 0:
        return cp.empty((0, 2)), cp.empty((0, 2))

    try:
        positions = cp.zeros((num_particles_to_inject, 2))
        edges = cp.random.choice(4, size=num_particles_to_inject)

        left_mask = edges == 0
        right_mask = edges == 1
        top_mask = edges == 2
        bottom_mask = edges == 3

        # Convert counts to integers
        num_left = int(cp.count_nonzero(left_mask))
        num_right = int(cp.count_nonzero(right_mask))
        num_top = int(cp.count_nonzero(top_mask))
        num_bottom = int(cp.count_nonzero(bottom_mask))

        if num_left > 0:
            positions[left_mask, 0] = 0
            positions[left_mask, 1] = cp.random.uniform(0, L, size=num_left)

        if num_right > 0:
            positions[right_mask, 0] = L
            positions[right_mask, 1] = cp.random.uniform(0, L, size=num_right)

        if num_top > 0:
            positions[top_mask, 1] = L
            positions[top_mask, 0] = cp.random.uniform(0, L, size=num_top)

        if num_bottom > 0:
            positions[bottom_mask, 1] = 0
            positions[bottom_mask, 0] = cp.random.uniform(0, L, size=num_bottom)


        # Efficient sampling of velocities
        v_min = sigma * factor_cutoff
        # Sample U uniformly from [0, 1)
        U = cp.random.uniform(0, 1, size=num_particles_to_inject)
        # Compute speeds v >= v_min
        v = cp.sqrt(v_min**2 - 2 * sigma**2 * cp.log(1 - U * (1 - cp.exp(- (v_min**2) / (2 * sigma**2)))))
        # Sample angles theta uniformly from [0, 2*pi)
        theta = cp.random.uniform(0, 2 * cp.pi, size=num_particles_to_inject)
        # Compute velocity components
        velocities = cp.empty((num_particles_to_inject, 2))
        velocities[:, 0] = v * cp.cos(theta)
        velocities[:, 1] = v * cp.sin(theta)

        
        return positions, velocities

    except Exception as e:
        print(f"Error in boundary_injection_func: {e}")
        return cp.empty((0, 2)), cp.empty((0, 2))

def radial_average(quantityCP, L, N_grid, r_max):
    """
    Computes the radial average of a 2D quantity up to radius r_max using NumPy.

    Parameters:
        quantityCP (cp.ndarray): CuPy array of the quantity.
        L (float): Size of the domain.
        N_grid (int): Number of grid points in one dimension.
        r_max (float): Maximum radius to compute radial average.

    Returns:
        tuple: (radial_avg, r_bin_centers)
    """
    quantity = cp.asnumpy(quantityCP)
    dx = L / N_grid
    # Create coordinate arrays centered at zero
    x = np.linspace(-L/2 + dx/2, L/2 - dx/2, N_grid)
    y = np.linspace(-L/2 + dx/2, L/2 - dx/2, N_grid)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2)

    # Mask data outside r_max
    mask = R <= r_max
    R = R[mask]
    quantity = quantity[mask]

    # Define radial bins
    dr = dx  # Bin size equal to grid spacing
    r_bins = np.arange(0, r_max + dr, dr)
    r_bin_centers = (r_bins[:-1] + r_bins[1:]) / 2

    # Compute radial averages
    bin_indices = np.digitize(R, r_bins) - 1
    radial_sum = np.zeros(len(r_bin_centers))
    radial_count = np.zeros(len(r_bin_centers))
    
    # Accumulate sums and counts
    np.add.at(radial_sum, bin_indices, quantity)
    np.add.at(radial_count, bin_indices, 1)
    
    radial_avg = radial_sum / radial_count

    return radial_avg, r_bin_centers


def calculate_sigma0(M1, R_d, R_max):
    """
    Calculate the central surface mass density Sigma_0 for an exponential disk.

    Parameters:
    - M1 (float): Total mass of the disk (Msun)
    - R_d (float): Scale length of the disk (kpc)
    - R_max (float): Maximum radius of the disk (kpc)

    Returns:
    - Sigma_0 (float): Central surface mass density (Msun/kpc^2)
    """
    exponent = -R_max / R_d
    normalization = 1 - np.exp(exponent) * (1 + R_max / R_d)
    Sigma_0 = M1 / (2 * np.pi * R_d**2 * normalization)
    return Sigma_0


def initialize_particles():
    print("Initializing fluid 1 particles...")
    # Initialize fluid 1 particles (galaxy)
    print("Initializing fluid 1 particles...")
    # Parameters for Fluid 1 (Galaxy)


    galaxy_stars = []
    count = 0
    """
    # Generate positions for Fluid 1 particles
    while count < N1:
        u = np.random.rand()
        radius = -R_d * np.log(1 - u * (1 - np.exp(-r1 / R_d)))  # Ensure R <= r1

        if radius <= r1:
            angle = np.random.uniform(0, 2 * np.pi)
            posX = radius * np.cos(angle)
            posY = radius * np.sin(angle)
            galaxy_stars.append((posX, posY))
            count += 1
       

    # Parameters for Fluid 1 (Galaxy)
    
    scale_radius = r1
    bulge_radius = r1 / 3
    galaxy_stars = []
    count = 0

    # Generate positions for Fluid 1 particles
    while count < N1:
        u = np.random.rand() ** 0.95
        a = bulge_radius / scale_radius
        denominator = 1 - u * (1 + 1 / a) ** -a
        if denominator != 0:
            radius = scale_radius * u / denominator
            if radius <= scale_radius:
                angle = np.random.uniform(0, 2 * np.pi)
                posX = radius * np.cos(angle)
                posY = radius * np.sin(angle)
                galaxy_stars.append((posX, posY))
                count += 1

    galaxy_stars = np.array(galaxy_stars)
    print("galaxy_stars",galaxy_stars.shape)   
    pos1 = galaxy_stars + np.array([L / 2, L / 2])
    pos1_gpu = cp.asarray(pos1)


    galaxy_stars = []
    count = 0
    sigma = 15 # Standard deviation for Gaussian distribution

    # Define maximum radius to limit the Gaussian (e.g., 3 sigma)
    max_radius = 2 * sigma

    # Precompute the normalization constant for the radial Gaussian distribution
    # P(r) = 2 * pi * r * exp(-r^2 / (2 sigma^2))
    # Cumulative distribution function (CDF) for inverse transform sampling
    def sample_radius():
        while True:
            u = np.random.rand()
            r = np.sqrt(-2 * sigma**2 * np.log(1 - u * (1 - np.exp(-(max_radius**2) / (2 * sigma**2)))))
            if r <= max_radius:
                return r

    # Generate positions for Fluid 1 particles using Gaussian radial distribution
    while count < N1:
        r = sample_radius()
        angle = np.random.uniform(0, 2 * np.pi)
        posX = r * np.cos(angle)
        posY = r * np.sin(angle)
        galaxy_stars.append((posX, posY))
        count += 1

    

    galaxy_stars = np.array(galaxy_stars)
    print("galaxy_stars", galaxy_stars.shape)

    # Shift positions to center the simulation box
    pos1 = galaxy_stars + np.array([L / 2, L / 2])

    # Transfer to GPU
    pos1_gpu = cp.asarray(pos1)
   """  
    #Sigma_0 = calculate_sigma0(M1, R_d, r1)        # Central surface mass density (Msun/kpc^2)

    # Generate uniformly distributed angles
    theta = np.random.uniform(0, 2 * np.pi, N1)

    # Generate exponentially distributed radial distances using inverse transform sampling
    u = np.random.uniform(0, 1, N1)
    r = -R_d * np.log(1 - u * (1 - np.exp(-r1 / R_d) * (1 + r1 / R_d)))

    # Convert polar coordinates to Cartesian coordinates
    posX = r * np.cos(theta)
    posY = r * np.sin(theta)
    galaxy_stars = np.column_stack((posX, posY))

    # Shift to center the simulation box
    pos1 = galaxy_stars + np.array([L / 2, L / 2])

    # Transfer to GPU
    pos1_gpu = cp.asarray(pos1)
    
    print("Fluid 1 particles initialized.")



    print("Depositing fluid 1 density onto grid...")
    density1 = particle_to_mesh(pos1_gpu, m1)
    print("Fluid 1 density deposited.")

    print("Computing gravitational potential for fluid 1...")
    Phi1 = compute_potential_from_density(-density1, dx, G)

    print("Gravitational potential for fluid 1 computed.")

    print("Initializing fluid 2 particles...")
    print("Initializing fluid 2 particles...")
    # Initialize fluid 2 particles uniformly across the domain, excluding a central hole
    pos2_gpu = cp.random.uniform(0, L, (N2, 2))
    pos2Boundaries = pos2_gpu.copy()
    """hole_radius = r1 * 1.2
    distance_squared = cp.sum((pos2_gpu - L / 2) ** 2, axis=1)
    mask = distance_squared >= hole_radius ** 2
    pos2_gpu = pos2_gpu[mask]"""
    print("pos2Shape",pos2_gpu.shape)

    # Initialize velocities for fluid 2
    sigma = init_vel2 / cp.sqrt(2)
    vel2_gpu = cp.random.normal(0, sigma, size=(pos2_gpu.shape[0], 2))
    vel2Boundary= vel2_gpu.copy()
    print("Fluid 2 particles initialized.")

    print("Depositing fluid 2 density onto grid...")
    density2 = particle_to_mesh(pos2_gpu, -m2)
    #density2 = apply_boundary_conditions_density(density2, -Rohdensity2)
    print("Fluid 2 density deposited and boundary conditions applied.")

    print("Computing gravitational potential for fluid 2...")
    Phi2 = compute_potential_from_density(density2, dx, G)
    print("Gravitational potential for fluid 2 computed.")

    print("Starting backward integration for fluid 2...")
    max_iterations = 50
    tolerance = 1e-4
    dt_back = -dt * 1 # Negative time step for backward integration
    integration_time = 100
    steps = int(abs(integration_time / dt_back))


    for iteration in range(max_iterations):

        # Perform backward integration focusing on boundary-initiated particles
        print(f"Starting backward integration for fluid 2, iteration {iteration + 1}...")
        pos2_gpu, vel2_gpu = backward_integrate(pos2_gpu, vel2_gpu, Phi1,Phi2, dx, dt_back, steps,boundary_injection_func ,sigma,init_vel2)


        # Compute new density
        density2_new = particle_to_mesh(pos2_gpu, -m2)

        # Apply boundary conditions on density and velocity
        #density2_new = apply_boundary_conditions_density(density2_new, -Rohdensity2)
        #vel2_gpu = apply_boundary_conditions_velocity(vel2Boundary, pos2_gpu, init_vel2)

        # Compute new potential
        Phi2_new = compute_potential_from_density(density2_new, dx, G)

        # Check for convergence based on potential change
        delta_Phi = cp.max(cp.abs(Phi2_new - Phi2))

        print(f"Iteration {iteration + 1}, delta_Phi: {delta_Phi}")

        if delta_Phi < tolerance:
            print(f"Converged after {iteration + 1} iterations.")
            Phi2 = Phi2_new
            density2 = density2_new
            break

        # Update for next iteration
        # Update Phi2 with under-relaxation
        relaxation_factor = 0.2  # Adjust as needed
        Phi2 = relaxation_factor * Phi2_new + (1 - relaxation_factor) * Phi2
        density2 = density2_new

    else:
        print("Maximum iterations reached without convergence.")


    print("Backward integration for fluid 2 completed.")



    density1 = particle_to_mesh(pos1_gpu, m1 )
    density2 = particle_to_mesh(pos2_gpu, m2 )

    #total_density = density1 
    total_density = density1 + density2
    force_grid = solve_poisson(total_density, dx, G)  # solve_poisson returns forces from density
    vel1 = cp.zeros((N1, 2))


    # Set circular velocities for galaxy particles from potential
    print("start circular velocity")
    vel1_circ = calculate_velocities(pos1_gpu, force_grid, dx, L, r1/3)
    print(vel1_circ)
   

    

    vel1 = vel1_circ
  

   


    return pos1_gpu, vel1, pos2_gpu, vel2_gpu






def solve_poisson(density, dx, G):
    """
    Poisson solver FFT
    """
    
    # Regular Poisson solve with density contrasts
    mean_density = cp.mean(density)
    delta_density = density - mean_density
    
    density_k = cp.fft.fft2(delta_density)
    
    # Wave numbers
    kx = cp.fft.fftfreq(density.shape[0], d=dx)[:, cp.newaxis] * 2 * np.pi
    ky = cp.fft.fftfreq(density.shape[1], d=dx)[cp.newaxis, :] * 2 * np.pi
    k_squared = kx**2 + ky**2

      # 3. Better k=0 handling
    k_squared[0,0] = 1.0  # Temporary non-zero value
    
    # Handle k=0 mode properly
  
    
    phi_k= -4 * cp.pi * G * density_k / k_squared
    phi_k[0,0] = 0.0  # Remove constant potential
    
    # Calculate forces
    force_x_k = -1j *  kx * phi_k
    force_y_k = -1j * ky * phi_k
    
    #force_x = cp.fft.irfft2(force_x_k, s=density.shape)
    #force_y = cp.fft.irfft2(force_y_k, s=density.shape)

    force_x = cp.real(cp.fft.ifft2(force_x_k))
    force_y = cp.real(cp.fft.ifft2(force_y_k))
    
    return cp.stack([force_x, force_y], axis=-1)

def particle_to_mesh(pos, mass):
    density = cp.zeros((N_grid, N_grid), dtype=cp.float32)
    
    # Get the lower-left grid point for each particle
    pos_grid = (pos / dx).astype(cp.float32)
    pos_grid_floor = cp.floor(pos_grid).astype(int)
    
    # Calculate weights for CIC interpolation
    dxp = pos_grid - pos_grid_floor  # Distance from particle to lower-left grid point
    
    # Weights for the 4 nearest cells
    wx_r = dxp[:, 0]        # right x weight
    wx_l = 1.0 - wx_r       # left x weight
    wy_t = dxp[:, 1]        # top y weight
    wy_b = 1.0 - wy_t       # bottom y weight
    
    # Handle periodic boundary conditions
    x0 = pos_grid_floor[:, 0] % N_grid
    x1 = (x0 + 1) % N_grid
    y0 = pos_grid_floor[:, 1] % N_grid
    y1 = (y0 + 1) % N_grid
    
    # Distribute mass to the 4 nearest cells
    cp.add.at(density, (x0, y0), mass * wx_l * wy_b)  # bottom-left
    cp.add.at(density, (x1, y0), mass * wx_r * wy_b)  # bottom-right
    cp.add.at(density, (x0, y1), mass * wx_l * wy_t)  # top-left
    cp.add.at(density, (x1, y1), mass * wx_r * wy_t)  # top-right
    
    return density / dx**2

def interpolate_force(force, pos):
    grid_pos = pos / dx
    x0 = cp.floor(grid_pos[:, 0]).astype(int)
    y0 = cp.floor(grid_pos[:, 1]).astype(int)

    wx = grid_pos[:, 0] - x0
    wy = grid_pos[:, 1] - y0

    x0 = x0 % N_grid
    y0 = y0 % N_grid
    x1 = (x0 + 1) % N_grid
    y1 = (y0 + 1) % N_grid

    f00 = force[x0, y0]
    f10 = force[x1, y0]
    f01 = force[x0, y1]
    f11 = force[x1, y1]

    fx = (1 - wx) * (1 - wy) * f00[:, 0] + wx * (1 - wy) * f10[:, 0] + (1 - wx) * wy * f01[:, 0] + wx * wy * f11[:, 0]
    fy = (1 - wx) * (1 - wy) * f00[:, 1] + wx * (1 - wy) * f10[:, 1] + (1 - wx) * wy * f01[:, 1] + wx * wy * f11[:, 1]

    return cp.stack([fx, fy], axis=1)




def get_force(pos,vel,m_gravitational):
    density1 = particle_to_mesh(pos[:N1], m1)
    density2 = particle_to_mesh(pos[N1:], m2)
    total_density = density1 + density2
    
    # Get forces directly from density
    force_grid = solve_poisson(total_density, dx, G)

    #print(stab)
    return interpolate_force(force_grid, pos)* m_gravitational[:, cp.newaxis]



def run_simulation(pos1, vel1, pos2, vel2):
    positions1 = []
    positions2 = []

    # Combine positions and velocities
    all_pos = cp.concatenate([pos1, pos2])
    all_vel = cp.concatenate([vel1, vel2])



    # test stars + gas components 

        # --- Parameters for Mass Repartition ---
    """ x = 0.5         # Fraction of pos1 particles with mass > m1 (e.g., 30%)
    alpha = 3     # Factor to increase mass for the first population (m'_1 = alpha * m1)

    # --- Step 1: Determine Number of Particles in Each Population ---
    N1_A = int(x * N1)            # Number of pos1 particles with mass m'_1
    N1_B = N1 - N1_A              # Number of pos1 particles with mass m''_1

    # --- Step 2: Calculate m'_1 and m''_1 ---
    m_prime = m1 * alpha          # Mass for the first population (m'_1)
    print("star mass",m_prime)
    # Ensure that m''_1 remains positive
    if N1_B == 0:
        raise ValueError("N1_B (number of particles with mass < m1) cannot be zero.")
    m_double_prime = (N1 * m1 - N1_A * m_prime) / N1_B
    print("gas mass",m_double_prime)

    if m_double_prime <= 0:
        raise ValueError("Calculated m''_1 is non-positive. Adjust 'alpha' or 'x'.")

    # --- Step 3: Assign Masses to pos1 Particles ---
    # Initialize all pos1 masses to m''_1
    mass_pos1 = cp.full(N1, m_double_prime)

    # Randomly select N1_A indices to assign m'_1
    indices = cp.random.choice(N1, N1_A, replace=False)
    mass_pos1[indices] = m_prime

    # --- Step 4: Create Mass Arrays for Inertial and Gravitational Masses ---
    # For inertial mass, use absolute values
    m_inertial_pos1 = cp.abs(mass_pos1)

    # For gravitational mass, use the actual masses (can be positive or negative)
    m_gravitational_pos1 = mass_pos1

    # --- Step 5: Concatenate with pos2 Masses ---
    m_inertial = cp.concatenate([m_inertial_pos1, cp.full(N2, abs(m2))])
    m_gravitational = cp.concatenate([m_gravitational_pos1, cp.full(N2, m2)])
"""
    
    # Create mass arrays for both inertial and gravitational masses
    m_inertial = cp.concatenate([cp.full(N1, abs(m1)), cp.full(N2, abs(m2))])
    m_gravitational = cp.concatenate([cp.full(N1, m1), cp.full(N2, m2)])

    # Initial force calculation with gravitational mass
    force = get_force(all_pos,all_vel,m_gravitational)

    for step in range(n_steps):
        # Leapfrog integration with inertial mass
        all_vel_half = all_vel + 0.5 * (force / m_inertial[:, cp.newaxis]) * dt
        all_pos = (all_pos + all_vel_half * dt) % L

    
        # New force calculation with gravitational mass Poisson Solver
        new_force = get_force(all_pos,all_vel,m_gravitational)
        print("new_force",new_force)

       
        
        # Complete velocity update with inertial mass
        all_vel = all_vel_half + 0.5 * (new_force / m_inertial[:, cp.newaxis]) * dt

        force = new_force

        # Downsample factor
         # Adjust as needed for lower resolution

        if step % skipped_anime_frame == 0:
            """positions1_low_res = cp.asnumpy(all_pos[:N1:downsample_factor])
            positions2_low_res = cp.asnumpy(all_pos[N1::downsample_factor])
            positions1.append(positions1_low_res)
            positions2.append(positions2_low_res)"""
            positions1.append(cp.asnumpy(all_pos[:N1]))
            positions2.append(cp.asnumpy(all_pos[N1:]))
            
        if step % (skipped_anime_frame*5) == 0:
            #plot_density(positions1[-1],positions2[-1], "step : "+ str(step))
            plot_density(positions1[-1],positions2[-1], "step : "+ str(step))
            #plot_density(positions1[-1])

    return positions1, positions2

# Add these settings for faster rendering
plt.style.use('fast')  # Use fast style
mpl.rcParams['path.simplify'] = True
mpl.rcParams['path.simplify_threshold'] = 1.0
mpl.rcParams['agg.path.chunksize'] = 10000


def create_efficient_animation(positions1, positions2, skip_frames=1, dpi=50):
    fig, ax = plt.subplots(figsize=(10, 10))
    
    n_frames = len(positions1) // skip_frames

    center = L / 2
    delta = L / 4  # Adjust delta if needed
    region = [[center - delta, center + delta], [center - delta, center + delta]]
    
    def animate(i):
        frame_idx = i * skip_frames
        if frame_idx >= len(positions1):
            frame_idx = len(positions1) - 1

        ax.clear()

        pos1 = positions1[frame_idx]
        #pos2 = positions2[frame_idx]
        
        hist1, _, _ = np.histogram2d(
            pos1[:, 0], pos1[:, 1],
            bins=N_grid, range=region
        )
        """hist2, _, _ = np.histogram2d(
            pos2[:, 0], pos2[:, 1],
            bins=N_grid, range=region
        )"""

       # histTotal = hist1 + hist2
        histTotal = hist1 
        histTotal[histTotal == 0] = 1e-3  # Avoid zeros for LogNorm

        im = ax.imshow(
            histTotal,
            origin='lower',
            extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
            cmap='gist_heat',
            norm=LogNorm()
        )

        ax.set_title(f"Time Step {frame_idx}")
        ax.set_xlabel('x (kpc)')
        ax.set_ylabel('y (kpc)')
        plt.tight_layout()
        
        return [im]

    anim = FuncAnimation(fig, animate, frames=n_frames, interval=100, blit=True)
    anim.save('density_animation.gif', writer='pillow', fps=5, dpi=dpi)
    plt.close()

posfinal1=[]
posfinal2=[]
def run_and_save_simulation():
    # Initialize and run simulation
    print("Initializing particles...")

    pos1, vel1, pos2, vel2 = initialize_particles()

    print("Particles initialized!")
    
    # Show initial state
    plot_density(cp.asnumpy(pos1), cp.asnumpy(pos2), "Initial State")
    #plot_density(cp.asnumpy(pos1))
    
    # Run simulation and collect positions with mass parameters
    positions1, positions2 = run_simulation(
        pos1=pos1,
        vel1=vel1,
        pos2=pos2,
        vel2=vel2
    )
    
    # Show final state
    plot_density(positions1[-1], positions2[-1], "Final State")

    # Convert positions to lists for JSON serialization
   # print("Saving data to 'positions.json'...")
    #data = {
       # 'positions1': [pos.tolist() for pos in positions1],
      #  'positions2': [pos.tolist() for pos in positions2]
    #}

    # Save data to 'positions.json'
    #with open('positions.json', 'w') as f:
       # json.dump(data, f)

    #plot_density(cp.asnumpy(pos1))

    # Create optimized animation
    print("Creating animation...")
    create_efficient_animation(positions1, positions2, 
                             skip_frames=1,  # Save every ith frame
                             dpi=50)        # Lower resolution
    
    print("Animation saved!")

#if __name__ == "__main__":
    #run_and_save_simulation()
if __name__ == "__main__":
    threading.Thread(target=run_and_save_simulation).start()