# LDC Implementation in Python/JAX
Written By: Sungje Park

In [1]:
import jax
from jax import numpy as jnp
from jax import Array
from jax.typing import ArrayLike
from jax.tree_util import register_pytree_node_class

import sys
sys.path.append('..')
from utils.utils import *

Defining Vars.

In [2]:
rho_ref = 8 # reference density
Mu = .1 # kinematic Viscosity
Re = 100 # reynolds number
N_x = 100 # number of cells in the x dir
N_y = N_x # number of cells in the y dir
dx = 1 # spacing in x
dy = 1 # spacing in y
U_lid = Re*Mu/N_x # velocity of lid

In [3]:
class Dynamics():

    #Static 
    DIM: int
    NUM_QUIVERS: int
    KSI: Array   # Matrix containing Define in sub-class
    W: Array     # Define in sub-class

    def calc_eq(self,pdf: ArrayLike):# Calculate Eq
        pass

    def ones_pdf(self):
        return jnp.ones(self.NUM_QUIVERS)

class D2Q9(Dynamics):
    DIM = 2
    NUM_QUIVERS = 9
    KSI = jnp.array([[0,0],  # center
                     [1,0],  # right
                     [0,1],  # top
                     [-1,0], # left
                     [0,-1], # bottom
                     [1,1],  # top-right
                     [-1,1], # top-left
                     [-1,-1],# bot-left
                     [1,-1]])# bot-right
    W = jnp.array([4/9,1/9,1/9,1/9,1/9,1/36,1/36,1/36,1/36])
    C = 1/jnp.sqrt(3)
    def __init__(self):
        super().__init__()
    
    def calc_eq(self,rho: ArrayLike,vel:ArrayLike):
        return self.W*rho*(1+(jnp.dot(self.KSI,vel))/(self.C**2)+(jnp.dot(self.KSI,vel))**2/(2*self.C**4)-(jnp.dot(vel,vel))/(2*self.C**2))
dynamics = D2Q9()

In [25]:
class Element:
    def __init__(self):
        pass

class ElementContainer:
    def __init__(self,elements: list[Element]):
        self.elements = elements

    def flatten(self):
        assert all(self.elements[0].__class__ == element.__class__ for element in self.elements)

        return pad_stack_tree([element.__dict__ for element in self.elements])

class Cell(Element):
    def __init__(self):
        self.f = jnp.ones((dynamics.NUM_QUIVERS,))
        self.u = jnp.zeros((dynamics.DIM,))
        self.rho = jnp.zeros((1,))
        self.faces = jnp.asarray([])

class Face:
    def __init__(self):
        self.f = jnp.ones((dynamics.NUM_QUIVERS,))
        self.cells = []

class LDC:
    def __init__(self):
        self.cells = ElementContainer([Cell() for i in range(N_y*N_x)])
        self.faces = ElementContainer([Face() for i in range(N_y*(N_x+1)*2)])

    def cell_tuples(self,cell):
        return cell.__dict__
        
    def init(self):
        return {"cells": self.cells.flatten(),"faces": self.cells.flatten()}

In [26]:
model = LDC()
model.cells.elements[0].faces = jnp.asarray([1,2,3,4])

In [29]:
params = model.init()
params

{'cells': {'f': Array([[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]], dtype=float32),
  'faces': Array([[ 1.,  2.,  3.,  4.],
         [-1., -1., -1., -1.],
         [-1., -1., -1., -1.],
         ...,
         [-1., -1., -1., -1.],
         [-1., -1., -1., -1.],
         [-1., -1., -1., -1.]], dtype=float32),
  'rho': Array([[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]], dtype=float32),
  'u': Array([[0., 0.],
         [0., 0.],
         [0., 0.],
         ...,
         [0., 0.],
         [0., 0.],
         [0., 0.]], dtype=float32)},
 'faces': {'f': Array([[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., 