In [None]:
import jax
import jax.numpy as jnp # transfer the backend from numpy to jax
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [None]:
devices = jax.devices()
print(devices)

In [None]:
# LBM D2Q9 lattice scheme configuration
##############
#  6   2   5
#   \  |  /
# 3 —— 0 —— 1
#   /  |  \
#  7   4   8
##############

Nq = 9  # Number of lattice directions
Nx = 32  # Number of grid points in x-direction
Ny = 32  # Number of grid points in y-direction
Nt = 300  # Number of timesteps

# Lattice vector definition
ei = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8])
ci = jnp.array([[0, 1, 0, -1, 0, 1, -1, -1, 1], 
               [0, 0, 1, 0, -1, 1, 1, -1, -1]])

# Lattice weights definition
omega1, omega2, omega3 = 1.0, 1.5, 1.9
weights = jnp.array([4/9, 1/9, 1/9, 1/9, 1/9, 1/36, 1/36, 1/36, 1/36])

In [None]:
# Meshgrid definition in 2D XY domain
t = jnp.arange(Nt)
xm = jnp.arange(Nx)
ym = jnp.arange(Ny)
X, Y = jnp.meshgrid(xm, ym, indexing='ij')

In [None]:
# fi field initialization
A_x = 0.3
k_x = 1

fi = jnp.zeros((Nx, Ny, Nq))
for i, wi in zip(ei, weights):
    fi = fi.at[X, Y, i].set(wi * (1 + A_x * jnp.cos(2 * jnp.pi * k_x * Y / Ny) * ci[0][i]))
fi1, fi2, fi3 = fi, fi, fi

g_jk = jnp.tensordot(fi, fi, axes=0)
g_jk1, g_jk2, g_jk3 = g_jk, g_jk, g_jk

In [None]:
def streaming(fi, g_jk):
    for i in ei:
        # Propagate fi value in each corresponding direction using jnp.roll
        fi = fi.at[:, :, i].set(jnp.roll(fi[:, :, i], ci[0][i], axis=0)) # roll along x
        fi = fi.at[:, :, i].set(jnp.roll(fi[:, :, i], ci[1][i], axis=1)) # roll along y

    for j in ei:
        g_jk = g_jk.at[:, :, j, :, :, :].set(jnp.roll(g_jk[:, :, j, :, :, :], ci[0][j], axis=0))
        g_jk = g_jk.at[:, :, j, :, :, :].set(jnp.roll(g_jk[:, :, j, :, :, :], ci[1][j], axis=1))

    for k in ei:
        g_jk = g_jk.at[:, :, :, :, :, k].set(jnp.roll(g_jk[:, :, :, :, :, k], ci[0][k], axis=3))
        g_jk = g_jk.at[:, :, :, :, :, k].set(jnp.roll(g_jk[:, :, :, :, :, k], ci[1][k], axis=4))

    return fi, g_jk

In [None]:
def get_matrix(omega):
    # Calculate matrix A
    ci_cj = jnp.tensordot(ci[0], ci[0], axes=0) + jnp.tensordot(ci[1], ci[1], axes=0) # ci is equivalent to cj
    L_ij = jnp.transpose(weights * (1 + 3 * ci_cj))
    A_ij = (1 - omega) * jnp.eye(Nq) + omega * L_ij

    # Calculate matrix B
    Q_ijk = jnp.zeros((Nq, Nq, Nq))
    for i in range(Nq):
        for j in range(Nq):
            for k in range(Nq):
                ci_cj = ci[0][i] * ci[0][j] + ci[1][i] * ci[1][j]
                ci_ck = ci[0][i] * ci[0][k] + ci[1][i] * ci[1][k]
                cj_ck = ci[0][j] * ci[0][k] + ci[1][j] * ci[1][k]
                Q_ijk = Q_ijk.at[i, j, k].set(9 * weights[i] * (ci_cj * ci_ck - cj_ck / 3))
    B_ijk = omega * Q_ijk
    
    return A_ij, B_ijk

In [None]:
def collision_BGK_2nd_carleman(fi_pre, A_ij, g_jk, B_ijk):
    f_j = jnp.squeeze(fi_pre)
    A_i = jnp.einsum('ij,NMj->NMi', A_ij, f_j) # N stands for x direction and M for y
    
    g_jk_dia = g_jk[X, Y, :, X, Y, :]
    
    B_i = jnp.einsum('ijk,NMjk->NMi', B_ijk, g_jk_dia)
    fi_collisioned = A_i + B_i
    
    AA = jnp.tensordot(A_ij, A_ij, axes=0)
    g_jk_collisioned = jnp.einsum('ikjl,NMkXYl->NMiXYj', AA, g_jk)
    
    return fi_collisioned, g_jk_collisioned

In [None]:
def get_macro_quantities(fi):
    rho = jnp.sum(fi, axis=2)  # Density
    ux = jnp.sum(fi * ci[0], 2) / rho
    uy = jnp.sum(fi * ci[1], 2) / rho
    
    return rho, ux, uy

In [None]:
_, ux, uy = get_macro_quantities(fi)

plt.figure(figsize=(3, 2))
contour = plt.contour(X, Y, ux, levels=500, cmap='viridis', vmin=-0.1, vmax=0.1)
plt.colorbar(contour)
plt.title("ux")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(alpha=0.3)
plt.show()

In [None]:
fi1t = []
step = 0
u1 = []
ux1, uy1, um1 = [], [], []
A_ij1, B_ijk1 = get_matrix(omega1)
for it in range(Nt):
    fi1t.append(fi1)
    _, uxit, uyit = get_macro_quantities(fi1)
    fi1, g_jk1 = collision_BGK_2nd_carleman(fi1, A_ij1, g_jk1, B_ijk1)
    fi1, g_jk1 = streaming(fi1, g_jk1)
    
    ue = jnp.log2(uxit[16, 16] / ux[16, 16])
    umit = jnp.linalg.norm(jnp.stack([uxit, uyit]), axis=0)
    u1.append(ue)
    ux1.append(uxit)
    uy1.append(uyit)
    um1.append(umit)
    print(step)
    step += 1
    
    
fi2t = []
u2 = []
A_ij2, B_ijk2 = get_matrix(omega2)
for it in range(Nt):
    fi2t.append(fi2)
    _, uxit, uyit = get_macro_quantities(fi2)
    fi2, g_jk2 = collision_BGK_2nd_carleman(fi2, A_ij2, g_jk2, B_ijk2)
    fi2, g_jk2 = streaming(fi2, g_jk2)
    
    ue = jnp.log2(uxit[16, 16] / ux[16, 16])
    u2.append(ue)
    print(step)
    step += 1
    
    
fi3t = []
u3 = []
A_ij3, B_ijk3 = get_matrix(omega3)
for it in range(Nt):
    fi3t.append(fi3)
    _, uxit, uyit = get_macro_quantities(fi3)
    fi3, g_jk3 = collision_BGK_2nd_carleman(fi3, A_ij3, g_jk3, B_ijk3)
    fi3, g_jk3 = streaming(fi3, g_jk3)
    
    ue = jnp.log2(uxit[16, 16] / ux[16, 16])
    u3.append(ue)
    print(step)
    step += 1

In [None]:
plt.figure(figsize=(6, 4))
plt.plot(t, u1, label="$w = 1$", color="blue")
plt.plot(t, u2, label="$w = 1.5$", color="red")
plt.plot(t, u3, label="$w = 1.9$", color="green")
plt.xlabel("t")
plt.ylabel("$log_2(u/u0)$")
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.legend()
plt.show()

In [None]:
# fi1t = jnp.array(fi1t)
# fi2t = jnp.array(fi2t)
# fi3t = jnp.array(fi3t)
# fi_all = jnp.stack((fi1t, fi2t, fi3t), axis=0)
# jnp.save("fi_2nd carleman", fi_all)

In [None]:
# Create animations for both velocity magnitude and curl
fig, axes = plt.subplots(3, 1, figsize=(6, 12))
ax1, ax2, ax3 = axes
frames = Nt

def init():
    # For ux
    contour1 = ax1.contourf(X, Y, ux1[0], levels=500, cmap='viridis', vmin=-0.1, vmax=0.1)
    plt.colorbar(contour1, ax=ax1)

    # For uy
    contour2 = ax2.contourf(X, Y, uy1[0], levels=500, cmap='viridis', vmin=-0.1, vmax=0.1)
    plt.colorbar(contour2, ax=ax2)
    
    # For u magnitude
    contour3 = ax3.contourf(X, Y, um1[0], levels=500, cmap='viridis', vmin=-0.1, vmax=0.1)
    plt.colorbar(contour3, ax=ax3)

def update(frame):
    # For velocity magnitude
    ax1.clear()
    ax1.contourf(X, Y, ux1[frame], levels=500, cmap='viridis', vmin=-0.1, vmax=0.1)
    ax1.set_title(f"Ux (Time step {frame + 1}/{frames})")

    # For vorticities
    ax2.clear()
    ax2.contourf(X, Y, uy1[frame], levels=500, cmap='viridis', vmin=-0.1, vmax=0.1)
    ax2.set_title(f"Uy (Time step {frame + 1}/{frames})")

    # For vorticities
    ax3.clear()
    ax3.contourf(X, Y, um1[frame], levels=500, cmap='viridis', vmin=-0.1, vmax=0.1)
    ax3.set_title(f"U magnitude (Time step {frame + 1}/{frames})")
    
# Create the animation
anim = FuncAnimation(fig, update, frames, init_func=init, interval=50)
anim.save("velocity_2nd CL.gif", writer="pillow", fps=50)
plt.show()