In [None]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import matplotlib as mpl
from cupyx.scipy.ndimage import gaussian_filter  # If needed for smoothing
from scipy.stats import qmc
import numpy.ma as ma
from scipy.ndimage import gaussian_filter1d
import datashader as ds
import datashader.transfer_functions as tf
import pandas as pd
from datashader.colors import viridis
from matplotlib.colors import LogNorm
import threading
from pythreejs import *
import numpy as np
from IPython.display import display
from IPython.display import clear_output 



#                       ------------------------------------------- V2 -------------------------------------------





#    - NEED REFACTOR ALL CODE ( remove no used functions )

#    - Shifted boundary negative fluids distribution at border of domain

#    - Steady State Osipkov - Merrit Anisotropic Isothermal generalisation distribution function of galaxy -> Truncated Version of Jean Pierre Petit 2D simulation




# Constants and parameters
N1 = 1000000 # Number of particles for fluid 1 (galaxy)
N2 = 1000000  # Number of particles for fluid 2 (background)

 
#M1 = 0.00001 # Total mass of fluid 1 (in solar masses)

r1 = 45 # Radius of the galaxy (in kpc)
R_d = r1 /3 # Disk scale length (in kpc)

M1 = 3e10 # Total mass of fluid 1 (in solar masses)


L = 700.0  # Box size (in kpc)

init_vel2=0.55

init_vel1=0

downsample_factor = 1  # Downsample factor for density plots

mean_density_m1=M1*(r1**-2)/cp.pi
#mean_density_m1= 16000000

#Rohdensity2 = -mean_density_m1*0.02#  Density of fluid 2 (in solar masses per kpc^2)
Rohdensity2 = - 55000



#M1 = L**2 * density1
M2 = L**2 * Rohdensity2


print(M2)  # Total mass of fluid 2 (in solar masses)
N_grid = 4096  # Number of grid points in each dimension
dt = 5 # Time step (in Myr)
n_steps = 400 # Number of simulation steps
skipped_anime_frame=5

# Gravitational constant in kpc^3 M_sun^-1 Myr^-2
G = 4.498502151469554e-12



# Derived parameters
dx = L / N_grid
m1 = M1 / N1  # Mass per particle for fluid 1
m2 = M2 / N2  # Mass per particle for fluid 2




def plot_density2(pos1, pos2, title):

    fig, ax = plt.subplots(figsize=(10, 10))

    # Define central region
    center = L / 2
    delta = L / 2  # Adjust delta to change the size of the central region
    region = [[center - delta, center + delta], [center - delta, center + delta]]

    # Calculate histograms for both fluids
    hist1, _, _ = np.histogram2d(
        pos1[:, 0], pos1[:, 1],
        bins=N_grid*2, range=region
    )
    hist2, _, _ = np.histogram2d(
        pos2[:, 0], pos2[:, 1],
        bins=N_grid*2, range=region
    )

    # Combine histograms
    histTotal = hist1 + hist2
    histTotal[histTotal == 0] = 1e-3  # Avoid zeros for LogNorm

    im = ax.imshow(
        histTotal.T,
        origin='lower',
        extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
        cmap='plasma',
        norm=LogNorm()
    )

    ax.set_title(title)
    ax.set_xlabel('x (kpc)')
    ax.set_ylabel('y (kpc)')
    plt.colorbar(im, ax=ax, label='Logarithmic Density')
    plt.tight_layout()
    plt.show()


def plot_density(pos1, pos2, title):
    #clear_output(wait=True)
    fig, axs = plt.subplots(1, 2, figsize=(20, 10))

    # Define central region
    center = L / 2
    delta = L / 2  # Adjust delta to change the size of the central region
    region = [[center - delta, center + delta], [center - delta, center + delta]]

    # Calculate histograms for both fluids
    hist1, _, _ = np.histogram2d(
        pos1[:, 0], pos1[:, 1],
        bins=N_grid*2, range=region
    )
    hist2, _, _ = np.histogram2d(
        pos2[:, 0], pos2[:, 1],
        bins=1024, range=region
    )

    # Avoid zeros for LogNorm
    hist1[hist1 == 0] = 1e-3
    hist2[hist2 == 0] = 1e-3

    # Plot fluid 1
    im1 = axs[0].imshow(
        hist1,
        origin='lower',
        extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
        cmap='gist_heat',
        norm=LogNorm()
    )
    axs[0].set_title(f'{title} - Fluid 1')
    axs[0].set_xlabel('x (kpc)')
    axs[0].set_ylabel('y (kpc)')
    plt.colorbar(im1, ax=axs[0], label='Logarithmic Density')

    # Plot fluid 2
    im2 = axs[1].imshow(
        hist2,
        origin='lower',
        extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
        cmap='plasma',
        norm=LogNorm()
    )
    axs[1].set_title(f'{title} - Fluid 2')
    axs[1].set_xlabel('x (kpc)')
    axs[1].set_ylabel('y (kpc)')
    plt.colorbar(im2, ax=axs[1], label='Logarithmic Density')

    plt.tight_layout()
    plt.show()



def calculate_velocities(pos1_gpu, force_grid, dx, L, bulge_radius):
    print("Starting velocity calculations...")
    
    # Initial setup
    grid_pos = (pos1_gpu / dx).astype(int) % N_grid
    forces = force_grid[grid_pos[:, 0], grid_pos[:, 1]]

    center = cp.array([L/2, L/2]) # Center of the galaxy

    
    r_vec = pos1_gpu - center
    r = cp.sqrt(cp.sum(r_vec**2, axis=1))
    
    vel1 = cp.zeros_like(pos1_gpu)
    valid_mask = r > 0
    v_circ = cp.zeros_like(r)
    # Optimize circular velocity calculation
    print("Computing circular velocities in batches...")
    batch_size = 100000  # Adjust based on available GPU memory
    valid_indices = cp.where(valid_mask)[0]
    n_batches = (len(valid_indices) + batch_size - 1) // batch_size
    
    for i in range(n_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(valid_indices))
        batch_indices = valid_indices[start_idx:end_idx]
        
        # Optimized force norm calculation
        forces_batch = forces[batch_indices]
        force_norm = cp.sqrt(forces_batch[:, 0]**2 + forces_batch[:, 1]**2)
        
        # Calculate v_circ for batch
        r_batch = r[batch_indices]
        v_circ_batch = cp.sqrt(force_norm * r_batch)

        # test same form like JPP- F.Lhanseat rotation curve
        #R0=15  v_circ_batch2 = r_batch * cp.exp(-r_batch / R0) -> spiral
        R0=50
        v_circ_batch2 = r_batch * cp.exp(-r_batch / R0)/4
        # Store v_circ
        v_circ[batch_indices] = v_circ_batch*0 + v_circ_batch2*1

        
        # Set velocities for batch
        r_vec_batch = r_vec[batch_indices]
        vel1[batch_indices, 0] = -v_circ_batch * r_vec_batch[:, 1] / r_batch
        vel1[batch_indices, 1] = v_circ_batch * r_vec_batch[:, 0] / r_batch
        
        if (i + 1) % 10 == 0:
            print(f"Processed batch {i+1}/{n_batches}")
    vel_mean = cp.mean(vel1, axis=0)
    vel1 -= vel_mean
        # Define number of radial bins and set max_radius manually
    n_radial_bins = 50
    max_radius = 100.0  # Set to the desired maximum radius
    print(f"Using fixed maximum radius for binning: {max_radius}")

    radial_bins = cp.linspace(0, max_radius, n_radial_bins + 1)

    # Digitize the radii to assign particles to bins
    bin_indices = cp.digitize(r, bins=radial_bins) - 1  # Bin indices start at 0

    # Initialize arrays to hold binned averages
   
    v_circ_binned = cp.zeros(n_radial_bins)
    bin_centers = (radial_bins[:-1] + radial_bins[1:]) / 2

    # Compute binned averages
    print("Computing binned averages...")
    for b in range(n_radial_bins):
        in_bin = bin_indices == b
        count = cp.sum(in_bin)
        if count > 0:
           
            v_circ_binned[b] = cp.mean(v_circ[in_bin])
        else:
            
            v_circ_binned[b] = cp.nan

    # Transfer data to CPU for plotting
    bin_centers_cpu = bin_centers.get()
   
    v_circ_cpu = v_circ_binned.get()

    # Handle potential NaNs by removing them
    valid_bins = ~np.isnan(v_circ_cpu)
    bin_centers_cpu = bin_centers_cpu[valid_bins]

    v_circ_cpu = v_circ_cpu[valid_bins]

    print(f"Maximum radius in bins: {bin_centers_cpu.max()}")

    # Plotting
    fig, axs = plt.subplots(2, 1, figsize=(10, 12))

    

    # Plot v_circ vs radius
    axs[1].plot(bin_centers_cpu, v_circ_cpu, label=r'$v_{\mathrm{circ}}$', color='green')
    axs[1].set_xlabel('Radius (units)')
    axs[1].set_ylabel('Circular Velocity')
    axs[1].set_title(r'Circular Velocity $v_{\mathrm{circ}}$ vs Radius')
    axs[1].legend()
    axs[1].grid(True)
    axs[1].set_xlim(0, max_radius)  # Ensure full radius is displayed

    plt.tight_layout()
    plt.show()
    
    return vel1


def compute_potential_from_density(density, dx, G):
    """
    Compute gravitational potential Φ(x) from density ρ(x) using FFT-based Poisson solver.
    """
    density_k = cp.fft.fft2(density)
    kx = cp.fft.fftfreq(density.shape[0], d=dx)[:, cp.newaxis] * 2 * np.pi
    ky = cp.fft.fftfreq(density.shape[1], d=dx)[cp.newaxis, :] * 2 * np.pi
    k_squared = kx ** 2 + ky ** 2

    # Avoid division by zero for the zero-frequency component
    k_squared[0, 0] = 1.0

    phi_k = -4 * cp.pi * G * density_k / k_squared
    phi_k[0, 0] = 0.0  # Set mean potential to zero

    phi = cp.real(cp.fft.ifft2(phi_k))

    return phi


def sample_positions_from_density(density, N_particles, dx):
    density_flat = density.ravel()
    density_flat = cp.maximum(density_flat, 0)
    density_sum = cp.sum(density_flat)
    if density_sum == 0:
        raise ValueError("Density field sums to zero.")

    probabilities = density_flat / density_sum
    probabilities_np = cp.asnumpy(probabilities)

    indices_np = np.random.choice(len(probabilities_np), size=N_particles, p=probabilities_np)
    indices = cp.asarray(indices_np)
    grid_size = density.shape[0]

    x_indices = indices // grid_size
    y_indices = indices % grid_size

    x_positions = (x_indices + cp.random.uniform(size=N_particles)) * dx
    y_positions = (y_indices + cp.random.uniform(size=N_particles)) * dx

    positions = cp.stack((x_positions, y_positions), axis=1)
    return positions



def characteristic_equations(x, y, vx, vy, phi_total, dx):
    """
    Compute the derivatives for the characteristic equations with interpolation.
    """
    # Compute the gradients of the total potential on the grid
    #phi_x = cp.gradient(phi_total, axis=0) / dx
    #phi_y = cp.gradient(phi_total, axis=1) / dx

    # Compute the gradients with periodic boundaries
    phi_x = (cp.roll(phi_total, -1, axis=0) - cp.roll(phi_total, 1, axis=0)) / (2 * dx)
    phi_y = (cp.roll(phi_total, -1, axis=1) - cp.roll(phi_total, 1, axis=1)) / (2 * dx)

    # Compute fractional indices
    x_indices = x / dx
    y_indices = y / dx

    # Get integer parts and fractional parts
    x0 = cp.floor(x_indices).astype(int)
    y0 = cp.floor(y_indices).astype(int)
    tx = x_indices - x0
    ty = y_indices - y0

    # Handle boundary conditions (assuming periodic or reflecting)
    x0 = cp.mod(x0, phi_total.shape[0])
    y0 = cp.mod(y0, phi_total.shape[1])
    x1 = cp.mod(x0 + 1, phi_total.shape[0])
    y1 = cp.mod(y0 + 1, phi_total.shape[1])

    # Perform bilinear interpolation for phi_x
    phi_x00 = phi_x[x0, y0]
    phi_x10 = phi_x[x1, y0]
    phi_x01 = phi_x[x0, y1]
    phi_x11 = phi_x[x1, y1]
    dvx_dt = -(
        (1 - tx) * (1 - ty) * phi_x00 +
        tx * (1 - ty) * phi_x10 +
        (1 - tx) * ty * phi_x01 +
        tx * ty * phi_x11
    )

    # Perform bilinear interpolation for phi_y
    phi_y00 = phi_y[x0, y0]
    phi_y10 = phi_y[x1, y0]
    phi_y01 = phi_y[x0, y1]
    phi_y11 = phi_y[x1, y1]
    dvy_dt = -(
        (1 - tx) * (1 - ty) * phi_y00 +
        tx * (1 - ty) * phi_y10 +
        (1 - tx) * ty * phi_y01 +
        tx * ty * phi_y11
    )

    # Characteristic equations
    dx_dt = vx
    dy_dt = vy

    return dx_dt, dy_dt, dvx_dt, dvy_dt

def rk4_step(x, y, vx, vy, phi_total, dx, dt):
    """
    Perform a single RK4 integration step for the particles.
    """

    k1_x, k1_y, k1_vx, k1_vy = characteristic_equations(x, y, vx, vy, phi_total, dx)
    k2_x, k2_y, k2_vx, k2_vy = characteristic_equations(
        x + 0.5 * k1_x * dt, y + 0.5 * k1_y * dt,
        vx + 0.5 * k1_vx * dt, vy + 0.5 * k1_vy * dt, phi_total, dx)
    k3_x, k3_y, k3_vx, k3_vy = characteristic_equations(
        x + 0.5 * k2_x * dt, y + 0.5 * k1_y * dt,
        vx + 0.5 * k2_vx * dt, vy + 0.5 * k1_vy * dt, phi_total, dx)
    k4_x, k4_y, k4_vx, k4_vy = characteristic_equations(
        x + k3_x * dt, y + k3_y * dt,
        vx + k3_vx * dt, vy + k3_vy * dt, phi_total, dx)

    # Update positions and velocities using RK4 coefficients
    x_new = x + (k1_x + 2 * k2_x + 2 * k3_x + k4_x) * dt / 6
    y_new = y + (k1_y + 2 * k2_y + 2 * k3_y + k4_y) * dt / 6
    vx_new = vx + (k1_vx + 2 * k2_vx + 2 * k3_vx + k4_vx) * dt / 6
    vy_new = vy + (k1_vy + 2 * k2_vy + 2 * k3_vy + k4_vy) * dt / 6

    return x_new, y_new, vx_new, vy_new



def backward_integrate2(pos, vel, Phi1, Phi2, dx, dt, steps,
                       boundary_injection_func, sigma, init_vel):
    N_particles = pos.shape[0]  # Total number of particles to maintain

    for step in range(steps):
        # Compute total potential
        density2 = particle_to_mesh(pos, -m2)

        Phi2 = compute_potential_from_density(density2, dx, G)

        phi_total = Phi1 + Phi2

        # Perform a single RK4 step
        x_new, y_new, vx_new, vy_new = rk4_step(
            pos[:, 0], pos[:, 1], vel[:, 0], vel[:, 1], phi_total, dx, dt
        )

        # Update positions and velocities
        pos = cp.column_stack((x_new, y_new))
        vel = cp.column_stack((vx_new, vy_new))

        # Identify particles that have left the domain
        outside_domain = (
            (pos[:, 0] < 0) | (pos[:, 0] > L) |
            (pos[:, 1] < 0) | (pos[:, 1] > L)
        )
        num_particles_removed = int(cp.count_nonzero(outside_domain))  # Convert to int


        # Remove particles that have left the domain
        pos = pos[~outside_domain]
        vel = vel[~outside_domain]

        # Calculate the number of particles to inject
        num_particles_to_inject = num_particles_removed

        # Log the number of particles to inject
        #print(f"Particles to inject: {num_particles_to_inject}")

        if num_particles_to_inject > 0:
            # Limit the number of particles to inject at once if necessary
            max_injection_batch = 10000  # Adjust this value based on system capability
            particles_remaining = num_particles_to_inject

            while particles_remaining > 0:
                particles_to_inject_now = min(particles_remaining, max_injection_batch)
                new_particles = boundary_injection_func(
                    particles_to_inject_now, L, dx, init_vel, sigma
                )

                # Unpack the returned tuple
                if (isinstance(new_particles, tuple) and len(new_particles) == 2):
                    new_particles_pos, new_particles_vel = new_particles

                    # Validate the returned arrays
                    if (new_particles_pos.size > 0 and
                        new_particles_vel.size > 0 and
                        new_particles_pos.shape[0] == new_particles_vel.shape[0]):

                        # Concatenate new particles
                        pos = cp.vstack([pos, new_particles_pos])
                        vel = cp.vstack([vel, new_particles_vel])
                        particles_remaining -= particles_to_inject_now
                    else:
                        print("Warning: Invalid new particles from boundary_injection_func.")
                        break  # Exit loop to prevent infinite looping
                else:
                    print("Warning: boundary_injection_func did not return a valid tuple.")
                    break  # Exit loop to prevent infinite looping
        else:
            print("No particles to inject this step.")

        # Ensure total number of particles remains consistent
        total_particles = pos.shape[0]
        if total_particles != N_particles:
            # Adjust particle count to maintain consistent total
            particle_difference = N_particles - total_particles
            if particle_difference > 0:
                # Inject additional particles to reach N_particles
                extra_particles = boundary_injection_func(
                    particle_difference, L, dx, init_vel, sigma
                )
                # Unpack and validate extra particles as before
                if (isinstance(extra_particles, tuple) and len(extra_particles) == 2):
                    extra_particles_pos, extra_particles_vel = extra_particles
                    if (extra_particles_pos.size > 0 and
                        extra_particles_vel.size > 0 and
                        extra_particles_pos.shape[0] == extra_particles_vel.shape[0]):
                        pos = cp.vstack([pos, extra_particles_pos])
                        vel = cp.vstack([vel, extra_particles_vel])
                    else:
                        print("Warning: Invalid extra particles from boundary_injection_func.")
                else:
                    print("Warning: boundary_injection_func did not return a valid tuple for extra particles.")
            elif particle_difference < 0:
                # Remove excess particles
                excess = -particle_difference
                indices_to_remove = cp.random.choice(
                    total_particles, size=excess, replace=False
                )
                pos = cp.delete(pos, indices_to_remove, axis=0)
                vel = cp.delete(vel, indices_to_remove, axis=0)

    return pos, vel


def boundary_injection_func2(num_particles_to_inject, L, dx, init_vel, sigma):
    factor_cutoff = 2
    num_particles_to_inject = int(num_particles_to_inject)
    if num_particles_to_inject <= 0:
        return cp.empty((0, 2)), cp.empty((0, 2))

    try:
        positions = cp.zeros((num_particles_to_inject, 2))
        edges = cp.random.choice(4, size=num_particles_to_inject)

        left_mask = edges == 0
        right_mask = edges == 1
        top_mask = edges == 2
        bottom_mask = edges == 3

        # Convert counts to integers
        num_left = int(cp.count_nonzero(left_mask))
        num_right = int(cp.count_nonzero(right_mask))
        num_top = int(cp.count_nonzero(top_mask))
        num_bottom = int(cp.count_nonzero(bottom_mask))

        if num_left > 0:
            positions[left_mask, 0] = 0
            positions[left_mask, 1] = cp.random.uniform(0, L, size=num_left)

        if num_right > 0:
            positions[right_mask, 0] = L
            positions[right_mask, 1] = cp.random.uniform(0, L, size=num_right)

        if num_top > 0:
            positions[top_mask, 1] = L
            positions[top_mask, 0] = cp.random.uniform(0, L, size=num_top)

        if num_bottom > 0:
            positions[bottom_mask, 1] = 0
            positions[bottom_mask, 0] = cp.random.uniform(0, L, size=num_bottom)


        # Efficient sampling of velocities
        v_min = sigma * factor_cutoff
        # Sample U uniformly from [0, 1)
        U = cp.random.uniform(0, 1, size=num_particles_to_inject)
        # Compute speeds v >= v_min
        v = cp.sqrt(v_min**2 - 2 * sigma**2 * cp.log(1 - U * (1 - cp.exp(- (v_min**2) / (2 * sigma**2)))))
        # Sample angles theta uniformly from [0, 2*pi)
        theta = cp.random.uniform(0, 2 * cp.pi, size=num_particles_to_inject)
        # Compute velocity components
        velocities = cp.empty((num_particles_to_inject, 2))
        velocities[:, 0] = v * cp.cos(theta)
        velocities[:, 1] = v * cp.sin(theta)

        
        return positions, velocities

    except Exception as e:
        print(f"Error in boundary_injection_func: {e}")
        return cp.empty((0, 2)), cp.empty((0, 2))
    


def backward_integrate(pos, vel, Phi1, Phi2, dx, dt, steps,
                       boundary_injection_func, sigma, init_vel):
    N_particles = pos.shape[0]  # Total number of particles to maintain

    for step in range(steps):
        # Compute total potential
        density2 = particle_to_mesh(pos, -m2)

        Phi2 = compute_potential_from_density(density2, dx, G)

        phi_total = Phi1 + Phi2

        # Perform a single RK4 step
        x_new, y_new, vx_new, vy_new = rk4_step(
            pos[:, 0], pos[:, 1], vel[:, 0], vel[:, 1], phi_total, dx, dt
        )

        # Update positions and velocities
        pos = cp.column_stack((x_new, y_new))
        vel = cp.column_stack((vx_new, vy_new))

        # Identify particles that have left the domain
        outside_domain = (
            (pos[:, 0] < 0) | (pos[:, 0] > L) |
            (pos[:, 1] < 0) | (pos[:, 1] > L)
        )
        num_particles_removed = int(cp.count_nonzero(outside_domain))  # Convert to int


        # Remove particles that have left the domain
        pos = pos[~outside_domain]
        vel = vel[~outside_domain]

        # Calculate the number of particles to inject
        num_particles_to_inject = num_particles_removed

        # Log the number of particles to inject
        #print(f"Particles to inject: {num_particles_to_inject}")

        if num_particles_to_inject > 0:
            # Limit the number of particles to inject at once if necessary
            max_injection_batch = 10000  # Adjust this value based on system capability
            particles_remaining = num_particles_to_inject

            while particles_remaining > 0:
                particles_to_inject_now = min(particles_remaining, max_injection_batch)
                new_particles = boundary_injection_func(
                    particles_to_inject_now, L, dx, init_vel, sigma,phi_total
                )

                # Unpack the returned tuple
                if (isinstance(new_particles, tuple) and len(new_particles) == 2):
                    new_particles_pos, new_particles_vel = new_particles

                    # Validate the returned arrays
                    if (new_particles_pos.size > 0 and
                        new_particles_vel.size > 0 and
                        new_particles_pos.shape[0] == new_particles_vel.shape[0]):

                        # Concatenate new particles
                        pos = cp.vstack([pos, new_particles_pos])
                        vel = cp.vstack([vel, new_particles_vel])
                        particles_remaining -= particles_to_inject_now
                    else:
                        print("Warning: Invalid new particles from boundary_injection_func.")
                        break  # Exit loop to prevent infinite looping
                else:
                    print("Warning: boundary_injection_func did not return a valid tuple.")
                    break  # Exit loop to prevent infinite looping
        else:
            print("No particles to inject this step.")

        # Ensure total number of particles remains consistent
        total_particles = pos.shape[0]
        if total_particles != N_particles:
            # Adjust particle count to maintain consistent total
            particle_difference = N_particles - total_particles
            if particle_difference > 0:
                # Inject additional particles to reach N_particles
                extra_particles = boundary_injection_func(
                    particle_difference, L, dx, init_vel, sigma,phi_total
                )
                # Unpack and validate extra particles as before
                if (isinstance(extra_particles, tuple) and len(extra_particles) == 2):
                    extra_particles_pos, extra_particles_vel = extra_particles
                    if (extra_particles_pos.size > 0 and
                        extra_particles_vel.size > 0 and
                        extra_particles_pos.shape[0] == extra_particles_vel.shape[0]):
                        pos = cp.vstack([pos, extra_particles_pos])
                        vel = cp.vstack([vel, extra_particles_vel])
                    else:
                        print("Warning: Invalid extra particles from boundary_injection_func.")
                else:
                    print("Warning: boundary_injection_func did not return a valid tuple for extra particles.")
            elif particle_difference < 0:
                # Remove excess particles
                excess = -particle_difference
                indices_to_remove = cp.random.choice(
                    total_particles, size=excess, replace=False
                )
                pos = cp.delete(pos, indices_to_remove, axis=0)
                vel = cp.delete(vel, indices_to_remove, axis=0)

    return pos, vel


def get_phi_bilinear(Phi, x, y, dx):
    """
    Bilinear interpolation of Phi at coordinates (x, y).
    
    Parameters
    ----------
    Phi : cp.ndarray of shape (Ny, Nx)
        2D potential array.
    x, y : cp.ndarray, shape (N,)
        Arrays of particle or boundary positions where we want Phi(x,y).
    dx : float
        Grid spacing (assuming same in x & y).
    
    Returns
    -------
    phi : cp.ndarray of shape (N,)
        Interpolated potential at each (x[i], y[i]).
    """
    Nx = Phi.shape[1]
    Ny = Phi.shape[0]
    
    # fractional indices
    fx = x / dx
    fy = y / dx
    
    ix = cp.floor(fx).astype(int)
    iy = cp.floor(fy).astype(int)
    
    # clamp so we can safely do ix+1, iy+1
    ix = cp.clip(ix, 0, Nx - 2)
    iy = cp.clip(iy, 0, Ny - 2)
    
    # fractional parts
    alpha_x = fx - ix
    alpha_y = fy - iy
    
    # corners of the cell
    c00 = Phi[iy,   ix  ]
    c10 = Phi[iy,   ix+1]
    c01 = Phi[iy+1, ix  ]
    c11 = Phi[iy+1, ix+1]
    
    # linear interp in x
    c0 = c00*(1.0 - alpha_x) + c10*alpha_x
    c1 = c01*(1.0 - alpha_x) + c11*alpha_x
    
    # linear interp in y
    phi = c0*(1.0 - alpha_y) + c1*alpha_y
    return phi
def sample_positions_on_edge(num_points, L, dx, Phi, sigma, side='left'):
    """
    Sample 'num_points' positions in [0, L] along a single edge ('left', 'right', etc.)
    using an isothermal-like weighting w_s = exp(-Phi / (2*sigma^2)),
    but with bilinear interpolation of Phi.
    
    Parameters:
        num_points (int): Number of positions to sample along the edge
        L (float): Domain size
        dx (float): Grid spacing for indexing into Phi
        Phi (cp.ndarray): Potential array of shape (Ny, Nx)
        sigma (float): velocity dispersion (units consistent with sqrt of potential)
        side (str): 'left', 'right', 'top', 'bottom'
    
    Returns:
        sampled_s : cp.ndarray of shape (num_points,)
            The coordinate (y-coord if side in {left,right} or x-coord if top,bottom).
    """

    
    # We'll use the bilinear function from above
    # (Assumes get_phi_bilinear is defined globally or in scope)
    
    M = max(4, int(cp.ceil(L/dx)))
    s_vals = cp.linspace(0, L, M)

    # For each s in [0,L], define (x_arr, y_arr) depending on the edge
    if side == 'left':
        x_arr = cp.zeros_like(s_vals)
        y_arr = s_vals
    elif side == 'right':
        x_arr = L * cp.ones_like(s_vals)
        y_arr = s_vals
    elif side == 'bottom':
        x_arr = s_vals
        y_arr = cp.zeros_like(s_vals)
    else:  # 'top'
        x_arr = s_vals
        y_arr = L * cp.ones_like(s_vals)

    # Use bilinear interpolation instead of nearest-neighbor
    boundary_phi = get_phi_bilinear(Phi, x_arr, y_arr, dx)

    # if Phi has units ~ v^2, define alpha = 1/(2*sigma^2)
    alpha = 1.0 / (2.0 * sigma**2)
    
    # Weighted PDF ~ exp(-alpha * Phi)
    w_s = cp.exp(-alpha * boundary_phi)
    w_s = cp.clip(w_s, 0.0, None)  # avoid negative or nan

    # Build normalized CDF
    pdf = w_s / w_s.sum()
    cdf = cp.cumsum(pdf)

    # Sample positions
    urand = cp.random.random(num_points)
    idx_samples = cp.searchsorted(cdf, urand, side='right')
    idx_samples = cp.clip(idx_samples, 0, M-1)

    sampled_s = s_vals[idx_samples]
    return sampled_s

def boundary_injection_func(
    num_particles_to_inject, 
    L, 
    dx, 
    init_vel, 
    sigma, 
    Phi
):

    
    factor_cutoff = 2
    num_particles_to_inject = int(num_particles_to_inject)
    if num_particles_to_inject <= 0:
        return cp.empty((0, 2)), cp.empty((0, 2))

    try:
        positions = cp.zeros((num_particles_to_inject, 2))
        edges = cp.random.choice(4, size=num_particles_to_inject)

        left_mask   = edges == 0
        right_mask  = edges == 1
        top_mask    = edges == 2
        bottom_mask = edges == 3

        num_left   = int(cp.count_nonzero(left_mask))
        num_right  = int(cp.count_nonzero(right_mask))
        num_top    = int(cp.count_nonzero(top_mask))
        num_bottom = int(cp.count_nonzero(bottom_mask))

        # --- Use the new sample_positions_on_edge with bilinear phi ---
        if num_left > 0:
            yvals_left = sample_positions_on_edge(
                num_points=num_left,
                L=L,
                dx=dx,
                Phi=Phi,
                sigma=sigma,
                side='left'
            )
            positions[left_mask, 0] = 0.0
            positions[left_mask, 1] = yvals_left

        if num_right > 0:
            yvals_right = sample_positions_on_edge(
                num_points=num_right,
                L=L,
                dx=dx,
                Phi=Phi,
                sigma=sigma,
                side='right'
            )
            positions[right_mask, 0] = L
            positions[right_mask, 1] = yvals_right

        if num_top > 0:
            xvals_top = sample_positions_on_edge(
                num_points=num_top,
                L=L,
                dx=dx,
                Phi=Phi,
                sigma=sigma,
                side='top'
            )
            positions[top_mask, 0] = xvals_top
            positions[top_mask, 1] = L

        if num_bottom > 0:
            xvals_bottom = sample_positions_on_edge(
                num_points=num_bottom,
                L=L,
                dx=dx,
                Phi=Phi,
                sigma=sigma,
                side='bottom'
            )
            positions[bottom_mask, 0] = xvals_bottom
            positions[bottom_mask, 1] = 0.0

        # -------------- Velocity sampling as before --------------
        v_min = sigma * factor_cutoff
        U = cp.random.uniform(0, 1, size=num_particles_to_inject)
        v = cp.sqrt(
            v_min**2
            - 2 * sigma**2
            * cp.log(1 - U * (1 - cp.exp(- (v_min**2) / (2 * sigma**2))))
        )

        theta = cp.random.uniform(0, 2 * cp.pi, size=num_particles_to_inject)
        velocities = cp.empty((num_particles_to_inject, 2))
        velocities[:, 0] = v * cp.cos(theta)
        velocities[:, 1] = v * cp.sin(theta)

        # -----------  bilinear interpolation for local_phi -----------
        local_phi = get_phi_bilinear(Phi, positions[:, 0], positions[:, 1], dx)

        #  shift logic
        Phi_min = cp.min(local_phi)
        v_sq = v**2 + 2.0 * (local_phi - Phi_min)
        v_sq = cp.clip(v_sq, 0.0, None)

        v_shifted = cp.sqrt(v_sq)
        velocities[:, 0] = v_shifted * cp.cos(theta)
        velocities[:, 1] = v_shifted * cp.sin(theta)

        return positions, velocities

    except Exception as e:
        print(f"Error in boundary_injection_func: {e}")
        return cp.empty((0, 2)), cp.empty((0, 2))


def radial_average(quantityCP, L, N_grid, r_max):
    """
    Computes the radial average of a 2D quantity up to radius r_max using NumPy.

    Parameters:
        quantityCP (cp.ndarray): CuPy array of the quantity.
        L (float): Size of the domain.
        N_grid (int): Number of grid points in one dimension.
        r_max (float): Maximum radius to compute radial average.

    Returns:
        tuple: (radial_avg, r_bin_centers)
    """
    quantity = cp.asnumpy(quantityCP)
    dx = L / N_grid
    # Create coordinate arrays centered at zero
    x = np.linspace(-L/2 + dx/2, L/2 - dx/2, N_grid)
    y = np.linspace(-L/2 + dx/2, L/2 - dx/2, N_grid)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2)

    # Mask data outside r_max
    mask = R <= r_max
    R = R[mask]
    quantity = quantity[mask]

    # Define radial bins
    dr = dx  # Bin size equal to grid spacing
    r_bins = np.arange(0, r_max + dr, dr)
    r_bin_centers = (r_bins[:-1] + r_bins[1:]) / 2

    # Compute radial averages
    bin_indices = np.digitize(R, r_bins) - 1
    radial_sum = np.zeros(len(r_bin_centers))
    radial_count = np.zeros(len(r_bin_centers))
    
    # Accumulate sums and counts
    np.add.at(radial_sum, bin_indices, quantity)
    np.add.at(radial_count, bin_indices, 1)
    
    radial_avg = radial_sum / radial_count

    return radial_avg, r_bin_centers

def compute_gradient_smooth(data, dr):
    # Apply a smoothing filter to data
    kernel = cp.array([1/4, 1/2, 1/4])
    data_smooth = cp.convolve(data, kernel, mode='same')
    gradient = cp.gradient(data_smooth, dr)
    return gradient


def compute_circular_velocity(Phi_Rn, Rn):
    """
    Computes the circular velocity v_c(R) from the gravitational potential Phi(R).

    Parameters:
        Phi_Rn (cp.ndarray): Gravitational potential as a function of R.
        Rn (cp.ndarray): Radial positions.

    Returns:
        cp.ndarray: Circular velocity v_c(R).
    """
    Rcp = cp.asarray(Rn)
    Phi_Rcp = cp.asarray(Phi_Rn)
    dr = Rcp[1] - Rcp[0]

    # Compute the gradient of Phi with respect to R
    dPhi_dR = cp.gradient(Phi_Rcp, dr)
    print("dPhi_dR",dPhi_dR)

    # Circular velocity: v_c = sqrt(R * (-dPhi/dR))
    v_c_squared = Rcp * (dPhi_dR)


    # Ensure v_c_squared is non-negative
    v_c_squared = cp.maximum(v_c_squared, 0.0)

    v_c = cp.sqrt(v_c_squared)

    # Enforce v_c = 0 at R = 0
    v_c = cp.where(Rcp == 0, 0.0, v_c)
    print(v_c,"v_c")

    return v_c


def assign_velocities(pos1_gpu, sigma_R, sigma_phi, R, v_phi):
    """
    Assign velocities to particles based on the solved velocity dispersions.

    Parameters:
        pos1_gpu: (N1, 2) CuPy array of particle positions
        sigma_R: (N_grid,) CuPy array of radial velocity dispersions
        sigma_phi: (N_grid,) CuPy array of azimuthal velocity dispersions
        R: (N_grid,) CuPy array of radial positions
        v_phi: (N_grid,) CuPy array of mean azimuthal velocities

    Returns:
        velocities_gpu: (N1, 2) CuPy array of particle velocities
    """
    pos1 = cp.asnumpy(pos1_gpu)
    N1 = pos1.shape[0]
    R_cpu = cp.asnumpy(R)
    sigma_R_cpu = cp.asnumpy(sigma_R)
    sigma_phi_cpu = cp.asnumpy(sigma_phi)
    v_phi_cpu = cp.asnumpy(v_phi)

    # Compute radial distances
    radial_dist = np.sqrt(pos1[:, 0]**2 + pos1[:, 1]**2)

    # Interpolate quantities to particle positions
    sigma_R_p = np.interp(radial_dist, R_cpu, sigma_R_cpu)
    sigma_phi_p = np.interp(radial_dist, R_cpu, sigma_phi_cpu)
    v_phi_p = np.interp(radial_dist, R_cpu, v_phi_cpu)

    # Generate random velocity components
    v_R = np.random.normal(0, sigma_R_p)
    v_phi_random = np.random.normal(0, sigma_phi_p)

    # Total azimuthal velocity
    v_phi_total = v_phi_p + v_phi_random

    # Convert polar velocities to Cartesian coordinates
    theta = np.arctan2(pos1[:, 1], pos1[:, 0])
    vX = v_R * np.cos(theta) - v_phi_total * np.sin(theta)
    vY = v_R * np.sin(theta) + v_phi_total * np.cos(theta)

    velocities = np.vstack((vX, vY)).T
    velocities_gpu = cp.asarray(velocities)

    # Ensure zero net linear momentum
    total_momentum = cp.sum(velocities_gpu, axis=0)
    mean_velocity = total_momentum / N1
    velocities_gpu -= mean_velocity

    return velocities_gpu


def interp2d_cp(field, x_pos, y_pos, L, dx):
    """
    Bicubic interpolation on GPU for a 2D field.

    Parameters:
    - field: (N_grid, N_grid) CuPy array
    - x_pos, y_pos: particle positions (CuPy arrays)
    - L: size of the domain
    - dx: grid spacing

    This function:
    1) Converts (x_pos, y_pos) to fractional cell coordinates.
    2) Locates a 4x4 patch of cells around each particle.
    3) Performs bicubic interpolation using the 16 values.
    """

    N_grid = field.shape[0]
    ix = (x_pos / dx).astype(cp.float32)
    iy = (y_pos / dx).astype(cp.float32)

    # floor indices for the center
    i_base = cp.floor(ix).astype(cp.int32)
    j_base = cp.floor(iy).astype(cp.int32)

    # We'll need four indices in each direction: i0 = i_base-1, i1=i_base, i2=i_base+1, i3=i_base+2
    i0 = cp.clip(i_base - 1, 0, N_grid-1)
    i1 = cp.clip(i_base,     0, N_grid-1)
    i2 = cp.clip(i_base + 1, 0, N_grid-1)
    i3 = cp.clip(i_base + 2, 0, N_grid-1)

    j0 = cp.clip(j_base - 1, 0, N_grid-1)
    j1 = cp.clip(j_base,     0, N_grid-1)
    j2 = cp.clip(j_base + 1, 0, N_grid-1)
    j3 = cp.clip(j_base + 2, 0, N_grid-1)

    # fractional coordinates inside the cell
    x0 = ix - i1  # i1 is the "second" index, effectively i_base
    y0 = iy - j1

    # Ensure x0,y0 in [0,1]. If negative due to boundary clamping, clamp again.
    x0 = cp.clip(x0, 0.0, 1.0)
    y0 = cp.clip(y0, 0.0, 1.0)

    # Gather the 16 values for each particle
    # We'll do this by indexing field with j0,j1,j2,j3 and i0,i1,i2,i3
    # We need to stack them into patch arrays. This might be large for many particles,
    # so consider doing this in smaller batches if memory is an issue.
    f00 = field[j0, i0]
    f01 = field[j0, i1]
    f02 = field[j0, i2]
    f03 = field[j0, i3]

    f10 = field[j1, i0]
    f11 = field[j1, i1]
    f12 = field[j1, i2]
    f13 = field[j1, i3]

    f20 = field[j2, i0]
    f21 = field[j2, i1]
    f22 = field[j2, i2]
    f23 = field[j2, i3]

    f30 = field[j3, i0]
    f31 = field[j3, i1]
    f32 = field[j3, i2]
    f33 = field[j3, i3]

    # We'll apply cubic_interpolate row-wise, then col-wise:
    # We can vectorize this approach:
    # Define a helper kernel on GPU might be ideal, but let's just inline the logic.

    # We must run cubic interpolation on the GPU. Let's write a vectorized approach:
    # We'll apply the same formula elementwise.

    def cubic_vec(p0, p1, p2, p3, t):
        a0 = p3 - p2 - p0 + p1
        a1 = p0 - p1 - a0
        a2 = p2 - p0
        a3 = p1
        return a0*(t**3) + a1*(t**2) + a2*t + a3

    # Interpolate along x for each row
    # shape of each is (N_particles,)

    col0 = cubic_vec(f00, f01, f02, f03, x0)
    col1 = cubic_vec(f10, f11, f12, f13, x0)
    col2 = cubic_vec(f20, f21, f22, f23, x0)
    col3 = cubic_vec(f30, f31, f32, f33, x0)

    # Now interpolate these four results in y
    f_out = cubic_vec(col0, col1, col2, col3, y0)

    return f_out

def periodic_gradient_x(field, dx):
    return (cp.roll(field, -1, axis=1) - cp.roll(field, 1, axis=1)) / (2*dx)

def periodic_gradient_y(field, dx):
    return (cp.roll(field, -1, axis=0) - cp.roll(field, 1, axis=0)) / (2*dx)


def sample_positions_from_density(rho, dx, L, Nparticles, rng=None):
    """
    Sample particle positions from a given 2D density distribution.
    
    Parameters
    ----------
    rho : 2D numpy array
        Final density distribution of shape (N_grid, N_grid).
    dx : float
        Grid spacing.
    L : float
        Size of the simulation box (in same units as dx).
    Nparticles : int
        Number of particles to sample.
    rng : numpy.random.Generator, optional
        A NumPy random number generator instance for reproducibility.
        If None, use default global RNG.
    
    Returns
    -------
    pos : (Nparticles, 2) numpy array
        Array of sampled particle positions in the box coordinate system [0, L].
    """
    rho= cp.asnumpy(rho)
    if rng is None:
        rng = np.random.default_rng()
    
    N_grid = rho.shape[0]
    
    # Compute total mass (integral of rho)
    total_mass = np.sum(rho) * dx**2
    
    # Convert rho to probability density function (pdf)
    pdf = rho.ravel()  # Flatten to 1D
    pdf_sum = np.sum(pdf)
    if pdf_sum <= 0:
        raise ValueError("Density distribution is non-positive or zero. Cannot sample.")
    
    # Normalize to form a probability distribution
    pdf_norm = pdf / pdf_sum
    
    # Compute cumulative distribution function (CDF)
    cdf = np.cumsum(pdf_norm)
    
    # Generate uniform random values
    u = rng.random(Nparticles)
    
    # Invert CDF to find indices
    # For each u, find the index in cdf that corresponds
    indices = np.searchsorted(cdf, u, side='right')
    
    # Convert 1D indices back to 2D (i, j)
    i = indices // N_grid
    j = indices % N_grid
    
    # Compute cell-centered coordinates
    # Assuming x and y arrays defined as before:
    # x = np.linspace(0, L, N_grid, endpoint=False) + dx/2
    # y = np.linspace(0, L, N_grid, endpoint=False) + dx/2
    
    x_coords = (i + 0.5)*dx
    y_coords = (j + 0.5)*dx
    
    # Stack into positions array
    pos = cp.column_stack((cp.asarray(x_coords), cp.asarray(y_coords)))
    
    return pos

def R_grid_index2(x_part, y_part):
    # Shift to [0, L)
    x_shifted = x_part + L / 2
    y_shifted = y_part + L / 2

    # Ensure positions are within [0, L) due to periodicity
    x_shifted = cp.mod(x_shifted, L)
    y_shifted = cp.mod(y_shifted, L)

    # Compute grid indices
    i = cp.floor(x_shifted / dx).astype(cp.int32) % N_grid
    j = cp.floor(y_shifted / dx).astype(cp.int32) % N_grid

    return i, j
def sample_truncated_normal(mean, sigma, low, high, size):
    """
    Sample from a normal distribution truncated to [low, high].
    For demonstration only, using a naive acceptance-rejection approach.
    """
    out = cp.empty(size, dtype=cp.float64)
    i = 0
    while i < size:
        r = cp.random.normal(mean, sigma, size=size*2)  # oversample
        mask = (r >= low) & (r <= high)
        needed = min(size - i, mask.sum())
        if needed > 0:
            out[i : i+needed] = r[mask][:needed]
            i += needed
    return out

def bilinear_interpolation_2d32(grid, xarr, yarr, L, dx):
    """
    Bilinear interpolation of `grid` at coordinates (xarr, yarr).
    
    grid : cupy.ndarray of shape (Nx, Ny)
       The scalar field to interpolate, e.g. Psi_final.
    xarr, yarr : cupy.ndarray of shape (N_particles,)
       Particle positions in [0, L].
    L : float
       Box size (assuming 0 <= x,y < L).
    dx : float
       Cell size in the uniform grid.
    
    Returns cupy.ndarray of shape (N_particles,) with interpolated values.
    """
    Nx, Ny = grid.shape

    # 1) Convert particle positions to fractional indices
    #    If your grid is "centered" at L/2, shift accordingly
    #    or if it's from [0..L], then no shift is needed.
    # 
    # If your Psi_final is built with:
    #   x_grid[i] = (i+0.5)*dx - L/2,
    #   y_grid[j] = (j+0.5)*dx - L/2,
    # then you need to shift by +L/2 to map [0..L]->[-L/2..+L/2].
    # Adjust as your code's conventions require.
    
    i_f = (xarr + L/2) / dx
    j_f = (yarr + L/2) / dx

    # 2) Integer indices
    i0 = cp.floor(i_f).astype(cp.int32)
    j0 = cp.floor(j_f).astype(cp.int32)

    i1 = i0 + 1
    j1 = j0 + 1

    # 3) Fractional offsets
    alpha = i_f - i0
    beta  = j_f - j0

    # 4) Clamp to grid boundaries
    i0 = cp.clip(i0, 0, Nx - 1)
    i1 = cp.clip(i1, 0, Nx - 1)
    j0 = cp.clip(j0, 0, Ny - 1)
    j1 = cp.clip(j1, 0, Ny - 1)

    # 5) Gather corner values
    c00 = grid[i0, j0]
    c10 = grid[i1, j0]
    c01 = grid[i0, j1]
    c11 = grid[i1, j1]

    # 6) Combine
    val = (c00 * (1 - alpha) * (1 - beta) +
           c10 *       alpha  * (1 - beta) +
           c01 * (1 - alpha) *       beta  +
           c11 *       alpha  *       beta)

    return val

def bilinear_interpolation_2d(grid, xarr, yarr, L, dx):
    """
    Bilinear interpolation of `grid` at coordinates (xarr, yarr).
    
    Parameters:
    - grid (cp.ndarray): Scalar field to interpolate, shape (Nx, Ny).
    - xarr (cp.ndarray): Particle x-positions, shape (N_particles,).
    - yarr (cp.ndarray): Particle y-positions, shape (N_particles,).
    - L (float): Box size (0 <= x, y < L).
    - dx (float): Grid spacing.
    
    Returns:
    - cp.ndarray: Interpolated values at particle positions, shape (N_particles,).
    """
    Nx, Ny = grid.shape

    # 1) Ensure positions are within [0, L) using modulo for periodicity
    xarr = xarr % L
    yarr = yarr % L

    # 2) Convert particle positions to fractional grid indices
    i_f = xarr / dx
    j_f = yarr / dx

    # 3) Integer grid indices
    i0 = cp.floor(i_f).astype(cp.int32) % Nx
    j0 = cp.floor(j_f).astype(cp.int32) % Ny

    # 4) Neighboring grid indices with periodic wrapping
    i1 = (i0 + 1) % Nx
    j1 = (j0 + 1) % Ny

    # 5) Fractional parts
    alpha = i_f - cp.floor(i_f)
    beta = j_f - cp.floor(j_f)

    # 6) Gather corner values
    c00 = grid[i0, j0]
    c10 = grid[i1, j0]
    c01 = grid[i0, j1]
    c11 = grid[i1, j1]

    # 7) Perform bilinear interpolation
    val = (c00 * (1 - alpha) * (1 - beta) +
           c10 * alpha * (1 - beta) +
           c01 * (1 - alpha) * beta +
           c11 * alpha * beta)

    return val


def initialize_particles():
    print("Initializing fluid 1 particles...")
    # Initialize fluid 1 particles (galaxy)
    print("Initializing fluid 1 particles...")
    # Parameters for Fluid 1 (Galaxy)


    galaxy_stars = []
    count = 0

    # Generate positions for Fluid 1 particles
    while count < N1:
        u = np.random.rand()
        radius = -R_d * np.log(1 - u * (1 - np.exp(-r1 / R_d)))  # Ensure R <= r1

        if radius <= r1:
            angle = np.random.uniform(0, 2 * np.pi)
            posX = radius * np.cos(angle)
            posY = radius * np.sin(angle)
            galaxy_stars.append((posX, posY))
            count += 1

    galaxy_stars = np.array(galaxy_stars)
    print("galaxy_stars",galaxy_stars.shape)   
    pos1 = galaxy_stars + np.array([L / 2, L / 2])
    pos1_gpu = cp.asarray(pos1)
    print("Fluid 1 particles initialized.")



    print("Depositing fluid 1 density onto grid...")
    density1 = particle_to_mesh(pos1_gpu, m1)
    print("Fluid 1 density deposited.")

    print("Computing gravitational potential for fluid 1...")
    Phi1 = compute_potential_from_density(-density1, dx, G)

    print("Gravitational potential for fluid 1 computed.")

    print("Initializing fluid 2 particles...")
    print("Initializing fluid 2 particles...")
    # Initialize fluid 2 particles uniformly across the domain, excluding a central hole
    pos2_gpu = cp.random.uniform(0, L, (N2, 2))
    pos2Boundaries = pos2_gpu.copy()
    """hole_radius = r1 * 1.2
    distance_squared = cp.sum((pos2_gpu - L / 2) ** 2, axis=1)
    mask = distance_squared >= hole_radius ** 2
    pos2_gpu = pos2_gpu[mask]"""
    print("pos2Shape",pos2_gpu.shape)

    # Initialize velocities for fluid 2
    sigma = init_vel2 / cp.sqrt(2)
    vel2_gpu = cp.random.normal(0, sigma, size=(pos2_gpu.shape[0], 2))
    vel2Boundary= vel2_gpu.copy()
    print("Fluid 2 particles initialized.")

    print("Depositing fluid 2 density onto grid...")
    density2 = particle_to_mesh(pos2_gpu, -m2)
    #density2 = apply_boundary_conditions_density(density2, -Rohdensity2)
    print("Fluid 2 density deposited and boundary conditions applied.")

    print("Computing gravitational potential for fluid 2...")
    Phi2 = compute_potential_from_density(density2, dx, G)
    print("Gravitational potential for fluid 2 computed.")

    print("Starting backward integration for fluid 2...")
    max_iterations = 25
    tolerance = 1e-4
    dt_back = -dt * 2 # Negative time step for backward integration
    integration_time = 200
    steps = int(abs(integration_time / dt_back))


    for iteration in range(max_iterations):

        # Perform backward integration focusing on boundary-initiated particles
        print(f"Starting backward integration for fluid 2, iteration {iteration + 1}...")
        pos2_gpu, vel2_gpu = backward_integrate(pos2_gpu, vel2_gpu, Phi1,Phi2, dx, dt_back, steps,boundary_injection_func ,sigma,init_vel2)


        # Compute new density
        density2_new = particle_to_mesh(pos2_gpu, -m2)

        # Apply boundary conditions on density and velocity
        #density2_new = apply_boundary_conditions_density(density2_new, -Rohdensity2)
        #vel2_gpu = apply_boundary_conditions_velocity(vel2Boundary, pos2_gpu, init_vel2)

        # Compute new potential
        Phi2_new = compute_potential_from_density(density2_new, dx, G)

        # Check for convergence based on potential change
        delta_Phi = cp.max(cp.abs(Phi2_new - Phi2))

        print(f"Iteration {iteration + 1}, delta_Phi: {delta_Phi}")

        if delta_Phi < tolerance:
            print(f"Converged after {iteration + 1} iterations.")
            Phi2 = Phi2_new
            density2 = density2_new
            break

        # Update for next iteration
        # Update Phi2 with under-relaxation
        relaxation_factor = 0.6  # Adjust as needed
        Phi2 = relaxation_factor * Phi2_new + (1 - relaxation_factor) * Phi2
        
        density2 = density2_new

    else:
        print("Maximum iterations reached without convergence.")


    print("Backward integration for fluid 2 completed.")

    
    # Physical parameters for truncated OM    from test , r1=55 ; ra1 = 60 ; T1 = 0.15*m1 , M1 = 0.5e11 ; N1 = 1000000 , V_circ=0 produce steady steate with no rotation
    T1   = 0.063 * m1        # "kT" in (km/s)^2
    ra1  = r1*0.6          # anisotropy radius
    f0_1 = 1.0           # initial guess for normalization
    w    = 0.02        # relaxation factor in potential iteration
    v0=30
    R0=ra1*3

    tol        = 1e-5      # potential convergence tolerance
    mass_tol   = 1e-5      # mass convergence tolerance
    max_iter   = 150
    convergence = False

    # 1) Deposit initial density onto mesh and compute potential
    density1 = particle_to_mesh(pos1_gpu, m1)
    print("Fluid 1 density deposited.")

    Phi1 = compute_potential_from_density(density1, dx, G)
    phi_total = Phi1 - Phi2  # example if you have a second fluid or external potential

    # 2) Iterative solver for strictly truncated OM distribution
    x = cp.linspace(0, L, N_grid, endpoint=False) + dx/2
    y = cp.linspace(0, L, N_grid, endpoint=False) + dx/2
    X, Y = cp.meshgrid(x, y, indexing='ij')
    R_grid = cp.sqrt((X - L/2)**2 + (Y - L/2)**2) + 1e-10

    for iteration in range(max_iter):
        # Relative potential
        Psi = -phi_total
        
        # We only allow Psi>0 to contribute to density:
        #   if Psi <= 0 => rho = 0,
        #   else => [2π f0 (kT/m) * (exp(m Psi/kT) - 1)] / sqrt(1 + (R/ra)^2)
        
        exp_term1   = cp.exp((m1 / T1) * Psi)
        exp_minus_1 = exp_term1 - 1.0
        
        truncated_core = cp.where(Psi > 0.0, exp_minus_1, 0.0)
        
        # Final truncated density expression
        rho1_new = (2.0 * cp.pi * f0_1 * T1 / m1) \
                * truncated_core \
                / cp.sqrt(1.0 + (R_grid / ra1)**2)
        
        # Solve Poisson's equation
        phi_new = compute_potential_from_density(rho1_new, dx, G)
        phi_new = phi_new - Phi2  # Subtract external or other fluid's potential if needed
        
        # Check potential convergence
        delta_phi = cp.max(cp.abs(phi_new - phi_total))
        print(f"Iteration {iteration+1}: delta_phi = {delta_phi:.3e}")
        
        if delta_phi < tol:
            print(f"Potential converged after {iteration+1} iterations.")
            phi_total = phi_new

            convergence = True
            break
        
        # Relaxation update
        phi_total = w * phi_new + (1.0 - w) * phi_total
        phi_offset = cp.max(phi_total)
        phi_total -= phi_offset
        
        # Enforce total mass by rescaling f0_1
        mass1_current = cp.sum(rho1_new) * (dx**2)
        print(f"Iteration {iteration+1}: mass1_current = {mass1_current:.3e}")
        
        if mass1_current > 0:
            scale_f0_1 = M1 / mass1_current
        else:
            scale_f0_1 = 1.0
            print("Warning: mass1_current is zero or negative.")
        
        f0_1 *= scale_f0_1
        print(f"Iteration {iteration+1}: scale_f0_1 = {scale_f0_1:.3e}")
        
        # Optional: check mass convergence
        mass_diff1 = abs(mass1_current - M1)/M1
        if (mass_diff1 < mass_tol) and (delta_phi < tol):
            print("Mass and potential converged sufficiently.")
            break

    if not convergence:
        print("Warning: Did not fully converge within max iterations.")

    

    # After convergence, compute final truncated density
    Psi_final = -phi_total
    exp_term1 = cp.exp((m1 / T1) * Psi_final)
    exp_minus_1 = exp_term1 - 1.0
    truncated_core = cp.where(Psi_final > 0.0, exp_minus_1, 0.0)

    rho1_final = (2.0 * cp.pi * f0_1 * T1 / m1) \
                * truncated_core \
                / cp.sqrt(1.0 + (R_grid / ra1)**2)

    # 3) Sample final positions from strictly truncated density
    pos1 = sample_positions_from_density(rho1_final, dx, L, N1)
    pos1_gpu = cp.asarray(pos1)
    density1 = particle_to_mesh(pos1_gpu, m1)
    print("Computing gravitational potential for fluid 1...")
    Phi1 = compute_potential_from_density(-density1, dx, G)

    for iteration in range(max_iterations):

        # Perform backward integration focusing on boundary-initiated particles
        print(f"Starting backward integration for fluid 2, iteration {iteration + 1}...")
        pos2_gpu, vel2_gpu = backward_integrate(pos2_gpu, vel2_gpu, Phi1,Phi2, dx, dt_back, steps,boundary_injection_func ,sigma,init_vel2)


        # Compute new density
        density2_new = particle_to_mesh(pos2_gpu, -m2)

        # Apply boundary conditions on density and velocity
        #density2_new = apply_boundary_conditions_density(density2_new, -Rohdensity2)
        #vel2_gpu = apply_boundary_conditions_velocity(vel2Boundary, pos2_gpu, init_vel2)

        # Compute new potential
        Phi2_new = compute_potential_from_density(density2_new, dx, G)

        # Check for convergence based on potential change
        delta_Phi = cp.max(cp.abs(Phi2_new - Phi2))

        print(f"Iteration {iteration + 1}, delta_Phi: {delta_Phi}")

        if delta_Phi < tolerance:
            print(f"Converged after {iteration + 1} iterations.")
            Phi2 = Phi2_new
            density2 = density2_new
            break

        # Update for next iteration
        # Update Phi2 with under-relaxation
        relaxation_factor = 0.5  # Adjust as needed
        Phi2 = relaxation_factor * Phi2_new + (1 - relaxation_factor) * Phi2
        density2 = density2_new

    else:
        print("Maximum iterations reached without convergence.")


    print("Backward integration for fluid 2 completed.")

    Phi1 = compute_potential_from_density(density1, dx, G)
    phi_total = Phi1 - Phi2  # example if you have a second fluid or external potential
    phi_offset = cp.max(phi_total)
    phi_total -= phi_offset

    w    = 0.01
    # 2) Iterative solver for strictly truncated OM distribution
    x = cp.linspace(0, L, N_grid, endpoint=False) + dx/2
    y = cp.linspace(0, L, N_grid, endpoint=False) + dx/2
    X, Y = cp.meshgrid(x, y, indexing='ij')
    R_grid = cp.sqrt((X - L/2)**2 + (Y - L/2)**2) + 1e-10

    for iteration in range(max_iter):
        # Relative potential
        Psi = -phi_total
        
        # We only allow Psi>0 to contribute to density:
        #   if Psi <= 0 => rho = 0,
        #   else => [2π f0 (kT/m) * (exp(m Psi/kT) - 1)] / sqrt(1 + (R/ra)^2)
        
        exp_term1   = cp.exp((m1 / T1) * Psi)
        exp_minus_1 = exp_term1 - 1.0
        
        truncated_core = cp.where(Psi > 0.0, exp_minus_1, 0.0)
        
        # Final truncated density expression
        rho1_new = (2.0 * cp.pi * f0_1 * T1 / m1) \
                * truncated_core \
                / cp.sqrt(1.0 + (R_grid / ra1)**2)
        
        # Solve Poisson's equation
        phi_new = compute_potential_from_density(rho1_new, dx, G)
        phi_new = phi_new - Phi2  # Subtract external or other fluid's potential if needed
        
        # Check potential convergence
        delta_phi = cp.max(cp.abs(phi_new - phi_total))
        print(f"Iteration {iteration+1}: delta_phi = {delta_phi:.3e}")
        
        if delta_phi < tol:
            print(f"Potential converged after {iteration+1} iterations.")
            phi_total = phi_new

            convergence = True
            break
        
        # Relaxation update
        phi_total = w * phi_new + (1.0 - w) * phi_total
        phi_offset = cp.max(phi_total)
        phi_total -= phi_offset
        
        # Enforce total mass by rescaling f0_1
        mass1_current = cp.sum(rho1_new) * (dx**2)
        print(f"Iteration {iteration+1}: mass1_current = {mass1_current:.3e}")
        
        if mass1_current > 0:
            scale_f0_1 = M1 / mass1_current
        else:
            scale_f0_1 = 1.0
            print("Warning: mass1_current is zero or negative.")
        
        f0_1 *= scale_f0_1
        print(f"Iteration {iteration+1}: scale_f0_1 = {scale_f0_1:.3e}")
        
        # Optional: check mass convergence
        mass_diff1 = abs(mass1_current - M1)/M1
        if (mass_diff1 < mass_tol) and (delta_phi < tol):
            print("Mass and potential converged sufficiently.")
            break

    if not convergence:
        print("Warning: Did not fully converge within max iterations.")

    

    # After convergence, compute final truncated density
    Psi_final = -phi_total
    exp_term1 = cp.exp((m1 / T1) * Psi_final)
    exp_minus_1 = exp_term1 - 1.0
    truncated_core = cp.where(Psi_final > 0.0, exp_minus_1, 0.0)

    rho1_final = (2.0 * cp.pi * f0_1 * T1 / m1) \
                * truncated_core \
                / cp.sqrt(1.0 + (R_grid / ra1)**2)

    # 3) Sample final positions from strictly truncated density
    pos1 = sample_positions_from_density(rho1_final, dx, L, N1)
    pos1_gpu = cp.asarray(pos1)

    N_particles = pos1_gpu.shape[0]

    # 4) Velocity sampling with strict truncation (Q >= 0 => accept, else reject)
    # Positions relative to center
    # Positions relative to center
    x_part = pos1_gpu[:, 0] - L / 2
    y_part = pos1_gpu[:, 1] - L / 2
    R_part = cp.sqrt(x_part**2 + y_part**2)
    theta_part = cp.arctan2(y_part, x_part)

    # Interpolate Psi on the grid to get local potential at each particle
    Psi_local = bilinear_interpolation_2d(Psi_final, pos1_gpu[:, 0], pos1_gpu[:, 1], L, dx)

    # Identify bound region (Psi > 0)
    valid_mask = (Psi_local > 0.0)

    # Anisotropy factor alpha(R)
    alpha_part = 1.0 + (R_part**2) / (ra1**2)

    # Prepare arrays for velocities
    v_x = cp.zeros_like(x_part)
    v_y = cp.zeros_like(y_part)

    # Physical thermal speed for the radial direction
    sigma = cp.sqrt(T1 / m1)  # same for u and w

    # Subset of valid (bound) particles
    idx_valid = cp.where(valid_mask)[0]
    N_valid = len(idx_valid)
    print("N_valid:", N_valid)

    if N_valid > 0:
        # Extract local arrays for just the bound subset
        Psi_b = Psi_local[idx_valid]        # shape (N_valid,)
        alpha_b = alpha_part[idx_valid]    # shape (N_valid,)
        theta_b = theta_part[idx_valid]    # shape (N_valid,)
        R_b = R_part[idx_valid]            # shape (N_valid,)

        # 1) Maximum radius in (u,w)-space is r_max = sqrt(2 * Psi_b)
        r_max_b = cp.sqrt(2.0 * Psi_b)  # shape (N_valid,)

        # 2) Sample r from the distribution:
        #       p(r) ~ r * exp(-r^2 / (2 * sigma^2))  for 0 <= r <= r_max
        #    The CDF up to r is:
        #       CDF(r) = 1 - exp(-r^2 / (2 * sigma^2))
        #    The truncated CDF up to r_max is:
        #       CDF_max = 1 - exp(- (r_max^2) / (2 * sigma^2))
        #    For a random z in [0,1], set partial = z * CDF_max
        #    Then, r^2 = -2 * sigma^2 * ln(1 - partial)

        # Precompute the max of the CDF
        cdf_max_b = 1.0 - cp.exp(-0.5 * (r_max_b**2) / (sigma**2))

        # Sample uniform z in [0,1]
        z = cp.random.random(size=N_valid)  # shape (N_valid,)
        # Avoid log(0)
        partial = z * cdf_max_b
        partial = cp.clip(partial, 1e-30, 1.0 - 1e-30)

        # Invert the CDF to get r^2
        r_sq = -2.0 * sigma**2 * cp.log(1.0 - partial)
        # Clip to ensure r <= r_max
        r_sq = cp.clip(r_sq, 0.0, (r_max_b**2) + 1e-14)
        r = cp.sqrt(r_sq)  # shape (N_valid,)

        # 3) Sample a random angle phi in [0, 2 pi)
        phi = 2.0 * np.pi * cp.random.random(size=N_valid)

        # 4) Convert (r, phi) to (u, w)
        u_b = r * cp.cos(phi)
        w_b = r * cp.sin(phi)

        # 5) Map to (v_R, v_theta)
        #    v_R = u
        #    v_theta = w / sqrt(alpha)
        v_R_b = u_b
        v_theta_b = w_b / cp.sqrt(alpha_b)

        # 6) Convert (v_R, v_theta) to (v_x, v_y)
        #    v_x = v_R cos(theta) - v_theta sin(theta)
        #    v_y = v_R sin(theta) + v_theta cos(theta)
        v_x_b = v_R_b * cp.cos(theta_b) - v_theta_b * cp.sin(theta_b)
        v_y_b = v_R_b * cp.sin(theta_b) + v_theta_b * cp.cos(theta_b)

        # Assign velocities to valid particles
        v_x[idx_valid] = v_x_b
        v_y[idx_valid] = v_y_b

        # Transfer velocities and radii to CPU for plotting
        v_R_b_cpu = v_R_b.get()
        v_theta_b_cpu = v_theta_b.get()
        R_b_cpu = R_b.get()

        # Define number of radial bins and maximum radius for plotting
        n_radial_bins = 20
        max_radius = r1  # Set to desired maximum radius for plotting
        radial_bins = np.linspace(0, max_radius, n_radial_bins + 1)

        # Digitize the radii to assign particles to bins
        bin_indices = np.digitize(R_b_cpu, bins=radial_bins) - 1  # Bin indices start at 0

        # Mask to include only particles within the desired radius
        within_radius_mask = (R_b_cpu <= max_radius)
        bin_indices = bin_indices[within_radius_mask]
        v_R_b_cpu = v_R_b_cpu[within_radius_mask]
        v_theta_b_cpu = v_theta_b_cpu[within_radius_mask]
        R_b_cpu = R_b_cpu[within_radius_mask]

        bin_centers = (radial_bins[:-1] + radial_bins[1:]) / 2

        # Initialize arrays to hold binned averages
        v_R_binned = np.zeros(n_radial_bins)
        v_theta_binned = np.zeros(n_radial_bins)
        counts = np.zeros(n_radial_bins)

        # Compute binned averages
        for b in range(n_radial_bins):
            in_bin = bin_indices == b
            count = np.sum(in_bin)
            if count > 0:
                v_R_binned[b] = np.mean(v_R_b_cpu[in_bin])
                v_theta_binned[b] = np.mean(v_theta_b_cpu[in_bin])
                counts[b] = count
            else:
                v_R_binned[b] = np.nan
                v_theta_binned[b] = np.nan

        # Handle potential NaNs by removing them
        valid_bins = ~np.isnan(v_R_binned) & ~np.isnan(v_theta_binned)
        bin_centers = bin_centers[valid_bins]
        v_R_binned = v_R_binned[valid_bins]
        v_theta_binned = v_theta_binned[valid_bins]

        # Plotting v_R and v_theta vs radius
        fig, axs = plt.subplots(2, 1, figsize=(10, 12))

        # Plot v_R and v_theta vs radius
        axs[0].plot(bin_centers, v_R_binned, label=r'$v_R$', color='blue')
        axs[0].plot(bin_centers, v_theta_binned, label=r'$v_\theta$', color='red')
        axs[0].set_xlabel('Radius (units)')
        axs[0].set_ylabel('Velocity')
        axs[0].set_title(r'Velocity Components $v_R$ and $v_\theta$ vs Radius')
        axs[0].legend()
        axs[0].grid(True)

        # Compute and plot velocity dispersions
        sigma_r_sampled = np.zeros(n_radial_bins)
        sigma_t_sampled = np.zeros(n_radial_bins)

        for b in range(n_radial_bins):
            in_bin = bin_indices == b
            count = np.sum(in_bin)
            if count > 0:
                sigma_r_sampled[b] = np.std(v_R_b_cpu[in_bin])
                sigma_t_sampled[b] = np.std(v_theta_b_cpu[in_bin])
            else:
                sigma_r_sampled[b] = np.nan
                sigma_t_sampled[b] = np.nan

        # Remove NaNs
        sigma_r_sampled = sigma_r_sampled[valid_bins]
        sigma_t_sampled = sigma_t_sampled[valid_bins]

        # Plot sigma_r and sigma_t
        axs[1].plot(bin_centers, sigma_r_sampled, label=r'$\sigma_r$ (Sampled)', color='green')
        axs[1].plot(bin_centers, sigma_t_sampled, label=r'$\sigma_t$ (Sampled)', color='orange')
        axs[1].set_xlabel('Radius (units)')
        axs[1].set_ylabel('Velocity Dispersion')
        axs[1].set_title(r'Velocity Dispersions $\sigma_r$ and $\sigma_t$ vs Radius')
        axs[1].legend()
        axs[1].grid(True)

        plt.tight_layout()
        plt.show()

        print("Plots generated successfully.")

    # The rest (Psi <= 0) remain zero velocity by default
    v_x_final = v_x
    v_y_final = v_y



    print(v_x_final)
    print(v_y_final)
    print("pos",pos1_gpu)


    density1 = particle_to_mesh(pos1_gpu, m1 )
    density2 = particle_to_mesh(pos2_gpu, m2 )
    total_density = density1 + density2
    force_grid = solve_poisson(total_density, dx, G)  # solve_poisson returns forces from density
    vel1 = cp.zeros((N1, 2))


    # Set circular velocities for galaxy particles from potential
    print("start circular velocity")
    vel1_circ = calculate_velocities(pos1_gpu, force_grid, dx, L, r1/3)
    print(vel1_circ)
   

    

    vel1 = cp.column_stack((v_x_final, v_y_final))*0.5+ vel1_circ*1.5
    #vel1 = cp.column_stack((v_x_final, v_y_final))

    """ bar spiral galaxy
    vel1_circ= calculate_velocities(pos1_gpu, force_grid, dx, L, r1/3)*0.95
   

    

    #vel1 = cp.column_stack((v_x_final, v_y_final))*0.01+ cp.column_stack((v_add_x, v_add_y))
    vel1 = vel1_circ+cp.column_stack((v_x_final, v_y_final))*0.001
    
    """

   


    return pos1_gpu, vel1, pos2_gpu, vel2_gpu


def initialize_particles33():
    print("Initializing fluid 1 particles...")
    # Initialize fluid 1 particles (galaxy)
    print("Initializing fluid 1 particles...")
    # Parameters for Fluid 1 (Galaxy)


    galaxy_stars = []
    count = 0

    # Generate positions for Fluid 1 particles
    while count < N1:
        u = np.random.rand()
        radius = -R_d * np.log(1 - u * (1 - np.exp(-r1 / R_d)))  # Ensure R <= r1

        if radius <= r1:
            angle = np.random.uniform(0, 2 * np.pi)
            posX = radius * np.cos(angle)
            posY = radius * np.sin(angle)
            galaxy_stars.append((posX, posY))
            count += 1

    galaxy_stars = np.array(galaxy_stars)
    print("galaxy_stars",galaxy_stars.shape)   
    pos1 = galaxy_stars + np.array([L / 2, L / 2])
    pos1_gpu = cp.asarray(pos1)
    print("Fluid 1 particles initialized.")



    print("Depositing fluid 1 density onto grid...")
    density1 = particle_to_mesh(pos1_gpu, m1*0.95)
    print("Fluid 1 density deposited.")

    print("Computing gravitational potential for fluid 1...")
    Phi1 = compute_potential_from_density(-density1, dx, G)

    print("Gravitational potential for fluid 1 computed.")

    print("Initializing fluid 2 particles...")
    print("Initializing fluid 2 particles...")
    # Initialize fluid 2 particles uniformly across the domain, excluding a central hole
    pos2_gpu = cp.random.uniform(0, L, (N2, 2))
    pos2Boundaries = pos2_gpu.copy()
    """hole_radius = r1 * 1.2
    distance_squared = cp.sum((pos2_gpu - L / 2) ** 2, axis=1)
    mask = distance_squared >= hole_radius ** 2
    pos2_gpu = pos2_gpu[mask]"""
    print("pos2Shape",pos2_gpu.shape)

    # Initialize velocities for fluid 2
    sigma = init_vel2 / cp.sqrt(2)
    vel2_gpu = cp.random.normal(0, sigma, size=(pos2_gpu.shape[0], 2))
    vel2Boundary= vel2_gpu.copy()
    print("Fluid 2 particles initialized.")

    print("Depositing fluid 2 density onto grid...")
    density2 = particle_to_mesh(pos2_gpu, -m2)
    #density2 = apply_boundary_conditions_density(density2, -Rohdensity2)
    print("Fluid 2 density deposited and boundary conditions applied.")

    print("Computing gravitational potential for fluid 2...")
    Phi2 = compute_potential_from_density(density2, dx, G)
    print("Gravitational potential for fluid 2 computed.")

    print("Starting backward integration for fluid 2...")
    max_iterations = 200
    tolerance = 1e-4
    dt_back = -dt * 1 # Negative time step for backward integration
    integration_time = 50
    steps = int(abs(integration_time / dt_back))


    for iteration in range(max_iterations):

        # Perform backward integration focusing on boundary-initiated particles
        print(f"Starting backward integration for fluid 2, iteration {iteration + 1}...")
        pos2_gpu, vel2_gpu = backward_integrate(pos2_gpu, vel2_gpu, Phi1,Phi2, dx, dt_back, steps,boundary_injection_func ,sigma,init_vel2)


        # Compute new density
        density2_new = particle_to_mesh(pos2_gpu, -m2)

        # Apply boundary conditions on density and velocity
        #density2_new = apply_boundary_conditions_density(density2_new, -Rohdensity2)
        #vel2_gpu = apply_boundary_conditions_velocity(vel2Boundary, pos2_gpu, init_vel2)

        # Compute new potential
        Phi2_new = compute_potential_from_density(density2_new, dx, G)

        # Check for convergence based on potential change
        delta_Phi = cp.max(cp.abs(Phi2_new - Phi2))

        print(f"Iteration {iteration + 1}, delta_Phi: {delta_Phi}")

        if delta_Phi < tolerance:
            print(f"Converged after {iteration + 1} iterations.")
            Phi2 = Phi2_new
            density2 = density2_new
            break

        # Update for next iteration
        # Update Phi2 with under-relaxation
        relaxation_factor = 0.8  # Adjust as needed
        Phi2 = relaxation_factor * Phi2_new + (1 - relaxation_factor) * Phi2
        density2 = density2_new

    else:
        print("Maximum iterations reached without convergence.")


    print("Backward integration for fluid 2 completed.")

    
    # Physical parameters for truncated OM
    T1   = 0.2 * m1        # "kT" in (km/s)^2
    ra1  = 2             # anisotropy radius
    f0_1 = 1.0             # initial guess for normalization
    w    = 0.04            # relaxation factor in potential iteration
    v0=30
    R0=ra1*3

    tol        = 1e-5      # potential convergence tolerance
    mass_tol   = 1e-5      # mass convergence tolerance
    max_iter   = 600
    convergence = False

    # 1) Deposit initial density onto mesh and compute potential
    density1 = particle_to_mesh(pos1_gpu, m1)
    print("Fluid 1 density deposited.")

    Phi1 = compute_potential_from_density(density1, dx, G)
    phi_total = Phi1 - Phi2  # example if you have a second fluid or external potential

    # 2) Iterative solver for strictly truncated OM distribution
    x = cp.linspace(0, L, N_grid, endpoint=False) + dx/2
    y = cp.linspace(0, L, N_grid, endpoint=False) + dx/2
    X, Y = cp.meshgrid(x, y, indexing='ij')
    R_grid = cp.sqrt((X - L/2)**2 + (Y - L/2)**2) + 1e-10

    for iteration in range(max_iter):
        # Relative potential
        Psi = -phi_total
        
        # We only allow Psi>0 to contribute to density
        # If Psi <= 0, then density = 0
        # Otherwise, density = [2π f0 (kT/m) (exp(m Psi/kT) - 1)] / sqrt(1 + (R/ra)^2)
        
        # Precompute exp(m Psi / T) where positive
        exp_term1 = cp.exp((m1 / T1) * Psi)
        exp_minus_1 = exp_term1 - 1.0
        
        # Strictly truncated: set negative or zero values to zero
        # i.e., if Psi <= 0 => 0, else => exp_term1 - 1
        truncated_core = cp.where(Psi > 0.0, exp_minus_1, 0.0)
        
        # Final density expression
        rho1_new = (2.0 * cp.pi * f0_1 * T1 / m1) \
                * truncated_core \
                / cp.sqrt(1.0 + (R_grid/ra1)**2)
        
        # Solve Poisson's equation
        phi_new = compute_potential_from_density(rho1_new, dx, G)
        phi_new = phi_new - Phi2  # or += -Phi2, depending on sign conventions
        
        # Check potential convergence
        delta_phi = cp.max(cp.abs(phi_new - phi_total))
        print(f"Iteration {iteration+1}: delta_phi = {delta_phi:.3e}")
        
        if delta_phi < tol:
            print(f"Potential converged after {iteration+1} iterations.")
            phi_total = phi_new
            convergence = True
            break
        
        # Relaxation update
        phi_total = w * phi_new + (1.0 - w) * phi_total
        
        # Enforce total mass by rescaling f0_1
        mass1_current = cp.sum(rho1_new) * dx**2
        print(f"Iteration {iteration+1}: mass1_current = {mass1_current:.3e}")
        
        if mass1_current > 0:
            scale_f0_1 = M1 / mass1_current
        else:
            scale_f0_1 = 1.0
            print("Warning: mass1_current is zero or negative.")
        
        f0_1 *= scale_f0_1
        print(f"Iteration {iteration+1}: scale_f0_1 = {scale_f0_1:.3e}")
        
        # Optional: check mass convergence
        mass_diff1 = abs(mass1_current - M1)/M1
        if (mass_diff1 < mass_tol) and (delta_phi < tol):
            print("Mass and potential converged sufficiently.")
            break

    if not convergence:
        print("Warning: Did not fully converge within max iterations.")

    # After convergence, compute final truncated density
    Psi_final = -phi_total
    exp_term1 = cp.exp((m1 / T1) * Psi_final)
    exp_minus_1 = exp_term1 - 1.0
    truncated_core = cp.where(Psi_final > 0.0, exp_minus_1, 0.0)

    rho1_final = (2.0 * cp.pi * f0_1 * T1 / m1) \
                * truncated_core \
                / cp.sqrt(1.0 + (R_grid / ra1)**2)

    # 3) Sample final positions from strictly truncated density
    pos1 = sample_positions_from_density(rho1_final, dx, L, N1)
    pos1_gpu = cp.asarray(pos1)

    # 4) Velocity sampling with strict truncation (Q >= 0 => accept, else reject)
    x_part_all = pos1_gpu[:, 0] - L/2
    y_part_all = pos1_gpu[:, 1] - L/2
    R_part_all = cp.sqrt(x_part_all**2 + y_part_all**2)
    theta_part_all = cp.arctan2(y_part_all, x_part_all)

    # Osipkov-Merritt velocity dispersions
    sigma_vR = cp.sqrt(T1 / m1)
    alpha_part_all = 1.0 + (R_part_all**2 / ra1**2)
    sigma_vtheta_all = cp.sqrt(T1 / (m1 * alpha_part_all))

    N_needed = pos1_gpu.shape[0]
    oversample_factor = 2

    accepted_x  = cp.array([], dtype=cp.float64)
    accepted_y  = cp.array([], dtype=cp.float64)
    accepted_vx = cp.array([], dtype=cp.float64)
    accepted_vy = cp.array([], dtype=cp.float64)

    # Helper to map (x,y) -> grid indices -> Psi_final
    def R_grid_index(xarr, yarr):
        i = cp.floor((xarr + L/2)/dx).astype(cp.int32)
        j = cp.floor((yarr + L/2)/dx).astype(cp.int32)
        return i, j

    while accepted_vx.size < N_needed:
        N_sample = (N_needed - accepted_vx.size) * oversample_factor
        
        # Randomly pick positions for velocity draws
        idx = cp.random.randint(0, N_needed, size=N_sample)
        x_sample     = x_part_all[idx]
        y_sample     = y_part_all[idx]
        R_sample     = R_part_all[idx]
        theta_sample = theta_part_all[idx]

        # Local velocity dispersions
        sigma_vR_sample = sigma_vR
        alpha_part_sample = 1.0 + (R_sample**2 / ra1**2)
        sigma_vtheta_sample = cp.sqrt(T1 / (m1 * alpha_part_sample))

        # Sample velocities from anisotropic Maxwellian
        v_R     = cp.random.normal(0.0, sigma_vR_sample,     size=N_sample)
        v_theta = cp.random.normal(0.0, sigma_vtheta_sample, size=N_sample)

        # Convert polar velocities to Cartesian
        v_x = v_R * cp.cos(theta_sample) - v_theta * cp.sin(theta_sample)
        v_y = v_R * cp.sin(theta_sample) + v_theta * cp.cos(theta_sample)

        # Interpolate Psi at each (x_sample, y_sample)
        i_grid, j_grid = R_grid_index(x_sample, y_sample)
        Psi_sample = Psi_final[i_grid, j_grid]

        # Compute Q = Psi - 1/2 (v_R^2) - 1/2 (1 + R^2/ra^2)(v_theta^2)
        Q_sample = Psi_sample \
                - 0.5*v_R**2 \
                - 0.5*(1.0 + (R_sample**2 / ra1**2))*v_theta**2

        # Strict truncation: accept only if Q <= 0
        mask = (Q_sample <= 0.0)

        accepted_x  = cp.concatenate((accepted_x,  x_sample[mask]))
        accepted_y  = cp.concatenate((accepted_y,  y_sample[mask]))
        accepted_vx = cp.concatenate((accepted_vx, v_x[mask]))
        accepted_vy = cp.concatenate((accepted_vy, v_y[mask]))

        # Safety catch for too many attempts
        if N_sample > 5e6:
            raise RuntimeError("Too many samples needed. Potential not deep enough or R not large enough?")

    # Truncate to exactly N_needed accepted velocities
    accepted_x  = accepted_x[:N_needed]
    accepted_y  = accepted_y[:N_needed]
    accepted_vx = accepted_vx[:N_needed]
    accepted_vy = accepted_vy[:N_needed]

    v_x_final = accepted_vx
    v_y_final = accepted_vy
     # 5) Optionally add rotation

    density1 = particle_to_mesh(pos1_gpu, m1 )
    density2 = particle_to_mesh(pos2_gpu, m2 )
    total_density = density1 + density2
    force_grid = solve_poisson(total_density, dx, G)  # solve_poisson returns forces from density
    vel1 = cp.zeros((N1, 2))
    #vel1 = vel1_circ+cp.column_stack((v_x_final, v_y_final))

    # Set circular velocities for galaxy particles from potential
    print("start circular velocity")
    vel1_circ= calculate_velocities(pos1_gpu, force_grid, dx, L, r1/3)*0
   

    

    #vel1 = cp.column_stack((v_x_final, v_y_final))*0.01+ cp.column_stack((v_add_x, v_add_y))
    vel1 = vel1_circ+cp.column_stack((v_x_final, v_y_final))*1

    """ bar spiral galaxy
    vel1_circ= calculate_velocities(pos1_gpu, force_grid, dx, L, r1/3)*0.95
   

    

    #vel1 = cp.column_stack((v_x_final, v_y_final))*0.01+ cp.column_stack((v_add_x, v_add_y))
    vel1 = vel1_circ+cp.column_stack((v_x_final, v_y_final))*0.001
    
    """

   


    return pos1_gpu, vel1, pos2_gpu, vel2_gpu




def initialize_particles122():
    print("Initializing fluid 1 particles...")
    # Initialize fluid 1 particles (galaxy)
    print("Initializing fluid 1 particles...")
    # Parameters for Fluid 1 (Galaxy)


    galaxy_stars = []
    count = 0

    # Generate positions for Fluid 1 particles
    while count < N1:
        u = np.random.rand()
        radius = -R_d * np.log(1 - u * (1 - np.exp(-r1 / R_d)))  # Ensure R <= r1

        if radius <= r1:
            angle = np.random.uniform(0, 2 * np.pi)
            posX = radius * np.cos(angle)
            posY = radius * np.sin(angle)
            galaxy_stars.append((posX, posY))
            count += 1

    galaxy_stars = np.array(galaxy_stars)
    print("galaxy_stars",galaxy_stars.shape)   
    pos1 = galaxy_stars + np.array([L / 2, L / 2])
    pos1_gpu = cp.asarray(pos1)
    print("Fluid 1 particles initialized.")



    print("Depositing fluid 1 density onto grid...")
    density1 = particle_to_mesh(pos1_gpu, m1*0.9)
    print("Fluid 1 density deposited.")

    print("Computing gravitational potential for fluid 1...")
    Phi1 = compute_potential_from_density(-density1, dx, G)

    print("Gravitational potential for fluid 1 computed.")

    print("Initializing fluid 2 particles...")
    print("Initializing fluid 2 particles...")
    # Initialize fluid 2 particles uniformly across the domain, excluding a central hole
    pos2_gpu = cp.random.uniform(0, L, (N2, 2))
    pos2Boundaries = pos2_gpu.copy()
    """hole_radius = r1 * 1.2
    distance_squared = cp.sum((pos2_gpu - L / 2) ** 2, axis=1)
    mask = distance_squared >= hole_radius ** 2
    pos2_gpu = pos2_gpu[mask]"""
    print("pos2Shape",pos2_gpu.shape)

    # Initialize velocities for fluid 2
    sigma = init_vel2 / cp.sqrt(2)
    vel2_gpu = cp.random.normal(0, sigma, size=(pos2_gpu.shape[0], 2))
    vel2Boundary= vel2_gpu.copy()
    print("Fluid 2 particles initialized.")

    print("Depositing fluid 2 density onto grid...")
    density2 = particle_to_mesh(pos2_gpu, -m2)
    #density2 = apply_boundary_conditions_density(density2, -Rohdensity2)
    print("Fluid 2 density deposited and boundary conditions applied.")

    print("Computing gravitational potential for fluid 2...")
    Phi2 = compute_potential_from_density(density2, dx, G)
    print("Gravitational potential for fluid 2 computed.")

    print("Starting backward integration for fluid 2...")
    max_iterations = 10
    tolerance = 1e-4
    dt_back = -dt * 1 # Negative time step for backward integration
    integration_time = 50
    steps = int(abs(integration_time / dt_back))


    for iteration in range(max_iterations):

        # Perform backward integration focusing on boundary-initiated particles
        print(f"Starting backward integration for fluid 2, iteration {iteration + 1}...")
        pos2_gpu, vel2_gpu = backward_integrate(pos2_gpu, vel2_gpu, Phi1,Phi2, dx, dt_back, steps,boundary_injection_func ,sigma,init_vel2)


        # Compute new density
        density2_new = particle_to_mesh(pos2_gpu, -m2)

        # Apply boundary conditions on density and velocity
        #density2_new = apply_boundary_conditions_density(density2_new, -Rohdensity2)
        #vel2_gpu = apply_boundary_conditions_velocity(vel2Boundary, pos2_gpu, init_vel2)

        # Compute new potential
        Phi2_new = compute_potential_from_density(density2_new, dx, G)

        # Check for convergence based on potential change
        delta_Phi = cp.max(cp.abs(Phi2_new - Phi2))

        print(f"Iteration {iteration + 1}, delta_Phi: {delta_Phi}")

        if delta_Phi < tolerance:
            print(f"Converged after {iteration + 1} iterations.")
            Phi2 = Phi2_new
            density2 = density2_new
            break

        # Update for next iteration
        # Update Phi2 with under-relaxation
        relaxation_factor = 0.8  # Adjust as needed
        Phi2 = relaxation_factor * Phi2_new + (1 - relaxation_factor) * Phi2
        density2 = density2_new

    else:
        print("Maximum iterations reached without convergence.")


    print("Backward integration for fluid 2 completed.")

    T1 = 0.25*m1         # kT in (km/s)^2
    ra1 =20       # anisotropy radius
    f0_1 = 1
    w=0.02
    Q_smooth = 0.0000001  # a small positive number controlling smoothness
    v0=0.00001
    R0=ra1*2

    tol = 1e-5
    mass_tol = 1e-3
    max_iterations = 200
    convergence = False
    density1 = particle_to_mesh(pos1_gpu, m1)
    print("Fluid 1 density deposited.")

    print("Computing gravitational potential for fluid 1...")
    Phi1 = compute_potential_from_density(density1, dx, G)
    phi_total=Phi1-Phi2


    print("Starting iterative solver for OM distribution...")
            # --- Compute radial grid ---
    x = cp.linspace(0, L, N_grid, endpoint=False) + dx/2
    y = cp.linspace(0, L, N_grid, endpoint=False) + dx/2
    X, Y = cp.meshgrid(x, y, indexing='ij')
    R_grid = cp.sqrt((X - L/2)**2 + (Y - L/2)**2) + 1e-10

   
    for iteration in range(max_iterations):

        
        # Relative potential: Psi = -phi_total
        Psi = - phi_total
        
        # Compute new density from OM distribution
        # rho(R) = (2π f0 kT / m) * exp(m/T * Psi(R)) / sqrt(1 + (R/ra)^2)
        #exp_term1 = cp.exp((m1/T1)*Psi)
        #rho1_new = (2.0*cp.pi*f0_1*T1/m1)*exp_term1 / cp.sqrt(1+(R_grid/ra1)**2)
        exp_term1 = cp.exp((m1/T1)*Psi)
        smooth_cutoff = 0.5*(1.0 + cp.tanh((m1/T1)*Psi/Q_smooth)) 
        rho1_new = (2.0*cp.pi*f0_1*T1/m1)* (exp_term1 * smooth_cutoff) / cp.sqrt(1+(R_grid/ra1)**2)

        #rho1_new = (2.0*cp.pi*f0_1*T1/m1)*(exp_term1 - 1.0) / cp.sqrt(1+(R_grid/ra1)**2)
        
        # Subtract mean to ensure zero mean density for Poisson solver (periodic box)
        rho_bg_new = cp.mean(rho1_new)
        rho_total_new = rho1_new
        
        # Solve Poisson’s equation for new potential
        phi_new = compute_potential_from_density(rho_total_new, dx, G)
        phi_new += -Phi2  # Add fixed external potential if needed
        
        # Check for convergence in potential
        delta_phi = cp.max(cp.abs(phi_new - phi_total))
        print(f"Iteration {iteration+1}: delta_phi = {delta_phi:.3e}")
        
        if delta_phi < tol:
            print(f"Potential converged after {iteration+1} iterations.")
            phi_total = phi_new
            convergence = True
            break
        
        # Update phi for next iteration
        phi_total = w * phi_new + (1 - w) * phi_total
        
        # Compute current mass
        # mass = ∑rho(R) * dx^2
        # Note: rho1_new includes background offset we removed above. Use rho1_new directly for mass.
        mass1_current = cp.sum(rho1_new)*dx**2
        
        print(f"Iteration {iteration+1}: mass1_current = {mass1_current:.3e}")
        
        # Normalize f0_1 to match M1
        if mass1_current > 0:
            scale_f0_1 = M1 / mass1_current
        else:
            scale_f0_1 = 1.0
            print("Warning: mass1_current is zero or negative.")
        
        f0_1 *= scale_f0_1
        print(f"Iteration {iteration+1}: scale_f0_1 = {scale_f0_1:.3e}")
        ""
        
        # Check for mass convergence (optional)
        mass_diff1 = cp.abs(mass1_current - M1)/M1
        if mass_diff1 < mass_tol and delta_phi < tol:
            print("Mass and potential converged sufficiently.")
            break

    if not convergence:
        print("Warning: Did not fully converge within the maximum number of iterations.")

    # After convergence, we have final phi_total and rho1_final
    #rho1_final = (2.0*cp.pi*f0_1*T1/m1)*cp.exp((m1/T1)*(-phi_total)) / cp.sqrt(1+(R_grid/ra1)**2)
    #rho1_final = (2.0*cp.pi*f0_1*T1/m1)*( cp.exp((m1/T1)*(-phi_total)) - 1.0) / cp.sqrt(1+(R_grid/ra1)**2)
    smooth_cutoff = 0.5*(1.0 + cp.tanh((m1/T1)*(-phi_total)/Q_smooth)) 

    rho1_final = (2.0*cp.pi*f0_1*T1/m1)* (exp_term1 * smooth_cutoff) / cp.sqrt(1+(R_grid/ra1)**2)


    pos1 = sample_positions_from_density(rho1_final, dx, L, N1)
    print("Final density profile obtained.")
    # Convert positions to GPU array if not already
    # Convert positions to GPU array if not already
    pos1_gpu = cp.asarray(pos1)

    # 1) Basic data from positions
    x_part_all = pos1_gpu[:, 0] - L/2
    y_part_all = pos1_gpu[:, 1] - L/2
    R_part_all = cp.sqrt(x_part_all**2 + y_part_all**2)
    theta_part_all = cp.arctan2(y_part_all, x_part_all)

    # 2) Velocity dispersions (standard Osipkov–Merritt formula)
    sigma_vR = cp.sqrt(T1 / m1)
    alpha_part_all = 1.0 + (R_part_all**2 / ra1**2)
    sigma_vtheta_all = cp.sqrt(T1 / (m1 * alpha_part_all))

    # We'll use rejection sampling: 
    # We want a final total of N1 accepted velocities.

    N_needed = pos1_gpu.shape[0]  # number of particles == number of velocities
    oversample_factor = 2

    # Arrays to accumulate accepted samples
    accepted_x  = cp.array([], dtype=cp.float64)
    accepted_y  = cp.array([], dtype=cp.float64)
    accepted_vx = cp.array([], dtype=cp.float64)
    accepted_vy = cp.array([], dtype=cp.float64)

    # 3) Rejection sampling loop
    while accepted_vx.size < N_needed:
        # Number of velocity draws in this batch
        N_sample = (N_needed - accepted_vx.size) * oversample_factor

        # a) Randomly pick positions for these velocity samples
        idx = cp.random.randint(0, N_needed, size=N_sample)

        x_sample     = x_part_all[idx]       # (N_sample,)
        y_sample     = y_part_all[idx]
        R_sample     = R_part_all[idx]
        theta_sample = theta_part_all[idx]

        # b) Compute local velocity dispersions at each chosen R
        sigma_vR_sample     = cp.sqrt(T1 / m1)
        alpha_part_sample   = 1.0 + (R_sample**2 / ra1**2)
        sigma_vtheta_sample = cp.sqrt(T1 / (m1 * alpha_part_sample))

        # Sample velocities from anisotropic Maxwellian
        v_R     = cp.random.normal(0.0, sigma_vR_sample,    size=N_sample)
        v_theta = cp.random.normal(0.0, sigma_vtheta_sample, size=N_sample)

        # Convert polar -> Cartesian
        v_x = v_R * cp.cos(theta_sample) - v_theta * cp.sin(theta_sample)
        v_y = v_R * cp.sin(theta_sample) + v_theta * cp.cos(theta_sample)

        # c) Map each sample's position to grid indices
        i_grid, j_grid = R_grid_index(x_sample, y_sample)
        Psi_sample = Psi[i_grid, j_grid]  # shape (N_sample,)

        # d) Compute Q for each sample
        Q_sample = Psi_sample \
                - 0.5 * v_R**2 \
                - 0.5 * (1.0 + (R_sample**2 / ra1**2)) * v_theta**2

        # e) Smooth cutoff factor
        smooth_cut = 0.5 * (1.0 + cp.tanh(Q_sample / Q_smooth))  # (N_sample,)

        # f) Rejection sampling
        u = cp.random.uniform(0.0, 1.0, size=N_sample)
        mask = (u < smooth_cut)

        # g) Accumulate accepted samples
        accepted_x  = cp.concatenate((accepted_x,  x_sample[mask]))
        accepted_y  = cp.concatenate((accepted_y,  y_sample[mask]))
        accepted_vx = cp.concatenate((accepted_vx, v_x[mask]))
        accepted_vy = cp.concatenate((accepted_vy, v_y[mask]))

        # Optional safeguard against too many attempts
        if N_sample > 5e6:
            raise RuntimeError("Too many samples needed. Check Q_smooth or potential depth.")

    # 4) We now have at least N_needed accepted samples; truncate exactly
    accepted_x  = accepted_x[:N_needed]
    accepted_y  = accepted_y[:N_needed]
    accepted_vx = accepted_vx[:N_needed]
    accepted_vy = accepted_vy[:N_needed]



    v_x_final = accepted_vx 
    v_y_final = accepted_vy 

    # 6) Combine final velocities
    vel1 = cp.column_stack((v_x_final, v_y_final))

    print("Particle velocities sampled according to the truncated OM distribution.")


        # 5) Optionally add rotation
    R_final     = cp.sqrt(accepted_x**2 + accepted_y**2)
    theta_final = cp.arctan2(accepted_y, accepted_x)

    """v_theta_add = v0 * (R_final / R0) * cp.exp(-R_final / R0)
    v_add_x = -v_theta_add * cp.sin(theta_final)
    v_add_y =  v_theta_add * cp.cos(theta_final)"""
     # Testing Disk shaped anistropic velocity distribution self consistent iterative method
    density1 = particle_to_mesh(pos1_gpu, m1 )
    density2 = particle_to_mesh(pos2_gpu, m2 )
    total_density = density1 + density2
    force_grid = solve_poisson(total_density, dx, G)  # solve_poisson returns forces from density
    vel1 = cp.zeros((N1, 2))

    # Set circular velocities for galaxy particles from potential
    print("start circular velocity")
    vel1_circ= calculate_velocities(pos1_gpu, force_grid, dx, L, r1/3)*0
   


    return pos1_gpu, vel1, pos2_gpu, vel2_gpu




def solve_poisson(density, dx, G):
    """
    Poisson solver FFT
    """
    
    # Regular Poisson solve with density contrasts
    mean_density = cp.mean(density)
    delta_density = density 
    
    density_k = cp.fft.fft2(delta_density)
    
    # Wave numbers
    kx = cp.fft.fftfreq(density.shape[0], d=dx)[:, cp.newaxis] * 2 * np.pi
    ky = cp.fft.fftfreq(density.shape[1], d=dx)[cp.newaxis, :] * 2 * np.pi
    k_squared = kx**2 + ky**2

      # 3. Better k=0 handling
    k_squared[0,0] = 1.0  # Temporary non-zero value
    
    # Handle k=0 mode properly
  
    
    phi_k= -4 * cp.pi * G * density_k / k_squared
    phi_k[0,0] = 0.0  # Remove constant potential
    
    # Calculate forces
    force_x_k = -1j *  kx * phi_k
    force_y_k = -1j * ky * phi_k
    
    #force_x = cp.fft.irfft2(force_x_k, s=density.shape)
    #force_y = cp.fft.irfft2(force_y_k, s=density.shape)

    force_x = cp.real(cp.fft.ifft2(force_x_k))
    force_y = cp.real(cp.fft.ifft2(force_y_k))
    
    return cp.stack([force_x, force_y], axis=-1)

def particle_to_mesh(pos, mass):
    density = cp.zeros((N_grid, N_grid), dtype=cp.float32)
    
    # Get the lower-left grid point for each particle
    pos_grid = (pos / dx).astype(cp.float32)
    pos_grid_floor = cp.floor(pos_grid).astype(int)
    
    # Calculate weights for CIC interpolation
    dxp = pos_grid - pos_grid_floor  # Distance from particle to lower-left grid point
    
    # Weights for the 4 nearest cells
    wx_r = dxp[:, 0]        # right x weight
    wx_l = 1.0 - wx_r       # left x weight
    wy_t = dxp[:, 1]        # top y weight
    wy_b = 1.0 - wy_t       # bottom y weight
    
    # Handle periodic boundary conditions
    x0 = pos_grid_floor[:, 0] % N_grid
    x1 = (x0 + 1) % N_grid
    y0 = pos_grid_floor[:, 1] % N_grid
    y1 = (y0 + 1) % N_grid
    
    # Distribute mass to the 4 nearest cells
    cp.add.at(density, (x0, y0), mass * wx_l * wy_b)  # bottom-left
    cp.add.at(density, (x1, y0), mass * wx_r * wy_b)  # bottom-right
    cp.add.at(density, (x0, y1), mass * wx_l * wy_t)  # top-left
    cp.add.at(density, (x1, y1), mass * wx_r * wy_t)  # top-right
    
    return density / dx**2

def interpolate_force(force, pos):
    grid_pos = pos / dx
    x0 = cp.floor(grid_pos[:, 0]).astype(int)
    y0 = cp.floor(grid_pos[:, 1]).astype(int)

    wx = grid_pos[:, 0] - x0
    wy = grid_pos[:, 1] - y0

    x0 = x0 % N_grid
    y0 = y0 % N_grid
    x1 = (x0 + 1) % N_grid
    y1 = (y0 + 1) % N_grid

    f00 = force[x0, y0]
    f10 = force[x1, y0]
    f01 = force[x0, y1]
    f11 = force[x1, y1]

    fx = (1 - wx) * (1 - wy) * f00[:, 0] + wx * (1 - wy) * f10[:, 0] + (1 - wx) * wy * f01[:, 0] + wx * wy * f11[:, 0]
    fy = (1 - wx) * (1 - wy) * f00[:, 1] + wx * (1 - wy) * f10[:, 1] + (1 - wx) * wy * f01[:, 1] + wx * wy * f11[:, 1]

    return cp.stack([fx, fy], axis=1)



def interpolate_force2(force, pos):

    grid_pos = pos / dx
    x0 = cp.floor(grid_pos[:, 0]).astype(int) % N_grid
    y0 = cp.floor(grid_pos[:, 1]).astype(int) % N_grid
    x1 = (x0 + 1) % N_grid
    y1 = (y0 + 1) % N_grid

    wx = grid_pos[:, 0] - cp.floor(grid_pos[:, 0])
    wy = grid_pos[:, 1] - cp.floor(grid_pos[:, 1])

    f00 = force[x0, y0]
    f10 = force[x1, y0]
    f01 = force[x0, y1]
    f11 = force[x1, y1]

    fx = (1 - wx) * (1 - wy) * f00[:, 0] + wx * (1 - wy) * f10[:, 0] + (1 - wx) * wy * f01[:, 0] + wx * wy * f11[:, 0]
    fy = (1 - wx) * (1 - wy) * f00[:, 1] + wx * (1 - wy) * f10[:, 1] + (1 - wx) * wy * f01[:, 1] + wx * wy * f11[:, 1]

    return cp.stack([fx, fy], axis=1)


def get_force(pos,vel,m_gravitational):
    density1 = particle_to_mesh(pos[:N1], m1)
    density2 = particle_to_mesh(pos[N1:], m2)
    total_density = density1 + density2
    
    # Get forces directly from density
    force_grid = solve_poisson(total_density, dx, G)

    #print(stab)
    return interpolate_force(force_grid, pos)* m_gravitational[:, cp.newaxis]



def run_simulation(pos1, vel1, pos2, vel2):
    positions1 = []
    positions2 = []

    # Combine positions and velocities
    all_pos = cp.concatenate([pos1, pos2])
    all_vel = cp.concatenate([vel1, vel2])
    
    # Create mass arrays for both inertial and gravitational masses
    m_inertial = cp.concatenate([cp.full(N1, abs(m1)), cp.full(N2, abs(m2))])
    m_gravitational = cp.concatenate([cp.full(N1, m1), cp.full(N2, m2)])

    # Initial force calculation with gravitational mass
    force = get_force(all_pos,all_vel,m_gravitational)

    for step in range(n_steps):
        # Leapfrog integration with inertial mass
        all_vel_half = all_vel + 0.5 * (force / m_inertial[:, cp.newaxis]) * dt
        all_pos = (all_pos + all_vel_half * dt) % L

    
        # New force calculation with gravitational mass Poisson Solver
        new_force = get_force(all_pos,all_vel,m_gravitational)
        print("new_force",new_force)

       
        
        # Complete velocity update with inertial mass
        all_vel = all_vel_half + 0.5 * (new_force / m_inertial[:, cp.newaxis]) * dt

        force = new_force

        # Downsample factor
         # Adjust as needed for lower resolution

        if step % skipped_anime_frame == 0:
            """positions1_low_res = cp.asnumpy(all_pos[:N1:downsample_factor])
            positions2_low_res = cp.asnumpy(all_pos[N1::downsample_factor])
            positions1.append(positions1_low_res)
            positions2.append(positions2_low_res)"""
            positions1.append(cp.asnumpy(all_pos[:N1]))
            positions2.append(cp.asnumpy(all_pos[N1:]))
            
        if step % (skipped_anime_frame) == 0:
            #plot_density(positions1[-1],positions2[-1], "step : "+ str(step))
            plot_density(positions1[-1],positions2[-1], "step : "+ str(step))
            #plot_density(positions1[-1])

    return positions1, positions2

# Add these settings for faster rendering
plt.style.use('fast')  # Use fast style
mpl.rcParams['path.simplify'] = True
mpl.rcParams['path.simplify_threshold'] = 1.0
mpl.rcParams['agg.path.chunksize'] = 10000

def create_efficient_animation32(positions1, positions2, skip_frames=1, dpi=60, percentile=99):
    fig = plt.figure(figsize=(12, 10))
    gs = fig.add_gridspec(2, 2, height_ratios=[1, 1])
    
    x_edges = np.linspace(0, L, N_grid + 1)
    y_edges = np.linspace(0, L, N_grid + 1)
    
    def animate(i):
        frame_idx = i * skip_frames
        if frame_idx >= len(positions1):
            frame_idx = len(positions1) - 1
        
        fig.clf()
        gs = fig.add_gridspec(2, 2, height_ratios=[1, 1])
        
        # Calculate histograms for both fluids
        hist1, _, _ = np.histogram2d(
            positions1[frame_idx][:, 0], positions1[frame_idx][:, 1],
            bins=[x_edges, y_edges]
        )
        hist2, _, _ = np.histogram2d(
            positions2[frame_idx][:, 0], positions2[frame_idx][:, 1],
            bins=[x_edges, y_edges]
        )

        # Normalize histograms
        hist1_normalized = hist1 / hist1.max()
        hist2_normalized = hist2 / hist2.max()
        hist_total = hist1_normalized + hist2_normalized

        # Plot combined density in the top subplot spanning two columns
        ax0 = fig.add_subplot(gs[0, :])
        vmax_total = np.percentile(hist_total, percentile)
        im0 = ax0.imshow(
            hist_total.T,
            origin='lower',
            extent=[0, L, 0, L],
            cmap='plasma',
            vmin=0,
            vmax=vmax_total
        )
        ax0.set_title(f'T={frame_idx * dt *skipped_anime_frame* skip_frames:.1f} Myr')
        ax0.set_xlabel('x (kpc)')
        ax0.set_ylabel('y (kpc)')
        ax0.set_xticks([])
        ax0.set_yticks([])
        fig.colorbar(im0, ax=ax0, label='Combined Density')

        # Plot fluid 1 density in the bottom-left subplot
        ax1 = fig.add_subplot(gs[1, 0])
        vmax1 = np.percentile(hist1_normalized, percentile)
        im1 = ax1.imshow(
            hist1_normalized.T,
            origin='lower',
            extent=[0, L, 0, L],
            cmap='viridis',
            vmin=0,
            vmax=vmax1
        )
        ax1.set_title('Fluid 1 Density')
        ax1.set_xlabel('x (kpc)')
        ax1.set_ylabel('y (kpc)')
        ax1.set_xticks([])
        ax1.set_yticks([])
        fig.colorbar(im1, ax=ax1, label='Fluid 1 Density')

        # Plot fluid 2 density in the bottom-right subplot
        ax2 = fig.add_subplot(gs[1, 1])
        vmax2 = np.percentile(hist2_normalized, percentile)
        im2 = ax2.imshow(
            hist2_normalized.T,
            origin='lower',
            extent=[0, L, 0, L],
            cmap='plasma',
            vmin=0,
            vmax=vmax2
        )
        ax2.set_title('Fluid 2 Density')
        ax2.set_xlabel('x (kpc)')
        ax2.set_ylabel('y (kpc)')
        ax2.set_xticks([])
        ax2.set_yticks([])
        fig.colorbar(im2, ax=ax2, label='Fluid 2 Density')

        plt.tight_layout()
        
        return [im0, im1, im2]
    
    n_frames = len(positions1) // skip_frames
    anim = FuncAnimation(fig, animate, frames=n_frames, interval=100, blit=False)
    
    anim.save(
        '2D_PMDST_Janus_Galaxy1.gif',
        writer='pillow',
        fps=5,
        dpi=dpi
    )
    plt.close()


def create_efficient_animation3(positions1, positions2, skip_frames=1, dpi=72, percentile=99):
    fig, ax = plt.subplots(figsize=(10, 10))
    
    x_edges = np.linspace(0, L, N_grid + 1)
    y_edges = np.linspace(0, L, N_grid + 1)
    
    # Precompute all histograms to determine global vmax based on percentile

    def animate(i):
        frame_idx = i * skip_frames
        if frame_idx >= len(positions1):
            frame_idx = len(positions1) - 1
        
        ax.clear()
        
        hist1, _, _ = np.histogram2d(
            positions1[frame_idx][:, 0], 
            positions1[frame_idx][:, 1],
            bins=[x_edges, y_edges]
        )
        hist2, _, _ = np.histogram2d(
            positions2[frame_idx][:, 0], 
            positions2[frame_idx][:, 1],
            bins=[x_edges, y_edges]
        )
        
        # Normalize histograms using percentile
        # Compute medians
        median1 = np.median(hist1)
        median2 = np.median(hist2)

        print(f"hist1 median: {median1}, hist2 median: {median2}")

        # Normalize histograms using median
        #hist1_normalized = hist1 / median1 if median1 != 0 else hist1
        hist1_normalized= hist1 / hist1.max()
        hist2_normalized = hist2 / hist2.max()
       # hist2_normalized = hist2 / median2 if median2 != 0 else hist2.max()

        histTotal = hist1_normalized + hist2_normalized

        vmax = np.percentile(histTotal, 99)*1.1
        
        im = ax.imshow(
            histTotal.T, 
            origin='lower', 
            extent=[0, L, 0, L],
            cmap='plasma', 
            alpha=1,
            vmin=0, 
            vmax=vmax
        )
        
        #ax.set_title(f'T={frame_idx * dt * skip_frames*:.1f} Myr')
        ax.set_xticks([])
        ax.set_yticks([])
        
        return [im]
    
    n_frames = len(positions1) // skip_frames
    anim = FuncAnimation(fig, animate, frames=n_frames, interval=100, blit=True)
    
    anim.save(
        '2D_PMDST_Janus_Galaxy322222.gif',
        writer='pillow',
        fps=5,
        dpi=dpi
    )
    plt.close()

def create_efficient_animation3232(positions1, positions2, skip_frames=1, dpi=100, crop_size=0.5):
    fig, ax = plt.subplots(figsize=(10, 10))

    # Define the cropping region around the center
    center_x = L / 2
    center_y = L / 2
    half_crop = (L * crop_size)
    x_min = center_x - half_crop
    x_max = center_x + half_crop
    y_min = center_y - half_crop
    y_max = center_y + half_crop

    def animate(i):
        frame_idx = i * skip_frames
        if frame_idx >= len(positions1):
            frame_idx = len(positions1) - 1

        ax.clear()

        # Combine positions from both sets
        x = np.concatenate([positions1[frame_idx][:, 0], positions2[frame_idx][:, 0]])
        y = np.concatenate([positions1[frame_idx][:, 1], positions2[frame_idx][:, 1]])

        # Apply cropping
        mask = (
            (x >= x_min) & (x <= x_max) &
            (y >= y_min) & (y <= y_max)
        )
        x_cropped = x[mask]
        y_cropped = y[mask]

        # Plot density similar to plot_density function
        ax.hist2d(
            x_cropped,
            y_cropped,
            bins=100,
            range=[[x_min, x_max], [y_min, y_max]],
            cmap='plasma',
            norm=LogNorm()
        )

        ax.set_xticks([])
        ax.set_yticks([])

        return ax

    n_frames = len(positions1) // skip_frames
    anim = FuncAnimation(fig, animate, frames=n_frames, interval=100, blit=False)

    anim.save(
        '2D_PMDST_Janus_Galaxy32222234567.gif',
        writer='pillow',
        fps=5,
        dpi=dpi
    )
    plt.close()

def create_efficient_animation44(positions1, positions2, skip_frames=1, dpi=100, percentile=99, crop_size=0.25):
    fig, ax = plt.subplots(figsize=(10, 10))
    
    x_edges = np.linspace(0, L, N_grid + 1)
    y_edges = np.linspace(0, L, N_grid + 1)

    # Define the cropping region around the center
    center_x = L / 2
    center_y = L / 2
    half_crop = (L * crop_size) / 2
    x_min = center_x - half_crop
    x_max = center_x + half_crop
    y_min = center_y - half_crop
    y_max = center_y + half_crop
    
    # Find indices corresponding to the cropped region
    x_idx = np.where((x_edges >= x_min) & (x_edges <= x_max))[0]
    y_idx = np.where((y_edges >= y_min) & (y_edges <= y_max))[0]

    def animate(i):
        frame_idx = i * skip_frames
        if frame_idx >= len(positions1):
            frame_idx = len(positions1) - 1
        
        ax.clear()
        
        hist1, _, _ = np.histogram2d(
            positions1[frame_idx][:, 0], 
            positions1[frame_idx][:, 1],
            bins=[x_edges, y_edges]
        )
        hist2, _, _ = np.histogram2d(
            positions2[frame_idx][:, 0], 
            positions2[frame_idx][:, 1],
            bins=[x_edges, y_edges]
        )
        
        hist1_normalized = hist1 / hist1.max()
        hist2_normalized = hist2 / hist2.max()
        histTotal = hist1_normalized + hist2_normalized
        vmax = np.percentile(histTotal, percentile) * 1.1

        # Crop the histogram
        histTotal_cropped = histTotal[np.ix_(x_idx[:-1], y_idx[:-1])]
        
        im = ax.imshow(
            histTotal_cropped.T, 
            origin='lower', 
            extent=[x_min, x_max, y_min, y_max],
            cmap='plasma', 
            norm=LogNorm()
        )
        
        ax.set_xticks([])
        ax.set_yticks([])
        
        return [im]
        
    n_frames = len(positions1) // skip_frames
    anim = FuncAnimation(fig, animate, frames=n_frames, interval=100, blit=True)
    
    anim.save(
        '2D_PMDST_Janus_Galaxy32222234567.gif',
        writer='pillow',
        fps=5,
        dpi=dpi
    )
    plt.close()

def create_efficient_animation333(positions1, positions2, skip_frames=1, dpi=100):
    fig, ax = plt.subplots(figsize=(10, 10))
    
    n_frames = len(positions1) // skip_frames

    center = L / 2
    delta = L / 2  # Adjust delta if needed
    region = [[center - delta, center + delta], [center - delta, center + delta]]
    
    def animate(i):
        frame_idx = i * skip_frames
        if frame_idx >= len(positions1):
            frame_idx = len(positions1) - 1

        ax.clear()

        pos1 = positions1[frame_idx]
        pos2 = positions2[frame_idx]
        
        hist1, _, _ = np.histogram2d(
            pos1[:, 0], pos1[:, 1],
            bins=N_grid, range=region
        )
        hist2, _, _ = np.histogram2d(
            pos2[:, 0], pos2[:, 1],
            bins=N_grid, range=region
        )

        histTotal = hist1 + hist2
        histTotal[histTotal == 0] = 1e-3  # Avoid zeros for LogNorm

        im = ax.imshow(
            histTotal,
            origin='lower',
            extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
            cmap='gist_heat',
            norm=LogNorm()
        )

        ax.set_title(f"Time Step {frame_idx}")
        ax.set_xlabel('x (kpc)')
        ax.set_ylabel('y (kpc)')
        plt.tight_layout()
        
        return [im]

    anim = FuncAnimation(fig, animate, frames=n_frames, interval=100, blit=True)
    anim.save('density_animation.gif', writer='pillow', fps=5, dpi=dpi)
    plt.close()

def create_efficient_animation(positions1, positions2, skip_frames=1, dpi=50):
    """
    Create an animation that shows the density of two fluids side by side.
    This version pre-creates the image and colorbar artists for speed.
    """
    # Create figure and two axes side by side
    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    
    n_frames = len(positions1) // skip_frames

    center = L / 2
    delta = L / 2  # Adjust delta if needed
    region = [[center - delta, center + delta], [center - delta, center + delta]]
    
    # Pre-calculate an initial histogram for both fluids (first frame)
    pos1 = positions1[0]
    pos2 = positions2[0]
    hist1, _, _ = np.histogram2d(
        pos1[:, 0], pos1[:, 1],
        bins=N_grid, range=region
    )
    hist2, _, _ = np.histogram2d(
        pos2[:, 0], pos2[:, 1],
        bins=N_grid, range=region
    )
    # Avoid zeros (for LogNorm)
    hist1[hist1 == 0] = 1e-3
    hist2[hist2 == 0] = 1e-3

    # Create the images and attach fixed colorbars.
    im1 = axs[0].imshow(
        hist1,
        origin='lower',
        extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
        cmap='gist_heat', 
        norm=LogNorm()
    )
    axs[0].set_title("Fluid 1")
    axs[0].set_xlabel('x (kpc)')
    axs[0].set_ylabel('y (kpc)')
    cbar1 = fig.colorbar(im1, ax=axs[0], label='Logarithmic Density')

    im2 = axs[1].imshow(
        hist2,
        origin='lower',
        extent=[region[0][0], region[0][1], region[1][0], region[1][1]],
        cmap='gist_heat',
        norm=LogNorm()
    )
    axs[1].set_title("Fluid 2")
    axs[1].set_xlabel('x (kpc)')
    axs[1].set_ylabel('y (kpc)')
    cbar2 = fig.colorbar(im2, ax=axs[1], label='Logarithmic Density')

    plt.tight_layout()

    def animate(i):
        # Determine the frame index and ensure it's within range
        frame_idx = i * skip_frames
        if frame_idx >= len(positions1):
            frame_idx = len(positions1) - 1

        # Compute new histograms for both fluids for the given frame
        pos1 = positions1[frame_idx]
        pos2 = positions2[frame_idx]
        new_hist1, _, _ = np.histogram2d(
            pos1[:, 0], pos1[:, 1],
            bins=N_grid, range=region
        )
        new_hist2, _, _ = np.histogram2d(
            pos2[:, 0], pos2[:, 1],
            bins=N_grid, range=region
        )
        new_hist1[new_hist1 == 0] = 1e-3
        new_hist2[new_hist2 == 0] = 1e-3

        # Update the image data
        im1.set_data(new_hist1)
        im2.set_data(new_hist2)
        
        # Update the title for each subplot with the current time step
        axs[0].set_title(f"Time Step {frame_idx} - Fluid 1")
        axs[1].set_title(f"Time Step {frame_idx} - Fluid 2")
        return [im1, im2]

    # Create the animation using FuncAnimation with blit set to True
    anim = FuncAnimation(fig, animate, frames=n_frames, interval=100, blit=True)
    
    # Save the animation as a GIF file
    anim.save('density_animation.gif', writer='pillow', fps=5, dpi=dpi)
    plt.close()


posfinal1=[]
posfinal2=[]
def run_and_save_simulation():
    # Initialize and run simulation
    print("Initializing particles...")
    pos1, vel1, pos2, vel2 = initialize_particles()
    print("Particles initialized!")
    
    # Show initial state
    plot_density(cp.asnumpy(pos1), cp.asnumpy(pos2), "Initial State")
    #plot_density(cp.asnumpy(pos1))
    
    # Run simulation and collect positions with mass parameters
    positions1, positions2 = run_simulation(
        pos1=pos1,
        vel1=vel1,
        pos2=pos2,
        vel2=vel2
    )
    
    # Show final state
    plot_density(positions1[-1], positions2[-1], "Final State")

    # Convert positions to lists for JSON serialization
   # print("Saving data to 'positions.json'...")
    #data = {
       # 'positions1': [pos.tolist() for pos in positions1],
      #  'positions2': [pos.tolist() for pos in positions2]
    #}

    # Save data to 'positions.json'
    #with open('positions.json', 'w') as f:
       # json.dump(data, f)

    #plot_density(cp.asnumpy(pos1))

    # Create optimized animation
    print("Creating animation...")
    create_efficient_animation(positions1, positions2, 
                             skip_frames=1,  # Save every ith frame
                             dpi=100)        # Lower resolution
    
    print("Animation saved!")

#if __name__ == "__main__":
    #run_and_save_simulation()
if __name__ == "__main__":
    threading.Thread(target=run_and_save_simulation).start()