In [1]:
import os
import json
import jax
from time import time
from jax import config
import numpy as np
import jax.numpy as jnp
import flax.linen as nn
from termcolor import colored
from src.base import LBMExternalForce
from src.utils import *
from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
jax.config.update('jax_enable_x64', True)


In [2]:
class Cylinder(BGKSim):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def set_boundary_conditions(self):
        # Define the cylinder surface
        coord = np.array([(i, j) for i in range(self.nx) for j in range(self.ny)])
        xx, yy = coord[:, 0], coord[:, 1]
        cx, cy = 2.*_diam, 2.*_diam
        cylinder = (xx - cx)**2 + (yy-cy)**2 <= (_diam/2.)**2
        cylinder = coord[cylinder]
        implicit_distance = np.reshape((xx - cx)**2 + (yy-cy)**2 - (_diam/2.)**2, (self.nx, self.ny))
        self.BCs.append(InterpolatedBounceBackBouzidi(tuple(cylinder.T), implicit_distance, self.gridInfo, self.precisionPolicy))

        # Outflow BC
        outlet = self.boundingBoxIndices['right']
        rho_outlet = np.ones((outlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype)
        self.BCs.append(ExtrapolationOutflow(tuple(outlet.T), self.gridInfo, self.precisionPolicy))
        # self.BCs.append(ZouHe(tuple(outlet.T), self.gridInfo, self.precisionPolicy, 'pressure', rho_outlet))

        # Inlet BC
        inlet = self.boundingBoxIndices['left']
        rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype)
        vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype)
        yy_inlet = yy.reshape(self.nx, self.ny)[tuple(inlet.T)]
        vel_inlet[:, 0] = poiseuille_profile(yy_inlet,
                                             yy_inlet.min(),
                                             yy_inlet.max()-yy_inlet.min(), 3.0 / 2.0 * _prescribed_vel)
        self.BCs.append(Regularized(tuple(inlet.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_inlet))

        # No-slip BC for top and bottom
        wall = np.concatenate([self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']])
        vel_wall = np.zeros(wall.shape, dtype=self.precisionPolicy.compute_dtype)
        self.BCs.append(Regularized(tuple(wall.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall))

    def output_data(self, **kwargs):
        self.saved_data.append(kwargs)
        # 1:-1 to remove boundary voxels (not needed for visualization when using bounce-back)
    
    def get_force(self):
        pass

# Helper function to specify a parabolic poiseuille profile
poiseuille_profile  = lambda x,x0,d,umax: np.maximum(0.,4.*umax/(d**2)*((x-x0)*d-(x-x0)**2))

In [None]:
from flax import linen as nn

class ConvBlock(nn.Module):
  """Defines a convolutional block with activation and normalization."""
  features: int
  kernel_size: int = (3,3)
  strides: int = 1

  @nn.compact
  def __call__(self, inputs):
    x = nn.Conv(self.features, kernel_size=self.kernel_size, strides=self.strides, padding='SAME')(inputs)
    x = nn.BatchNorm(use_running_average=True)(x)
    x = nn.relu(x)
    return x

class DownBlock(nn.Module):
  """Downsamples feature maps through convolutions and pooling."""
  features: int
  pool_factor: int = 2

  @nn.compact
  def __call__(self, x):
    x = ConvBlock(self.features)(x)
    x = ConvBlock(self.features)(x)
    return x

class UpBlock(nn.Module):
  """Upsamples feature maps and concatenates with features from the contracting path."""
  features: int
  up_factor: int = 2

  @nn.compact
  def __call__(self, x):
    x = ConvBlock(self.features)(x)
    x = ConvBlock(self.features)(x)
    x = nn.ConvTranspose(self.features, kernel_size=(2, 2), strides=self.up_factor, padding='VALID')(x)
    return x

class UNet(nn.Module):
  """UNet architecture with contracting and expanding paths."""
  features_start: int = 64

  @nn.compact
  def __call__(self, x):
    input_shape = x.shape
    # Contracting path
    down1 = DownBlock(self.features_start * 2)(x)
    down1_max_pooled = nn.max_pool(down1, window_shape=(2, 2), strides=(2, 2))
    down2 = DownBlock(self.features_start * 4)(down1_max_pooled)
    down2_max_pooled = nn.max_pool(down2, window_shape=(2, 2), strides=(2, 2))
    down3 = DownBlock(self.features_start * 8)(down2_max_pooled)
    down3_max_pooled = nn.max_pool(down3, window_shape=(2, 2), strides=(2, 2))
    down4 = DownBlock(self.features_start * 16)(down3_max_pooled)
    down4_max_pooled = nn.max_pool(down4, window_shape=(2, 2), strides=(2, 2))
    
    # Expanding path with concatenation
    up1 = UpBlock(self.features_start * 16)(down4_max_pooled)
    down4_sliced = jax.lax.slice(down4, (4, 4, 0),(down4.shape[0]-4, down4.shape[1]-4, down4.shape[2]))
    up1_concatenated = jax.lax.concatenate([down4_sliced, up1], dimension=2)
    up2 = UpBlock(self.features_start * 4)(up1_concatenated)
    down3_sliced = jax.lax.slice(down3, (4, 4, 0), (down3.shape[0]-4, down3.shape[1]-4, down3.shape[2]))
    up2_concatenated = jax.lax.concatenate([down3_sliced, up2], dimension=2)
    up3 = UpBlock(self.features_start * 2)(up2_concatenated)
    print(up3.shape)
    return up3
  
class SimpleNet(nn.Module):
    features: int = 32
    kernel_size: int = (5, 5)
    strides: int = 1
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(self.features, kernel_size=self.kernel_size, strides=self.strides, padding='SAME')(x)
        x = nn.leaky_relu(x)
        x = nn.Conv(self.features, kernel_size=self.kernel_size, strides=self.strides, padding='SAME')(x)
        x = nn.leaky_relu(x)
        x = nn.Conv(self.features, kernel_size=self.kernel_size, strides=self.strides, padding='SAME')(x)
        x = nn.leaky_relu(x)
        x = nn.Conv(2, kernel_size=self.kernel_size, strides=self.strides, padding='SAME')(x)
        return x


In [None]:
my_unet = SimpleNet()
print(my_unet.tabulate(jax.random.key(0), jnp.ones((220,41,3)))) # check parameters, should be 55,298 in total

In [3]:
class BGKSimForce(LBMExternalForce):
    """
    BGK simulation class.

    This class implements the Bhatnagar-Gross-Krook (BGK) approximation for the collision step in the Lattice Boltzmann Method.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @partial(jit, static_argnums=(0,))
    def collision(self, f, feq, rho, u):
        """
        BGK collision step for lattice.

        The collision step is where the main physics of the LBM is applied. In the BGK approximation, 
        the distribution function is relaxed towards the equilibrium distribution function.
        """
        fneq = f - feq
        fout = f - self.omega * fneq
        return self.precisionPolicy.cast_to_output(fout)

class Cylinder(BGKSimForce):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def set_boundary_conditions(self):
        # Define the cylinder surface
        coord = np.array([(i, j) for i in range(self.nx) for j in range(self.ny)])
        xx, yy = coord[:, 0], coord[:, 1]
        cx, cy = 2.*_diam, 2.*_diam
        cylinder = (xx - cx)**2 + (yy-cy)**2 <= (_diam/2.)**2
        cylinder = coord[cylinder]
        implicit_distance = np.reshape((xx - cx)**2 + (yy-cy)**2 - (_diam/2.)**2, (self.nx, self.ny))
        self.BCs.append(InterpolatedBounceBackBouzidi(tuple(cylinder.T), implicit_distance, self.gridInfo, self.precisionPolicy))

        # Outflow BC
        outlet = self.boundingBoxIndices['right']
        rho_outlet = np.ones((outlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype)
        self.BCs.append(ExtrapolationOutflow(tuple(outlet.T), self.gridInfo, self.precisionPolicy))
        # self.BCs.append(ZouHe(tuple(outlet.T), self.gridInfo, self.precisionPolicy, 'pressure', rho_outlet))

        # Inlet BC
        inlet = self.boundingBoxIndices['left']
        rho_inlet = np.ones((inlet.shape[0], 1), dtype=self.precisionPolicy.compute_dtype)
        vel_inlet = np.zeros(inlet.shape, dtype=self.precisionPolicy.compute_dtype)
        yy_inlet = yy.reshape(self.nx, self.ny)[tuple(inlet.T)]
        vel_inlet[:, 0] = poiseuille_profile(yy_inlet,
                                             yy_inlet.min(),
                                             yy_inlet.max()-yy_inlet.min(), 3.0 / 2.0 * _prescribed_vel)
        self.BCs.append(Regularized(tuple(inlet.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_inlet))

        # No-slip BC for top and bottom
        wall = np.concatenate([self.boundingBoxIndices['top'], self.boundingBoxIndices['bottom']])
        vel_wall = np.zeros(wall.shape, dtype=self.precisionPolicy.compute_dtype)
        self.BCs.append(Regularized(tuple(wall.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall))

    def output_data(self, **kwargs):
        self.saved_data.append(kwargs)
        # 1:-1 to remove boundary voxels (not needed for visualization when using bounce-back)
    
    @partial(jit, static_argnums=(0,))
    def get_force(self, f_postcollision, feq, rho, u):
        pass

In [6]:
def generate_sim_dataset(diam, t_start, t_end, output_stride, output_offset):
    global _diam
    global _prescribed_vel
    _diam = diam
    precision = 'f64/f64'
    # diam_list = [10, 20, 30, 40, 60, 80]
    scale_factor = 80 / diam
    prescribed_vel = 0.003 * scale_factor
    _prescribed_vel = prescribed_vel
    lattice = LatticeD2Q9(precision)

    nx = int(22*diam)
    ny = int(4.1*diam)

    Re = 100.0
    visc = prescribed_vel * diam / Re
    omega = 1.0 / (3. * visc + 0.5)
    kwargs = {
        'lattice': lattice,
        'omega': omega,
        'nx': nx,
        'ny': ny,
        'nz': 0,
        'precision': precision,
        'return_fpost': True    # Need to retain fpost-collision for computation of lift and drag
    }
    # characteristic time
    tc = prescribed_vel/diam
    if t_end < int(100//tc):
        print(colored("WARNING: timestep_end is too small, Karman flow may not appear. Recommend value is {}".format(int(100//tc)), "red"))
    sim = Cylinder(**kwargs)
    sim.run(t_end, t_start, output_offset, output_stride)
    return sim.saved_data

In [7]:
def generate_sim_dataset_with_profile(diam, t_start, t_end, output_stride, output_offset):
    with jax.profiler.trace("/tmp/tensorboard"):
        generated_data = generate_sim_dataset(diam, t_start, t_end, output_stride, output_offset)
    return generated_data

In [8]:
generated_data = generate_sim_dataset_with_profile(20, 0, 10000, 10000, 0)

[32m**** Simulation Parameters for Cylinder ****[0m
            [34mParameter[0m | [33mValue[0m
--------------------------------------------------
                [34mOmega[0m | [33m1.971608832807571[0m
     [34mGrid Points in X[0m | [33m440[0m
     [34mGrid Points in Y[0m | [33m82[0m
     [34mGrid Points in Z[0m | [33m0[0m
       [34mDimensionality[0m | [33m2[0m
     [34mPrecision Policy[0m | [33mf64/f64[0m
         [34mLattice Type[0m | [33mD2Q9[0m
      [34mCheckpoint Rate[0m | [33m0[0m
 [34mCheckpoint Directory[0m | [33m./checkpoints[0m
  [34mDownsampling Factor[0m | [33m1[0m
      [34mPrint Info Rate[0m | [33m100[0m
             [34mI/O Rate[0m | [33m0[0m
        [34mCompute MLUPS[0m | [33mFalse[0m
   [34mRestore Checkpoint[0m | [33mFalse[0m
              [34mBackend[0m | [33mgpu[0m
    [34mNumber of Devices[0m | [33m1[0m
Time to create the grid mask: 0.13919281959533691
Time to create the local masks and normal

  0%|          | 0/10001 [00:00<?, ?it/s]


TypeError: unsupported operand type(s) for +: 'DynamicJaxprTracer' and 'NoneType'

In [None]:
from tqdm import tqdm
import numpy as np

def read_data():
    res_data = []
    total_batch = 1
    for i in tqdm(range(total_batch)):
        loaded_data = jnp.load('./data/ref_data_diam_80_seq_{}.npy'.format(i))
        res_data.append(loaded_data)
    return jnp.concatenate(res_data, axis=0)

def read_data_and_downsample():
    res_data = []
    total_batch = 10
    for i in tqdm(range(total_batch)):
        loaded_data = np.load('./data/ref_data_diam_80_seq_{}.npy'.format(i))
        downsampled_list = [downsample_field(field, 8) for field in loaded_data]
        res_data.append(downsampled_list)
    return np.concatenate(res_data, axis=0)

In [None]:
import matplotlib.pyplot as plt
def visualize_data(data, imgs=20, field='u'):
    fig, axs = plt.subplots(1, imgs, figsize=(32, 10))
    max_val = generated_data[0]['u'][np.where(generated_data[0]['u'] == generated_data[0]['u'].max())]
    min_val = generated_data[0]['u'][np.where(generated_data[0]['u'] == generated_data[0]['u'].min())]
    for i in range(imgs):
        img = (data[i*(len(data)//imgs)][field] - min_val)/(max_val - min_val)
        img = np.concatenate((img, np.zeros((img.shape[0], img.shape[1], 1))), axis=2)
        axs[i].imshow(img)
        axs[i].set_title("T={}".format(data[i*(len(data)//imgs)]['timestep']))

In [None]:
visualize_data(generated_data, 20)