# This notebook will try to see how the FEM code is implemented in JAX-FEM
- https://www.dealii.org/current/doxygen/deal.II/index.html (Overall class structure of a FE program)
- https://www.dealii.org/current/doxygen/deal.II/step_2.html (For meshing)
- https://www.dealii.org/current/doxygen/deal.II/step_3.html (Executing a full FE program)

## Step : 1 Weak formulation
- Consider: $\quad \Omega = \textnormal{Domain of the PDE}; \Omega \subset R^{dim}$
        
    - `dim` = dimension (3 for 3D problems, 2 for 2D etc.)
        
- We seek the solution to the following PDE:
    - $$ F = (f(\nabla u), \nabla v) - (t,v) - (b,v) = 0$$
        -  valid for all $v$
        
        
    - where, u = solution field $u(x) \in R^{vec}$,
        - `vec` = Number of components in the solution (1D, 2D, etc.)
        - v = Test function,
        - (a,b) = $ \quad \int_\Omega a : b \quad d\Omega$
        - $\nabla u$ = `grad_u` is the gradient of u
        - t = Neuman B.C. value
        - b = Body force or source term in the PDE

Thus, have to solve F = 0 using finite element approximations.
- So, we choose the basis functions ($\phi$) (basis for representing any test function $v$) so as to approximate

 $$u (x) \approx u^h (x) = \sum_j U_j \phi_j(x)$$

## Step : 2 Setting up the shape (basis) functions

1.  Need a mesh
    - mesh is just a collection of cells (with many nodes or dofs) and their connectivity
        - Discretization of the domain $\Omega$
    - In the code, `gmsh` package is used for creating meshes (see `generate_mesh.py`)
    - `meshio` package allows easy input/outputting of various mesh formats
2. Need a reference finite element.
    - We define the shape functions locally w.r.t the refernce cell!
    - `basix` package for getting reference cells
2. A mapping utility.
    - Mapping reference cell to each physical cell in the mesh
    - This transforms the local shape functions!
3. Ability of enumerate each node in the mesh

### Step 2.1 : Mesh

In [4]:
from jax_am.fem import generate_mesh

# important for us
# 1. Mesh interface class [between gmsh and JAX-FEM]
print(generate_mesh.Mesh)
# Has two attributes, cells and points -> denoting the cell (through it's nodes) and node coordinates

# 2. Creating a cuboidal mesh [For now, only the simplest domain for TopOpt]
print(generate_mesh.box_mesh)

<class 'jax_am.fem.generate_mesh.Mesh'>
<function box_mesh at 0x7f4aa98bff40>


In [5]:
# example code
Lx, Ly, Lz = 1., 1, 1.
Nx, Ny, Nz = 3, 3, 3
data_path = "../data/"
meshio_mesh = generate_mesh.box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, data_path)
jax_mesh = generate_mesh.Mesh(meshio_mesh.points, meshio_mesh.cells_dict['hexahedron'])

Info    : Meshing 1D...
Info    : [  0%] Meshing curve 1 (Extruded)
Info    : [ 10%] Meshing curve 2 (Extruded)
Info    : [ 20%] Meshing curve 3 (Extruded)
Info    : [ 30%] Meshing curve 4 (Extruded)
Info    : [ 40%] Meshing curve 7 (Extruded)
Info    : [ 50%] Meshing curve 8 (Extruded)
Info    : [ 50%] Meshing curve 9 (Extruded)
Info    : [ 60%] Meshing curve 10 (Extruded)
Info    : [ 70%] Meshing curve 12 (Extruded)
Info    : [ 80%] Meshing curve 13 (Extruded)
Info    : [ 90%] Meshing curve 17 (Extruded)
Info    : [100%] Meshing curve 21 (Extruded)
Info    : Done meshing 1D (Wall 0.000177388s, CPU 0.000153s)
Info    : Meshing 2D...
Info    : [  0%] Meshing surface 5 (Extruded)
Info    : [ 20%] Meshing surface 14 (Extruded)
Info    : [ 40%] Meshing surface 18 (Extruded)
Info    : [ 50%] Meshing surface 22 (Extruded)
Info    : [ 70%] Meshing surface 26 (Extruded)
Info    : [ 90%] Meshing surface 27 (Extruded)
Info    : Done meshing 2D (Wall 0.000602827s, CPU 0.000486s)
Info    : Meshin

In [11]:
print('Some points (nodes) in mesh : ',jax_mesh.points[:3,:]) # points = global coordinates
print('Connectivity of first cell in mesh : ',jax_mesh.cells[:1,:]) # cells = global node numbering

Some points (nodes) in mesh :  [[0. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]]
Connectivity of first cell in mesh :  [[ 0  8 32 12 24 36 56 48]]


### 2.2 Finite element
- Uses `basis.py`
- All computations within the reference cell (on their faces too!)
    - Gives us quadrature weights
    - Shape function values
    - Shape function gradients etc.


In [21]:
from jax_am.fem import basis
values, gradients, quadrature_weights = basis.get_shape_vals_and_grads('QUAD8')
print(values.shape) # (no of quadrature points, Nodes per cell)
print(quadrature_weights.shape)
# At each quadrature point, there is a contribution from each shape function. One node has one shape function
# So, we have n_quads * n_nodes entries
# When we consider the gradients, there will be dimensions also (grad_x, grad_y, grad_z if 3D)

ele_type = QUAD8, quad_points.shape = (4, 2)
(4, 8)
(4,)


### Step 2.3 : Mapping
- `core.py`
    - `get_shape_grads()`
    - `get_face_shape_grads()`
    - `get_physical_quad_points()`
    - `get_physical_surface_quad_points()`

In [1]:
def fixed_location(point): # For dirichlet B.C
    return np.isclose(point[0], 0., atol=1e-5)
    
def load_location(point):
    return np.logical_and(np.isclose(point[0], Lx, atol=1e-5), np.isclose(point[2], 0., atol=0.1*Lz+1e-5))

def dirichlet_val(point):
    return 0.

def neumann_val(point):
    return np.array([0., 0., -1e6])

dirichlet_bc_info = [[fixed_location]*3, [0, 1, 2], [dirichlet_val]*3]
neumann_bc_info = [[load_location], [neumann_val]]

In [8]:
from jax_am.common import rectangle_mesh
from jax_am.fem.generate_mesh import create_2d_mesh, Mesh
Lx, Ly = 1., 1
Nx, Ny = 2, 2
data_path = "../data/"
meshio_mesh = create_2d_mesh(Nx, Ny, Lx, Ly, data_path)
jax_mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict['quad'])

Info    : Meshing 1D...
Info    : [  0%] Meshing curve 1 (Extruded)
Info    : [ 30%] Meshing curve 2 (Extruded)
Info    : [ 50%] Meshing curve 3 (Extruded)
Info    : [ 80%] Meshing curve 4 (Extruded)
Info    : Done meshing 1D (Wall 8.3031e-05s, CPU 7.8e-05s)
Info    : Meshing 2D...
Info    : Meshing surface 5 (Extruded)
Info    : Done meshing 2D (Wall 8.1198e-05s, CPU 0.00022s)
Info    : 9 nodes 16 elements
Info    : Writing '../data/msh/2d.msh'...
Info    : Done writing '../data/msh/2d.msh'



In [16]:
jax_mesh.points


array([[0. , 0. , 0. ],
       [1. , 0. , 0. ],
       [0. , 1. , 0. ],
       [1. , 1. , 0. ],
       [0.5, 0. , 0. ],
       [0.5, 1. , 0. ],
       [0. , 0.5, 0. ],
       [1. , 0.5, 0. ],
       [0.5, 0.5, 0. ]])

In [12]:
from jax_am.common import rectangle_mesh
Lx, Ly = 1., 1.
Nx, Ny = 2, 2
mesh_without_gmsh = rectangle_mesh(Nx, Ny, Lx, Ly)
jax_mesh_no_gmsh = Mesh(mesh_without_gmsh.points, mesh_without_gmsh.cells_dict['quad'])

In [15]:
jax_mesh_no_gmsh.points

array([[0. , 0. ],
       [0. , 0.5],
       [0. , 1. ],
       [0.5, 0. ],
       [0.5, 0.5],
       [0.5, 1. ],
       [1. , 0. ],
       [1. , 0.5],
       [1. , 1. ]])