In [1]:
import os
import timeit
import numpy as np
import jax.numpy as jnp
from jax import jit, lax
import matplotlib.pyplot as plt

from scipy.interpolate import griddata
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.animation as animation

os.environ["CUDA_DEVICE_ORDER"] = (
    "PCI_BUS_ID"  # see issue #152 # comment this out for mac cpu
)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

### Solve a transient diffusion equation

The transient diffusion equation reads

$$\frac{\partial \rho \phi}{\partial t}+\nabla \cdot \left(-\Gamma\nabla \phi\right)=S_Q,$$

For simplicity we let $S_Q = 0$ 

In [2]:
# Read in grid
from grids import Grid,BoundaryConditions

grid = Grid(
    np.load("/home/yongqi/JAX-IGA/jax_torch_fvm/raw_data/mesh/mesh_basic_np.npz")
)

Now let us set up ground truth and physical quantities $\rho = 1$ and $\gamma = 1$

First let us validate that in our case $$\frac{\partial  \phi}{\partial t} == \nabla \cdot \left(\nabla \phi\right)$$

In [3]:
import sympy as sp

# Define the symbolic variables
x, y, z, t = sp.symbols('x y z t')

# Step 3: Define the scalar field (e.g., phi)
phi = 10e4*sp.sin(sp.pi * x) * sp.sin(sp.pi * y) * sp.sin(sp.pi * z) *sp.exp(-sp.pi*3*t)

# Step 4: Compute the Laplacian
laplacian_phi = sp.diff(phi, x, x) + sp.diff(phi, y, y) + sp.diff(phi, z, z)

# Simplify the result
laplacian_phi = sp.simplify(laplacian_phi)

# Display the Laplacian
laplacian_phi



-300000.0*pi**2*exp(-3*pi*t)*sin(pi*x)*sin(pi*y)*sin(pi*z)

In [4]:
grad_phi = sp.diff(phi,t)
grad_phi = sp.simplify(grad_phi)
grad_phi

-300000.0*pi*exp(-3*pi*t)*sin(pi*x)*sin(pi*y)*sin(pi*z)

In [5]:
# Substitute the specific values into the solution
phi_value = phi.subs({x: 0.51682734, y: 0.9421734, z: 0.5146519, t: 1.0})

# Evaluate the expression numerically
phi_value_numeric = phi_value.evalf()

phi_value_numeric

1.45441920393019

In [6]:
# diffusion coefficient field
Gamma = 1.0

def fx(x):
    return jnp.sin(jnp.pi * x)

def fy(x):
    return jnp.sin(jnp.pi * x)

def fz(x):
    return jnp.sin(jnp.pi * x)

def ft(x):
    return jnp.exp(-jnp.pi*3*x)

def get_gt_foi(X):
    x,y,z,t = X[:,0], X[:,1], X[:,2], X[:,3]
    return 10e4*fx(x)*fy(y)*fz(z)*ft(t)
    
def get_source(X):
    return jnp.sum(0.0 * X, axis=1)
    
def get_gamma(X):
    return Gamma + jnp.sum(0.0 * X, axis=1)
    # return 1.004+ jnp.sum(0.*x, axis = 1)


def get_initial_guess(X):
    x,y,z,t = X[:,0], X[:,1], X[:,2], X[:,3]
    return 10e4*fx(x)*fy(y)*fz(z)*ft(1)



fvSchemes = {
    "ddtSchemes": "steadyState",
    "gradSchemes": "Gauss linear",
    #  'gradSchemes': 'Gauss linear corrected',
    "divSchemes": "none",
    "laplacianSchemes": "Gauss linear corrected",
    "interpolateionSchemes": "linear",
    "snGradSchemes": "corrected",
}

bd_infos = {
    "T": (
        ("top", 0, "Uniform", 0.0),
        ("bottom", 0, "Uniform", 0.0),
        ("back", 0, "Uniform", 0.0),
        ("right", 0, "Uniform", 0.0),
        ("front", 0, "Uniform", 0.0),
        ("left", 0, "Uniform", 0.0),
    )
}

controlDict = {
    "startTime": 1.0,
    "endTime": 1.1,
    "deltaT": 0.01,
}

Simulation_input = {
    "controlDict": controlDict,
    "source": get_source,
    "gamma": Gamma,
    "rho": 1.0,
    "gt": get_gt_foi,
    "fvSchemes": fvSchemes,
    
}


In [7]:
# Set up simulation
grid.SetUpSimulation("transient Diffusion", **Simulation_input)

print("Gamma:",grid.gamma)
print("Shape of gamma array == Num of faces:", grid.gamma.shape[0] == grid.N_f,"\n")
print("Source Term:", grid.source)
print("Shape of source array == Num of cell center nodes:", grid.source.shape[0] == grid.N_c,"\n")
print("rho:", grid.rho)
print("Shape of rho array == Num of cell center nodes:", grid.rho.shape[0] == grid.N_c,"\n")

# Creat initial field object with B.C.
from initial_condition import InitialScalarField

# We apply Dirichlet boundary conditions on all boundaries of the domain:
mybc = BoundaryConditions(bd_infos["T"], grid)
v0 = InitialScalarField(get_initial_guess, grid, mybc, name="T").UpdateBoundaryPhi(grid)
# v0 = grid.gt[0]

Gamma: [1. 1. 1. ... 1. 1. 1.]
Shape of gamma array == Num of faces: True 

Source Term: [0. 0. 0. ... 0. 0. 0.]
Shape of source array == Num of cell center nodes: True 

rho: [1. 1. 1. ... 1. 1. 1.]
Shape of rho array == Num of cell center nodes: True 



In [8]:
v0

GridVariable(cell_phi=Array([1.454418  , 0.9884212 , 1.7486279 , ..., 0.91478175, 0.3516839 ,
       0.6350101 ], dtype=float32), bd_phi=Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), bc=BoundaryConditions(bd_names=('bottom', 'top', 'left', 'right', 'front', 'back'), bd_types=('Dirichlet', 'Dirichlet', 'Dirichlet', 'Dirichlet', 'Dirichlet', 'Dirichlet'), bd_infos=(('bottom', 0, 'Uniform', 0.0, 17895, 18301), ('top', 0, 'Uniform', 0.0, 18301, 18707), ('left', 0, 'Uniform', 0.0, 18707, 19111), ('right', 0, 'Uniform', 0.0, 19111, 19517), ('front', 0, 'Uniform', 0.0, 19517, 19921), ('back', 0, 'Uniform', 0.0, 19921, 20325))), name='T')

Now let us generate the animation of gt first

In [9]:
# # regular points:
# mx, my, mz = jnp.meshgrid(
#     jnp.linspace(0, 1, 50), jnp.linspace(0, 1, 50), jnp.linspace(0, 1, 50)
# )
# mxyz = jnp.concatenate((mx[..., None], my[..., None], mz[..., None]), axis=-1)

# cx = 10
# cy = 20
# cz = 30

# # Prepare the figure and axes
# fig, ax = plt.subplots(1, 3, figsize=(30, 24))

# # Set up the plot with initial data
# gt_field = grid.gt[0]
# gt_inter_foi_yz = griddata(grid.c_pos, gt_field, mxyz[cx, :, :], method="linear")
# gt_inter_foi_xz = griddata(grid.c_pos, gt_field, mxyz[:, cy, :], method="linear")
# gt_inter_foi_xy = griddata(grid.c_pos, gt_field, mxyz[:, :, cz], method="linear")

# # Initial data for images
# im_yz = ax[0].imshow(gt_inter_foi_yz, extent=(0, 1, 0, 1), cmap='coolwarm')
# im_xz = ax[1].imshow(gt_inter_foi_xz, extent=(0, 1, 0, 1), cmap='coolwarm')
# im_xy = ax[2].imshow(gt_inter_foi_xy, extent=(0, 1, 0, 1), cmap='coolwarm')

# # Create colorbars with fixed limits
# divider_yz = make_axes_locatable(ax[0])
# divider_xz = make_axes_locatable(ax[1])
# divider_xy = make_axes_locatable(ax[2])

# cax_yz = divider_yz.append_axes("right", size="5%", pad=0.05)
# cax_xz = divider_xz.append_axes("right", size="5%", pad=0.05)
# cax_xy = divider_xy.append_axes("right", size="5%", pad=0.05)

# cbar_yz = fig.colorbar(im_yz, cax=cax_yz)
# cbar_xz = fig.colorbar(im_xz, cax=cax_xz)
# cbar_xy = fig.colorbar(im_xy, cax=cax_xy)

# # Label axes
# ax[0].set_xlabel("z")
# ax[0].set_ylabel("y")
# ax[1].set_xlabel("z")
# ax[1].set_ylabel("x")
# ax[2].set_xlabel("y")
# ax[2].set_ylabel("x")

# # Function to update each frame of the animation
# def update(frame_idx):
#     gt_field = grid.gt[frame_idx]
#     gt_inter_foi_yz = griddata(grid.c_pos, gt_field, mxyz[cx, :, :], method="linear")
#     gt_inter_foi_xz = griddata(grid.c_pos, gt_field, mxyz[:, cy, :], method="linear")
#     gt_inter_foi_xy = griddata(grid.c_pos, gt_field, mxyz[:, :, cz], method="linear")

#     # Update image data
#     im_yz.set_data(gt_inter_foi_yz)
#     im_xz.set_data(gt_inter_foi_xz)
#     im_xy.set_data(gt_inter_foi_xy)

#     return im_yz, im_xz, im_xy

# # Create the animation
# ani = animation.FuncAnimation(fig, update, frames=len(grid.gt), blit=True)

# # Save the animation as a GIF
# ani.save('grid_animation.gif', writer='imagemagick', fps=2)

# plt.show()


In [10]:
# # Solve the problem in 3D
from assemble import *
from advection_diffusion import *
from pdeSolver import *

step_fn = partial(Transient_Downwind_jit, grid=grid)
def solvePDE_explicit_jit(f: Callable, ) -> Callable:
    """Returns a repeatedly applied version of f()."""
    steps = len(jnp.arange(grid.startTime, grid.endTime, grid.deltaT))
    
    def f_repeated(x_initial):
        g = lambda x, _: f(x)
        x_final, residual = lax.scan(g, x_initial, xs=None, length=steps)
        return x_final,residual
    return f_repeated

repeated_fn = jit(solvePDE_explicit_jit(step_fn))

def solvePDE_explicit(v0, grid):
    analysis = []
    steps = len(jnp.arange(grid.startTime, grid.endTime, grid.deltaT))
    for _ in range (steps):
        v0, axillary = Transient_Downwind(v0, grid)
        analysis.append((v0,axillary))
    return analysis


# def wrapper1():
#     return solvePDE_explicit(v0, grid)  # .block_until_ready() 


# def wrapper2():
#     return repeated_fn(v0)  # .block_until_ready()


# print("Before:", timeit.timeit(wrapper1, number=7))
# print("After:", timeit.timeit(wrapper2, number=7))

result = solvePDE_explicit(v0,grid)

In [11]:
# from scipy.interpolate import griddata

# # regular points:
# mx, my, mz = jnp.meshgrid(
#     jnp.linspace(0, 1, 50), jnp.linspace(0, 1, 50), jnp.linspace(0, 1, 50)
# )
# mxyz = jnp.concatenate((mx[..., None], my[..., None], mz[..., None]), axis=-1)

# solution = result[-1][0].cell_phi
# gt_field = grid.gt[-1]

# fig, ax = plt.subplots(3, 3, figsize=(30, 24))
# # lw=1
# cx = 10
# cy = 20
# cz = 30

# solution_inter_foi_yz = griddata(grid.c_pos, solution, mxyz[cx, :, :], method="linear")
# solution_inter_foi_xz = griddata(grid.c_pos, solution, mxyz[:, cy, :], method="linear")
# solution_inter_foi_xy = griddata(grid.c_pos, solution, mxyz[:, :, cz], method="linear")


# gt_inter_foi_yz = griddata(grid.c_pos, gt_field, mxyz[cx, :, :], method="linear")
# gt_inter_foi_xz = griddata(grid.c_pos, gt_field, mxyz[:, cy, :], method="linear")
# gt_inter_foi_xy = griddata(grid.c_pos, gt_field, mxyz[:, :, cz], method="linear")

# diff_inter_foi_yz = griddata(
#     grid.c_pos, solution - gt_field, mxyz[cx, :, :], method="linear"
# )
# diff_inter_foi_xz = griddata(
#     grid.c_pos, solution - gt_field, mxyz[:, cy, :], method="linear"
# )
# diff_inter_foi_xy = griddata(
#     grid.c_pos, solution - gt_field, mxyz[:, :, cz], method="linear"
# )


# data = [
#     [solution_inter_foi_yz, solution_inter_foi_xz, solution_inter_foi_xy],
#     [gt_inter_foi_yz, gt_inter_foi_xz, gt_inter_foi_xy],
#     [diff_inter_foi_yz, diff_inter_foi_xz, diff_inter_foi_xy],
# ]
# data_labels = ["solution", "gt", "difference"]
# for i in range(len(data)):
#     imx = ax[i, 0].imshow(data[i][0], label="yz", extent=(0, 1, 0, 1), cmap='coolwarm',vmin=0,vmax=4)
#     imy = ax[i, 1].imshow(data[i][1], label="xz", extent=(0, 1, 0, 1),  cmap='coolwarm',vmin=0,vmax=7)
#     imz = ax[i, 2].imshow(data[i][2], label="xy", extent=(0, 1, 0, 1),  cmap='coolwarm',vmin=0,vmax=7)
#     cbarx = fig.colorbar(imx)
#     cbary = fig.colorbar(imy)
#     cbarz = fig.colorbar(imz)
#     ax[i, 0].set_xlabel("z")
#     ax[i, 0].set_ylabel("y")
#     ax[i, 0].set_title(data_labels[i] + "@ x = " + str(cx / 50))
#     ax[i, 1].set_xlabel("z")
#     ax[i, 1].set_ylabel("x")
#     ax[i, 1].set_title(data_labels[i] + "@ y = " + str(cy / 50))
#     ax[i, 2].set_xlabel("y")
#     ax[i, 2].set_ylabel("x")
#     ax[i, 2].set_title(data_labels[i] + "@ z = " + str(cz / 50))

In [12]:
# residuals = result[1]
# errors_dict = {"residual": np.array(jnp.mean(abs(residuals), axis=1))}
# fig, ax = plt.subplots(figsize=(12, 8))
# # lw=1
# for key, value in errors_dict.items():
#     ax.plot(value, label=key, linewidth=3)
# ax.set_yscale("log")
# ax.set_xlabel("iteration")
# ax.set_ylabel("Residual")
# leg = ax.legend(loc="upper right", frameon=True)
# leg.get_frame().set_edgecolor("black")

In [13]:
# # # Solve the problem in 3D
# from assemble import *
# from advection_diffusion import *
# from pdeSolver import *

# step_fn = partial(Transient_upwind_jit, grid=grid)
# def solvePDE_implicit_jit(f: Callable, ) -> Callable:
#     """Returns a repeatedly applied version of f()."""
#     steps = len(jnp.arange(grid.startTime, grid.endTime, grid.deltaT))
    
#     def f_repeated(x_initial):
#         g = lambda x, _: f(x)
#         x_final, residual = lax.scan(g, x_initial, xs=None, length=steps)
#         return x_final,residual
#     return f_repeated

# repeated_fn_implicit = jit(solvePDE_implicit_jit(step_fn))

# def solvePDE_implicit(v0, grid):
#     analysis = []
#     steps = len(jnp.arange(grid.startTime, grid.endTime, grid.deltaT))
#     for _ in range (steps):
#         v0, axillary = Transient_upwind(v0, grid)
#         analysis.append(axillary)
#     return v0, analysis

# def wrapper3():
#     return solvePDE_implicit(v0, grid)  # .block_until_ready() 

# def wrapper4():
#     return repeated_fn_implicit(v0)  # .block_until_ready()


# print("Before:", timeit.timeit(wrapper3, number=7))
# print("After:", timeit.timeit(wrapper4, number=7))
# # result_implicit = solvePDE_implicit(v0,grid)
# result_implicit = wrapper4()

In [14]:
# result_implicit

In [15]:
# from scipy.interpolate import griddata

# # regular points:
# mx, my, mz = jnp.meshgrid(
#     jnp.linspace(0, 1, 50), jnp.linspace(0, 1, 50), jnp.linspace(0, 1, 50)
# )
# mxyz = jnp.concatenate((mx[..., None], my[..., None], mz[..., None]), axis=-1)

# solution = result_implicit[0].cell_phi
# gt_field = grid.gt

# fig, ax = plt.subplots(3, 3, figsize=(30, 24))
# # lw=1
# cx = 10
# cy = 20
# cz = 30

# solution_inter_foi_yz = griddata(grid.c_pos, solution, mxyz[cx, :, :], method="linear")
# solution_inter_foi_xz = griddata(grid.c_pos, solution, mxyz[:, cy, :], method="linear")
# solution_inter_foi_xy = griddata(grid.c_pos, solution, mxyz[:, :, cz], method="linear")


# gt_inter_foi_yz = griddata(grid.c_pos, gt_field, mxyz[cx, :, :], method="linear")
# gt_inter_foi_xz = griddata(grid.c_pos, gt_field, mxyz[:, cy, :], method="linear")
# gt_inter_foi_xy = griddata(grid.c_pos, gt_field, mxyz[:, :, cz], method="linear")

# diff_inter_foi_yz = griddata(
#     grid.c_pos, solution - gt_field, mxyz[cx, :, :], method="linear"
# )
# diff_inter_foi_xz = griddata(
#     grid.c_pos, solution - gt_field, mxyz[:, cy, :], method="linear"
# )
# diff_inter_foi_xy = griddata(
#     grid.c_pos, solution - gt_field, mxyz[:, :, cz], method="linear"
# )


# data = [
#     [solution_inter_foi_yz, solution_inter_foi_xz, solution_inter_foi_xy],
#     [gt_inter_foi_yz, gt_inter_foi_xz, gt_inter_foi_xy],
#     [diff_inter_foi_yz, diff_inter_foi_xz, diff_inter_foi_xy],
# ]
# data_labels = ["solution", "gt", "difference"]
# for i in range(len(data)):
#     imx = ax[i, 0].imshow(data[i][0], label="yz", extent=(0, 1, 0, 1), cmap='coolwarm')
#     imy = ax[i, 1].imshow(data[i][1], label="xz", extent=(0, 1, 0, 1),  cmap='coolwarm')
#     imz = ax[i, 2].imshow(data[i][2], label="xy", extent=(0, 1, 0, 1),  cmap='coolwarm')
#     cbarx = fig.colorbar(imx)
#     cbary = fig.colorbar(imy)
#     cbarz = fig.colorbar(imz)
#     ax[i, 0].set_xlabel("z")
#     ax[i, 0].set_ylabel("y")
#     ax[i, 0].set_title(data_labels[i] + "@ x = " + str(cx / 50))
#     ax[i, 1].set_xlabel("z")
#     ax[i, 1].set_ylabel("x")
#     ax[i, 1].set_title(data_labels[i] + "@ y = " + str(cy / 50))
#     ax[i, 2].set_xlabel("y")
#     ax[i, 2].set_ylabel("x")
#     ax[i, 2].set_title(data_labels[i] + "@ z = " + str(cz / 50))

In [16]:
# residuals = np.array(result_implicit[1])
# errors_dict = {"residual": np.array(jnp.mean(abs(residuals), axis=1))}
# fig, ax = plt.subplots(figsize=(12, 8))
# # lw=1
# for key, value in errors_dict.items():
#     ax.plot(value, label=key, linewidth=3)
# ax.set_yscale("log")
# ax.set_xlabel("iteration")
# ax.set_ylabel("Residual")
# leg = ax.legend(loc="upper right", frameon=True)
# leg.get_frame().set_edgecolor("black")