In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K

import datetime
import numpy as np
import matplotlib.pyplot as plt

from utils import *

In [None]:
# set precision (default is 'float32')
K.set_floatx('float64')

In [None]:
##########################################################
# Import trained model from file

model = tf.keras.models.load_model("example_network.tf", 
                                   custom_objects={'rmsre': rmsre}
                                  )

model.summary()

In [None]:
###########################################################
# Simulation Parameters

nx      = 32   # grid size along x
ny      = 32   # grid size along y
niter   = 1000  # total number of steps
dumpit  = 100    # collect data every dumpit iterations
tau     = 1.0  # relaxation time
u0      = 0.01 # initial velocity amplitude

verbose = 0

In [None]:
###########################################################
# Collect stats
ndumps   = int(niter//dumpit)
dumpfile = np.zeros( (ndumps*nx*ny, 4 ) ) 
###########################################################


def data_collector(dumpfile, t, ux, uy, rho):
    it   = t // dumpit
    idx0 =  it   *(nx*ny)
    idx1 = (it+1)*(nx*ny)
    dumpfile[idx0:idx1, 0] = t
    dumpfile[idx0:idx1, 1] = rho.reshape(nx*ny)
    dumpfile[idx0:idx1, 2] = ux.reshape( nx*ny)
    dumpfile[idx0:idx1, 3] = uy.reshape( nx*ny)


In [None]:
##########################################################
# Set Initial conditions

a = b = 1.0

ix, iy = np.meshgrid(range(nx), range(ny), indexing='ij')

x = 2.0*np.pi*(ix / nx)
y = 2.0*np.pi*(iy / ny)

ux =  1.0 * u0 * np.sin(a*x) * np.cos(b*y);
uy = -1.0 * u0 * np.cos(a*x) * np.sin(b*y);

rho = np.ones( (nx, ny))

###########################################################
# Lattice velocities and weights
Q = 9
c, w, cs2, compute_feq = LB_stencil()

###########################################################
# Lattice 
feq = np.zeros((nx, ny, Q))
feq = compute_feq(feq, rho, ux, uy, c, w)

f1 = np.copy(feq)
f2 = np.copy(feq)

In [None]:
###########################################################

data_collector(dumpfile, 0, ux, uy, rho)

###########################################################

m_initial = np.sum(f1.flatten())

###########################################################
# Loop on time steps
for t in range(1, niter):

    # streaming
    for ip in range(Q):
        f1[:, :, ip] = np.roll(np.roll(f2[:, :, ip], c[ip, 0], axis=0), c[ip, 1], axis=1)

    # Calculate density
    rho = np.sum(f1, axis=2)

    # Calculate velocity
    ux = (1./rho)*np.einsum('ijk,k', f1, c[:,0]) 
    uy = (1./rho)*np.einsum('ijk,k', f1, c[:,1])                   

    #########################################
    # ML collision step
    #########################################
    
    # Normalize input data
    fpre = f1.reshape( (nx*ny, Q) )
    norm = np.sum(fpre, axis=1)[:,np.newaxis]
    fpre = fpre / norm

    # Make prediction
    f2 = model.predict( fpre, verbose=verbose)

    # Rescale output
    f2 = norm*f2
    f2 = f2.reshape( (nx, ny, Q) )
    
    #########################################
    
    # Collect data
    if (t % dumpit) == 0: 
        data_collector(dumpfile, t, ux, uy, rho)
        
m_final = np.sum(f2.flatten())


print('Sim ended. Mass err:', np.abs(m_initial-m_final)/m_initial)        

In [None]:
w=3.46*3
h=2.14*3

###################################################################

def sol(t, L, F0, nu): return F0*np.exp(-2*nu*t / (L / (2*np.pi))**2  )

###################################################################

fig = plt.figure(figsize=(w,h))
ax  = fig.add_subplot(111)

tLst = np.arange(0, niter, dumpit)

for i, t in enumerate( tLst ):

    ux  = dumpfile[dumpfile[:,0]==t, 2]
    uy  = dumpfile[dumpfile[:,0]==t, 3]

    Ft = np.average( (ux**2 + uy**2)**0.5  ) 

    if i == 0:
        F0 = Ft 
        ax.semilogy( t, Ft, 'ob', label='lbm')
    else:
        ax.semilogy( t, Ft, 'ob')

nu = (tau-0.5)*(cs2)

ax.semilogy(tLst, sol(tLst, nx, F0, nu), linewidth=2.0, linestyle='--', color='r' , label='analytic')

###################################################################

ax.set_xlabel(r'$t~\rm{[L.U.]}$'      , fontsize=16)
ax.set_ylabel(r'$\langle |u| \rangle$', fontsize=16, rotation=90, labelpad=0)

ax.legend(loc='best', frameon=False, prop={'size' : 16})

ax.tick_params(which="both",direction="in",top="on",right="on",labelsize=14)

plt.show()

In [None]:
w=3.46*3
h=2.14*3

X, Y = np.meshgrid(np.arange(0, nx), 
                   np.arange(0, ny)
                   )

tLst = np.arange(0, niter, dumpit)

for i, t in enumerate( tLst ):
    
    fig = plt.figure(figsize=(w,h))
    ax  = fig.add_subplot(111)    
    
    ux  = dumpfile[dumpfile[:,0]==t, 2].reshape( (nx,ny) )
    uy  = dumpfile[dumpfile[:,0]==t, 3].reshape( (nx,ny) )    
    
    u = (ux**2 + uy**2)**0.5
    
    vmin=0
    vmax=1e-2
    
    im = ax.imshow(u)#, vmax=vmax, vmin=vmin)
    
    ax.streamplot(X, Y, ux, uy, density = 0.5, color='w')
    
    fig.colorbar(im, ax=ax, orientation='vertical', pad=0, shrink=0.69)
    
    ax.set_title(f"Iteration {t}", size=16)
    
    plt.show()