In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm import tqdm

jax.config.update("jax_enable_x64", True)

# Choose Flow Type: Set to 'DHIT' for DHIT flow or 'TGV' for TGV flow
FLOW_TYPE = 'DHIT'  # DHIT => Decay Homogeneous Isotropic Turbulence
                    # TGV => Taylor-Green vortex

# Grid dimensions
nx, ny, nz = 64, 64, 64
u0 = 0.1
KINEMATIC_VISCOSITY = 0.00125
RELAXATION_OMEGA = 1.0 / (3.0 * KINEMATIC_VISCOSITY + 0.5)

PLOT_EVERY_N_STEPS = 10
SAVE_EVERY_N_STEPS = 50000
N_ITERATIONS = 1000

N_DISCRETE_VELOCITIES = 19
LATICE_VELOCITIES_X = jnp.array([0, 1, -1, 0, 0, 0, 0, 1, -1, 1, -1, 0, 0, 0, 0, 1, -1, 1, -1])
LATICE_VELOCITIES_Y = jnp.array([0, 0, 0, 1, -1, 0, 0, 1, -1, -1, 1, 1, -1, 1, -1, 0, 0, 0, 0])
LATICE_VELOCITIES_Z = jnp.array([0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 1, -1, -1, 1, 1, -1, -1, 1])
LATTICE_VELOCITIES = jnp.array([LATICE_VELOCITIES_X, LATICE_VELOCITIES_Y, LATICE_VELOCITIES_Z])

LATTICE_WEIGHTS = jnp.array([
    1/3,  # Zero weight
    1/18, 1/18, 1/18, 1/18, 1/18, 1/18,  # Adjacent weights
    1/36, 1/36, 1/36, 1/36, 1/36, 1/36, 1/36, 1/36, 1/36, 1/36, 1/36, 1/36  # Diagonal weights
])

# Initialize grid
x = (jnp.linspace(0, 1, nx) - 0.5) / 1.0  # Center and normalize by grid size
y = (jnp.linspace(0, 1, ny) - 0.5) / 1.0
z = (jnp.linspace(0, 1, nz) - 0.5) / 1.0

X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")

# Velocity profile based on the selected flow type
if FLOW_TYPE == 'DHIT':
    ux = u0 * jnp.sin(2 * jnp.pi * X) * jnp.cos(2 * jnp.pi * Y) * jnp.cos(2 * jnp.pi * Z)
    uy = -u0 * jnp.cos(2 * jnp.pi * X) * jnp.sin(2 * jnp.pi * Y) * jnp.cos(2 * jnp.pi * Z)
    uz = jnp.zeros_like(ux)
    rho = 1.0 + (u0**2 / 16.0) * (jnp.cos(4.0 * jnp.pi * X) + jnp.cos(4.0 * jnp.pi * Y)) * (jnp.cos(4.0 * jnp.pi * Z) + 2.0)
elif FLOW_TYPE == 'TGV':
    ux = u0 * jnp.sin(2 * jnp.pi * X) * jnp.cos(2 * jnp.pi * Z)
    uy = jnp.zeros_like(ux)
    uz = -u0 * jnp.cos(2 * jnp.pi * X) * jnp.sin(2 * jnp.pi * Z)
    rho = jnp.ones_like(ux)
else:
    raise ValueError("FLOW_TYPE must be either 'DHIT' or 'TGV'.")

def get_equilibrium_discrete_velocities(macroscopic_velocities, density):
    projected_discrete_velocities = jnp.einsum("dQ,ijkd->ijkQ", LATTICE_VELOCITIES, macroscopic_velocities)
    macroscopic_velocity_magnitude = jnp.linalg.norm(macroscopic_velocities, axis=-1)
    equilibrium_discrete_velocities = (
        density[..., jnp.newaxis] * LATTICE_WEIGHTS[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] *
        (1 + 3 * projected_discrete_velocities +
         9/2 * projected_discrete_velocities**2 - 
         3/2 * macroscopic_velocity_magnitude[..., jnp.newaxis]**2)
    )
    return equilibrium_discrete_velocities

VELOCITY_PROFILE = jnp.stack((ux, uy, uz), axis=-1)
discrete_velocities_prev = get_equilibrium_discrete_velocities(VELOCITY_PROFILE, rho)

def get_density(discrete_velocities):
    return jnp.sum(discrete_velocities, axis=-1)

def get_macroscopic_velocities(discrete_velocities, density):
    return jnp.einsum("dQ,ijkQ->ijkd", LATTICE_VELOCITIES, discrete_velocities) / density[..., jnp.newaxis]

@jax.jit
def update(discrete_velocities_prev):
    density_prev = get_density(discrete_velocities_prev)
    macroscopic_velocities_prev = get_macroscopic_velocities(discrete_velocities_prev, density_prev)
    equilibrium_discrete_velocities = get_equilibrium_discrete_velocities(macroscopic_velocities_prev, density_prev)

    discrete_velocities_post_collision = (
        discrete_velocities_prev - RELAXATION_OMEGA * (discrete_velocities_prev - equilibrium_discrete_velocities)
    )

    discrete_velocities_streamed = jnp.zeros_like(discrete_velocities_post_collision)
    for i in range(N_DISCRETE_VELOCITIES):
        discrete_velocities_streamed = discrete_velocities_streamed.at[..., i].set(
            jnp.roll(
                jnp.roll(
                    jnp.roll(discrete_velocities_post_collision[..., i],
                             shift=LATTICE_VELOCITIES[0, i], axis=0),
                    shift=LATTICE_VELOCITIES[1, i], axis=1),
                shift=LATTICE_VELOCITIES[2, i], axis=2)
        )

    return discrete_velocities_streamed

def calculate_kinetic_energy(macroscopic_velocities):
    return 0.5 * jnp.sum(jnp.square(macroscopic_velocities))

def calculate_energy_dissipation_rate(macroscopic_velocities, kinematic_viscosity):
    """
    Calculate the energy dissipation rate using the formula:
    ε = ν * sum(∂u_i/∂x_j * ∂u_i/∂x_j)

    Args:
        macroscopic_velocities: Velocity field (shape: [nx, ny, nz, 3])
        kinematic_viscosity: Kinematic viscosity ν

    Returns:
        Energy dissipation rate ε
    """
    # Compute gradients for each velocity component
    gradients = [
        jnp.gradient(macroscopic_velocities[..., d], axis=(0, 1, 2))
        for d in range(3)
    ]  # List of tuples (one tuple per velocity component)

    # Extract and square the gradients
    squared_gradients = sum(
        jnp.square(grad) for component_gradients in gradients for grad in component_gradients
    )  # Sum of squared gradients across all components and directions

    # Integrate over the domain
    energy_dissipation_rate = kinematic_viscosity * jnp.sum(squared_gradients)
    return energy_dissipation_rate



def run(discrete_velocities_prev):
    kinetic_energy_over_time = []
    energy_dissipation_over_time = []

    with open("ke_data.dat", "a") as ke_file, open("epsilon_data.dat", "a") as eps_file:
        for i in tqdm(range(N_ITERATIONS)):
            discrete_velocities_next = update(discrete_velocities_prev)
            discrete_velocities_prev = discrete_velocities_next
            
            density = get_density(discrete_velocities_next)
            macroscopic_velocities = get_macroscopic_velocities(discrete_velocities_next, density)
            kinetic_energy = calculate_kinetic_energy(macroscopic_velocities)
            energy_dissipation_rate = calculate_energy_dissipation_rate(macroscopic_velocities, KINEMATIC_VISCOSITY)
            
            kinetic_energy_over_time.append(kinetic_energy)
            energy_dissipation_over_time.append(energy_dissipation_rate)

            if i % PLOT_EVERY_N_STEPS == 0:
                plt.figure(figsize=(4, 4))
                normalized_velocity_magnitude = jnp.linalg.norm(macroscopic_velocities[:, :, nx // 2], axis=-1) / (nx * ny * nz)
                plt.contourf(X[:, :, nx // 2], Y[:, :, nx // 2], normalized_velocity_magnitude, cmap='inferno')
                plt.colorbar(label="Normalized Velocity Magnitude")
                plt.title(f"Iteration {i}")
                plt.axis("scaled")
                plt.show()

            if i % SAVE_EVERY_N_STEPS == 0:
                normalized_kinetic_energy = kinetic_energy / (nx * ny * nz)
                normalized_energy_dissipation = energy_dissipation_rate / (nx * ny * nz)
                ke_file.write(f"{i:10d} {normalized_kinetic_energy:.6e}\n")
                eps_file.write(f"{i:10d} {normalized_energy_dissipation:.6e}\n")

    plt.figure()
    plt.plot([ke / (nx * ny * nz) for ke in kinetic_energy_over_time], label="Kinetic Energy")
    plt.xlabel("Iteration")
    plt.ylabel("Normalized Total Kinetic Energy")
    plt.legend()
    plt.title("Kinetic Energy Decay")
    plt.show()

    plt.figure()
    plt.plot([eps / (nx * ny * nz) for eps in energy_dissipation_over_time], label="Energy Dissipation Rate", color='r')
    plt.xlabel("Iteration")
    plt.ylabel("Normalized Energy Dissipation Rate")
    plt.legend()
    plt.title("Energy Dissipation Rate Over Time")
    plt.show()

run(discrete_velocities_prev)
