In [None]:
import jax
import jax.numpy as jnp
from   jax import vmap
from   jax import jit
from   jax import jacfwd, jacrev
from   jax import grad, jvp
from   jax import random as jrd

import numpy as np
from numpy.linalg import norm

import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy.optimize import least_squares
from matplotlib import cm
from matplotlib.ticker import LinearLocator
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from random import sample
from time import time
import math
import random

import os
import scipy.io
import sympy as sym
#For the animation
from matplotlib.animation import FuncAnimation
from matplotlib import rc
from IPython.display import HTML

from numpy.polynomial import Polynomial
from scipy.sparse import csr_matrix
from scipy.integrate import dblquad, quad
from scipy import integrate

In [None]:
'Data'
dim = 2

a1, a2, a3, a4 = -4., -3.,  3., 4.        # width = 8 cm (8.5)
                                          # L. sqrt(h) = 29.4 cm we will take 30. And l. sqrt(6) = 19.6 and we will take 20
b1, b2         = -11., 1.                 # 4cm width
o              = 0.
c1, c2         = -1.8, 1.8

rho = 1.043
mu  = 0.0035
#Re  = 100.  #rho/mu

key = jrd.PRNGKey(42)
key1, key2, key3, key4 = jrd.split(key, num = 4)

n_pts = 80
n_hat = 30

n_x = n_pts
n_y = n_pts

n_x_hat = n_hat
n_y_hat = n_hat

deltax, deltay = (a4 - a1)/(n_x -1), (b2 - b1)/(n_y-1)

#n_pts = n_x*n_y

#n_hat = (n_x -1)*(n_y - 1)     #variable

def chamber_bdry_l(x):          #return the rigid boundary left side
  k= ((x - c1)/(a1 - c1))**2           #(b2-b1)/((a1-c1)**4)
  return  b1*(1 - k)


def chamber_bdry_r(x):            #return the rigid boundary right side
  k = ((x - c2)/(a4 - c2))**2     #(b2-b1)/((a4-c2)**4)
  return  b1*(1 - k)


# defoemation of the domian
def chamber_membrane_up(x):
  k = 2.5
  return k*(jnp.abs(x) - 0.5*(a3 - a2))

#def chamber_membrane_up(x):                          #b2/((0.5*(a3-a2))**4)
#  return 0.1*np.arctan(10.*(x - a2))*np.arctan(10.*(x - a3))

period   = 6./7.  # how long last each beat
T  = period/2
T0 = 0.
T1 = T0 + T
T2 = T0 + 2*T
N = 10
dt = T/N

print("dt = %.6f" %dt)
SV = 70.      #Stroke volume = 70 ml/beat = 70 cm^3/beat
depth = 6.    #depth of the heart in cm
A = SV/depth  #area of a vertical section   11.6 cm^2

#beta = jnp.pi*A/(2*(c2 - c1))   #V = A.d = 11.6X6 moi
beta = 33./((c2 - c1)**3)
#alpha = beta*(c2 - c1)/(a2 - a1)  #moi
alpha = 33./((a2 - a1)**3)

@jit
def u2_vect_membrane(x, y, t):
  k2      = 50.
  #V_m    = - beta*jnp.sin(np.pi*(x - c1)/(c2 - c1))#*jnp.sin(np.pi*(t - T0)/T)
  V_m     =    beta*(x-c1)*(x-c2)*jnp.sin(np.pi*(t - T0)/T)
  Indicatrice_membrane  = ((0.5 + 0.5*jnp.tanh(k2*(x - c1))) - (0.5 + 0.5*jnp.tanh(k2*(x - c2))))*jnp.exp(-k2*(y - b1)**2)
  return V_m*Indicatrice_membrane

@jit
def ki(t):
  return 0.5*(jnp.sin(np.pi*(t - T0)/T) + jnp.abs(jnp.sin(np.pi*(t - T0)/T)))

@jit
def ko(t):
  return 0.5*(jnp.sin(np.pi*(t - (T0 + T))/T) + jnp.abs(jnp.sin(np.pi*(t - (T0 + T))/T)))

@jit
def u2_vect_inlet_outlet(x, y, t):    # returns u2 of inlet, outlet bdry points
  k2        = 50.
  #V_inlet   =   - alpha*jnp.sin(np.pi*(x - a1)/(a2 - a1))*ki(t)
  V_inlet  =     alpha*(x - a1)*(x-a2)*ki(t)
  #V_outlet  =     alpha*jnp.sin(np.pi*(x - a3)/(a4 - a3))*ko(t)
  V_outlet =    -alpha*(x-a3)*(x-a4)*ko(t)
  Indicatrice_inlet  = ((0.5 + 0.5*jnp.tanh(k2*(x - a1))) - (0.5 + 0.5*jnp.tanh(k2*(x - a2))))*jnp.exp(-k2*(y - b2)**2)
  Indicatrice_outlet = ((0.5 + 0.5*jnp.tanh(k2*(x - a3))) - (0.5 + 0.5*jnp.tanh(k2*(x - a4))))*jnp.exp(-k2*(y - b2)**2)
  return  V_inlet*Indicatrice_inlet + V_outlet*Indicatrice_outlet

@jit
def un(x, y, t):
  return u2_vect_inlet_outlet(x, y, t) + u2_vect_membrane(x, y, t)

print("alpha: %.7f:"  %alpha)
print("beta:  %.7f:"  %beta)
print(u2_vect_inlet_outlet((a3 + a4)/2, b2, T0 + T + T/2))
print(u2_vect_membrane((c1 + c2)/2, b1, 3*T/2))
print(u2_vect_inlet_outlet((a1 + a2)/2, b2,  T/2))
un(1., 1., 1.)

In [None]:
grid_number = 50
x, y = np.meshgrid(np.linspace(a1, a4, grid_number), np.linspace(b1, b2, grid_number) )
XY = np.vstack((np.ravel(x), np.ravel(y))).T
X, Y = XY[:, 0], XY[:, 1]
one = jnp.ones((len(X),))
un_vector = np.array(un(X, Y, (0.5*N)*dt*one))

for i in range(len(X)):
  if (a2 < X[i] < a3) and (chamber_membrane_up(X[i]) < Y[i] <= b2):
    un_vector[i] = 0.
  if (a1 <= X[i] <= c1 and Y[i] < chamber_bdry_l(X[i])) or (c2 <= X[i] <= a4 and Y[i] < chamber_bdry_r(X[i])):
    un_vector[i] = 0.

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
#fig.subplots_adjust(left=0.125, right=1.5, bottom=0.1,  top=1.1, wspace=0.2, hspace=0.2)
un_pred =  un_vector

z = un_pred.reshape(grid_number, grid_number)
#ax[0,0].set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("un_pred(x, y, %.1f)" %(10*dt))
cp = ax.contourf(x, y, z, 20,  cmap='RdGy')
plt.colorbar(cp, ax=ax)

In [None]:
x_gauss, wx_gauss  = np.polynomial.legendre.leggauss(n_x)
y_gauss, wy_gauss  = np.polynomial.legendre.leggauss(n_y)
xin_mesh, yin_mesh = np.meshgrid(x_gauss, y_gauss)
xin =  0.5*(a4 - a1)*xin_mesh.T.flatten().T + 0.5*(a1 + a4)
yin =  0.5*(b2 - b1)*yin_mesh.T.flatten().T + 0.5*(b1 + b2)
poids = 0.25*(a4 - a1)*(b2 - b1)*(wx_gauss*wy_gauss[:,None]).ravel()
#xin_hat, yin_hat = xin, yin

x_hat = np.linspace(a1, a4, n_x_hat)
y_hat = np.linspace(b1, b2, n_y_hat)
xin_hat_mesh, yin_hat_mesh = np.meshgrid(x_hat, y_hat)
xin_hat =  xin_hat_mesh.T.flatten().T
yin_hat =  yin_hat_mesh.T.flatten().T

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].scatter(xin, yin, s = 1., marker = 'o', color = 'black', label = "Training points")
#ax[1].scatter(xbdry, ybdry, s = 1., marker = 'o', color = 'r', label = "Boundary points")
ax[1].scatter(xin_hat, yin_hat, s = 1., marker = 'o', color = 'b', label = "Centers of RBFs")
#plt.legend(['Traning points', "Centers of RBF"])
ax[0].legend()
ax[1].legend()
#ax[1].legend()
plt.show()


In [None]:
def test(x, y):
  m, n = 1, 1
  return  jnp.cos(m*x)*jnp.sin(n*y)  #(1. + x + x**3 + x*y**2)*(np.exp(x**2 + y**2)**2 + y*np.sin(5*x))

print("True integral: %.12f" %integrate.nquad(test, [[a1, a4], [b1, b2]])[0])

print("Approximated integral: %.12f" %jnp.dot(poids, test(xin, yin)))

In [None]:
xin_r, yin_r, xin_D, yin_D = [], [], [], []

#xin_hat_r, yin_hat_r, xin_hat_D, yin_hat_D = [], [], [], []

#points AH Domain
for j in range(len(xin)):
  if (a2 <= xin[j] <= a3) and (chamber_membrane_up(xin[j]) <= yin[j] ):
    xin_r.append(xin[j])
    yin_r.append(yin[j])

  elif (a1 <= xin[j] <= c1 and yin[j] <= chamber_bdry_l(xin[j])) or  (c2 <= xin[j] <= a4 and yin[j] <= chamber_bdry_r(xin[j])):
     xin_r.append(xin[j])
     yin_r.append(yin[j])

  elif (a3 <= xin[j] <= a4) and (yin[j] == b2) :
     xin_r.append(xin[j])
     yin_r.append(yin[j])

  elif (xin[j] == a1 or xin[j] == a2 or xin[j] == a3 or xin[j] == a4) and (o <= yin[j] <= b2):
    xin_r.append(xin[j])
    yin_r.append(yin[j])

  else:
     xin_D.append(xin[j])
     yin_D.append(yin[j])

xin_r, yin_r, xin_D, yin_D = np.array(xin_r), np.array(yin_r), np.array(xin_D), np.array(yin_D)


In [None]:
# Inlet
xbdry_inlet_left = a1*jnp.ones(shape = (4*n_y,))
ybdry_inlet_left = np.linspace(o, b2, 4*n_y)

xbdry_inlet_right = a2*jnp.ones(shape = (4*n_y,))
ybdry_inlet_right = ybdry_inlet_left

xbdry_inlet_up =  np.linspace(a1, a2, 16*n_x)
ybdry_inlet_up =  b2*jnp.ones(shape = (16*n_x,))

# Outlet
xbdry_outlet_left = a3*jnp.ones(shape = (4*n_y,))
ybdry_outlet_left = np.linspace(o, b2, 4*n_y)

xbdry_outlet_right = a4*jnp.ones(shape = (4*n_y,))
ybdry_outlet_right  = ybdry_outlet_left

xbdry_outlet_up =  np.linspace(a3, a4, 16*n_x)
ybdry_outlet_up =  b2*jnp.ones(shape = (16*n_x,))

# chamber_bdry_left
xbdry_chamber_left = np.linspace(a1, c1, 16*n_x)
ybdry_chamber_left = chamber_bdry_l(xbdry_chamber_left)

# chamber_bdry_right
xbdry_chamber_right = np.linspace(c2, a4, 16*n_x)
ybdry_chamber_right = chamber_bdry_r(xbdry_chamber_right)

#bdry_chamber_m_up
xbdry_chamber_m_up = np.linspace(a2, a3, 16*n_x)
ybdry_chamber_m_up = chamber_membrane_up(xbdry_chamber_m_up)                   #np.linspace(o, o,   15*n_x)

#bdry_chamber_m_down
xbdry_chamber_m_down = np.linspace(c1, c2, 30*n_x)
ybdry_chamber_m_down = np.linspace(b1, b1, 30*n_x)

xbdry = np.concatenate([xbdry_inlet_left, xbdry_inlet_right, xbdry_inlet_up, \
                        xbdry_outlet_left, xbdry_outlet_right, xbdry_outlet_up,
                        xbdry_chamber_left, xbdry_chamber_right, xbdry_chamber_m_up, xbdry_chamber_m_down], 0)

ybdry = np.concatenate([ybdry_inlet_left, ybdry_inlet_right, ybdry_inlet_up, \
                        ybdry_outlet_left, ybdry_outlet_right, ybdry_outlet_up,\
                        ybdry_chamber_left, ybdry_chamber_right, ybdry_chamber_m_up, ybdry_chamber_m_down], 0)

xbdry_Gamma_r = np.concatenate([xbdry_inlet_left,  xbdry_inlet_right,\
                                xbdry_outlet_left, xbdry_outlet_right,  xbdry_outlet_up,\
                                xbdry_chamber_left, xbdry_chamber_right, xbdry_chamber_m_up], 0)

ybdry_Gamma_r = np.concatenate([ybdry_inlet_left,  ybdry_inlet_right,\
                                ybdry_outlet_left, ybdry_outlet_right, ybdry_outlet_up,\
                                ybdry_chamber_left, ybdry_chamber_right, ybdry_chamber_m_up], 0)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 5))
##ax.scatter(xin_D, yin_D, s = 1., marker = 'o', color = 'black')
#ax.scatter(xin_D, yin_D, s = 1., marker = 'o', color = 'y')
#ax.scatter(xin_hat_D, yin_hat_D, s = 1., marker = 'o', color = 'y')
ax.scatter(xbdry, ybdry, s = 1., marker = 'o', color = 'black')
#ax.scatter(xbdry_Gamma_r, ybdry_Gamma_r, s = 1., marker = 'o', color = 'black
ax.scatter(xbdry_inlet_up, u2_vect_inlet_outlet(xbdry_inlet_up,   b2, 0.5*N*dt) + b2 , s = 1., marker = 'o', color = 'r')
ax.scatter(xbdry_outlet_up, u2_vect_inlet_outlet(xbdry_outlet_up, b2, 0.5*N*dt) + b2 , s = 1., marker = 'o', color = 'r')
ax.scatter(xbdry_chamber_m_down, u2_vect_membrane(xbdry_chamber_m_down, b1, 0.5*N*dt) + b1, s = 1., marker = 'o', color = 'r')
xaxis = jnp.linspace(a1,  a4, int((a4 - a1) + 1))
yaxis =  jnp.linspace(b1, b2, int((b2 - b1) + 1))
plt.xticks(xaxis)   # ,[str(i) for i in ticks]
plt.yticks(yaxis)
plt.show()

In [None]:
kx, ky = 5, 5 # ie  x = kx*deltax, y = ky*deltay
h =   10**(-2)
eps = - jnp.log(h)/(kx**2*deltax**2 + ky**2*deltay**2)
print("eps = %f"  %eps)

def RBF(x, y):
  r_square = x**2 + y**2
  return jnp.exp(- eps*r_square)


In [None]:
n_hat = len(xin_hat)
params_d   =  np.random.normal(0., 1., size=(n_hat,))

@jit
def d(params_d, x, y):
  PHI = RBF(x - xin_hat, y - yin_hat)
  index1 = 0
  index2 = index1 + n_hat

  W = params_d[index1:index2]

  z = jnp.dot(W, PHI)
  return  z

d_vect   = vmap(d,      (None, 0, 0))

@jit
def loss_d(params_d):
  eq_in =  d_vect(params_d, xin,   yin)  - jnp.ones((len(xin)))
  eq_r  =  d_vect(params_d, xin_r, yin_r)
  eq_bc =  d_vect(params_d, xbdry, ybdry)
  return jnp.mean(eq_in**2) + 10.*(jnp.mean(eq_r**2) + jnp.mean(eq_bc**2))

grad_loss_d = jit(grad(loss_d, 0))

loss_d(params_d)


In [None]:
print("loss_d = %.12f" %loss_d(params_d))
epoch_d = 300
for i in range(epoch_d):

  L_BFGS_B_d = minimize(loss_d, params_d, args=(), method="L-BFGS-B", jac=grad_loss_d,\
                      hess= None, hessp= None,\
                      bounds=None, constraints=(), tol= 10**(-8), callback=None, \
                      options= {"maxiter" : 10, "gtol" : 10**(-8)})

  params_d  = L_BFGS_B_d.x

  print("i = %d, loss_d = %.12f" %(i, loss_d(params_d)))

In [None]:

grid_number = 50
x, y = np.meshgrid(np.linspace(a1, a4, grid_number), np.linspace(b1, b2, grid_number) )
XY = np.vstack((np.ravel(x), np.ravel(y))).T
X, Y = XY[:, 0], XY[:, 1]
d_vector = np.array(d_vect(params_d, X, Y))
'''
for i in range(len(X)):
  if (a2 < X[i] < a3) and (chamber_membrane_up(X[i]) < Y[i]):
    d_vector[i] = 0.
  if (a1 <= X[i] <= c1 and Y[i] < chamber_bdry_l(X[i])) or (c2 <= X[i] <= a4 and Y[i] < chamber_bdry_r(X[i])):
    #or (X[i]< a1) or (X[i]>a4) or ( Y[i] > b2)  :
    d_vector[i] = 0.
'''
fig, ax = plt.subplots(1, 1, figsize=(5, 8))
#fig.subplots_adjust(left=0.125, right=1.5, bottom=0.1,  top=1.1, wspace=0.2, hspace=0.2)
d_pred =  d_vector

z = d_pred.reshape(grid_number, grid_number)
#ax[0,0].set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("d_pred(x, y)")
cp = ax.contourf(x, y, z, 20,  cmap='RdGy')
xaxis = jnp.linspace(a1,  a4, int((a4 - a1) + 1))
yaxis =  jnp.linspace(b1, b2, int((b2 - b1) + 1))
plt.xticks(xaxis)   # ,[str(i) for i in ticks]
plt.yticks(yaxis)
plt.colorbar(cp, ax=ax)

In [None]:
params  =  np.random.normal(0., 1., size=(2*n_hat,))

def u1(params, x, y, t):

  PHI = RBF(x - xin_hat, y - yin_hat)

  index1 = 0
  index2 = index1 + n_hat

  W = params[index1:index2]

  z = jnp.dot(W, PHI)
  return z*d(params_d, x, y)

def u2(params, x, y, t):

  PHI = RBF(x - xin_hat, y - yin_hat)

  index1 = n_hat
  index2 = index1 + n_hat

  W = params[index1:index2]

  z =  jnp.dot(W, PHI)

  return z*d(params_d, x, y) + un(x, y, t)


In [None]:

#u1
u1_x1   = jit(grad(u1,    1))
u1_x1x1 = jit(grad(u1_x1, 1))
u1_x2   = jit(grad(u1,    2))
u1_x2x2 = jit(grad(u1_x2, 2))
u1_x1x2 = jit(grad(u1_x1, 2))

#u2
u2_x1   = jit(grad(u2,    1))
u2_x1x1 = jit(grad(u2_x1, 1))
u2_x2   = jit(grad(u2,    2))
u2_x2x2 = jit(grad(u2_x2, 2))
u2_x1x2 = jit(grad(u2_x1, 2))

#p
#p_x1 = jit(grad(p, 1))
#p_x2 = jit(grad(p, 2))

#Vectoriser
#u1
u1_vect      = vmap(u1,          (None, 0, 0, 0) )
u1_x1_vect   = vmap(u1_x1,       (None, 0, 0, 0) )
u1_x2_vect   = vmap(u1_x2,       (None, 0, 0, 0) )
u1_x1x1_vect = vmap(u1_x1x1,     (None, 0, 0, 0) )
u1_x1x2_vect = vmap(u1_x1x2,     (None, 0, 0, 0) )
u1_x2x2_vect = vmap(u1_x2x2,     (None, 0, 0, 0) )
#u2
u2_vect      = vmap(u2,          (None, 0, 0, 0) )
u2_x1_vect   = vmap(u2_x1,       (None, 0, 0, 0) )
u2_x2_vect   = vmap(u2_x2,       (None, 0, 0, 0) )
u2_x1x1_vect = vmap(u2_x1x1,     (None, 0, 0, 0) )
u2_x1x2_vect = vmap(u2_x1x2,     (None, 0, 0, 0) )
u2_x2x2_vect = vmap(u2_x2x2,     (None, 0, 0, 0) )
#p
#p_vect      = vmap(p,            (None, 0, 0) )
#p_x1_vect   = vmap(p_x1,         (None, 0, 0) )
#p_x2_vect   = vmap(p_x2,         (None, 0, 0) )


In [None]:
"psi before the training"
grid_number = 100

x, y = np.meshgrid(np.linspace(a1, a4, grid_number), np.linspace(b1, b2, grid_number) )

XY = np.vstack((np.ravel(x), np.ravel(y))).T
X = XY[:, 0]
Y = XY[:, 1]
one = jnp.ones((len(X),))
u1_vector = np.array(u1_vect(params, X, Y, 0.25*N*dt*one))
u2_vector = np.array(u2_vect(params, X, Y, 0.25*N*dt*one))

for i in range(len(X)):
  if (a2 < X[i] < a3) and (chamber_membrane_up(X[i]) < Y[i]):
    u1_vector[i] = 0.
    u2_vector[i] = 0.
  if (a1 <= X[i] <= c1 and Y[i] < chamber_bdry_l(X[i])) or (c2 <= X[i] <= a4 and Y[i] < chamber_bdry_r(X[i])):
    u1_vector[i] = 0.
    u2_vector[i] = 0.

fig, ax = plt.subplots(1, 1, figsize=(5, 6))

u1_pred = u1_vector
u2_pred = u2_vector
#ax = plt.axes(xlim=(0, R), ylim=(0, R))
ax.set_title("IC: t = %.1f " %0.)
z1 = u1_pred.reshape(grid_number, grid_number)
z2 = u2_pred.reshape(grid_number, grid_number)
#plt.clabel(cp, inline=True, fontsize=15)
plt.streamplot(x, y,  z1, z2, density = 3)
cp = ax.quiver(XY[:, 0], XY[:, 1], u1_pred,  u2_pred,  \
               jnp.sqrt(u1_pred**2+ u2_pred**2), color='g', \
               label = '(u1_pred, u2_pred)')
ax.scatter(xbdry, ybdry, s = 2., marker = 'o', color = 'r')
plt.colorbar(cp)

In [None]:

def norm_div_L2(params, n):
  eq_div  = u1_x1_vect(params, xin, yin, n*dt*one) +  u2_x2_vect(params, xin, yin, n*dt*one)
  return  jnp.sqrt(jnp.dot(poids, eq_div**2))

#grad_loss_div = jit(grad(loss_div, 0))
one = jnp.ones((len(xin),))
norm_div_L2(params, 1)

In [None]:
Re = 298.


@jit
def lagrangean_NS(params, params0, params1, X0, Y0, X1, Y1, n):

  inertia_term   = (0.5*3.*(u1_vect(params, xin, yin,     n*dt*one)**2 + \
                            u2_vect(params, xin, yin,     n*dt*one)**2)  \
                 -      4.*(u1_vect(params1, X1,  Y1, (n-1)*dt*one)*\
                            u1_vect(params, xin, yin,     n*dt*one) +  \
                            u2_vect(params1, X1,  Y1, (n-1)*dt*one)*\
                            u2_vect(params, xin, yin,     n*dt*one))\
                 +         (u1_vect(params0, X0,  Y0, (n-2)*dt*one)*\
                            u1_vect(params, xin, yin,     n*dt*one) +\
                            u2_vect(params0, X0,  Y0, (n-2)*dt*one)*\
                            u2_vect(params, xin, yin,     n*dt*one)))/(2*dt)

  diffusion_term1 = 0.5*\
                   (u1_x1_vect(params, xin, yin, n*dt*one)**2 + \
                    u1_x2_vect(params, xin, yin, n*dt*one)**2 + \
                    u2_x1_vect(params, xin, yin, n*dt*one)**2 + \
                    u2_x2_vect(params, xin, yin, n*dt*one)**2)/(2*Re)

  diffusion_term2 = (u1_x1_vect(params1,  xin, yin, (n-1)*dt*one)* \
                     u1_x1_vect(params,   xin, yin,     n*dt*one) +\
                     u1_x2_vect(params1,  xin, yin, (n-1)*dt*one)*  \
                     u1_x2_vect(params,   xin, yin,     n*dt*one) +\
                     u2_x1_vect(params1,  xin, yin, (n-1)*dt*one)*  \
                     u2_x1_vect(params,   xin, yin,     n*dt*one) +\
                     u2_x2_vect(params1,  xin, yin, (n-1)*dt*one)* \
                     u2_x2_vect(params,   xin, yin,     n*dt*one))/(2*Re)

  augmented_term1 = 0.25*penalty*(u1_x1_vect(params, xin, yin, n*dt*one) +\
                                  u2_x2_vect(params, xin, yin, n*dt*one))**2

  augmented_term2 = 0.5*penalty*(u1_x1_vect(params1,  xin, yin,  (n-1)*dt*one) +\
                                 u2_x2_vect(params1,  xin, yin,  (n-1)*dt*one))*\
                                 (u1_x1_vect(params,  xin, yin,      n*dt*one) +\
                                  u2_x2_vect(params,  xin, yin,      n*dt*one))

  pdiv_term      =   pressure*(u1_x1_vect(params, xin, yin, n*dt*one) +\
                               u2_x2_vect(params, xin, yin, n*dt*one))

  L =  jnp.dot(poids, inertia_term)     +\
       jnp.dot(poids, diffusion_term1)  +\
       jnp.dot(poids, diffusion_term2)  +\
       jnp.dot(poids, augmented_term1)  +\
       jnp.dot(poids, augmented_term2)  -\
      jnp.dot(poids,  pdiv_term)
  return L



In [None]:
'''
# NS
pressure = jnp.ones((len(xin),))
one      = jnp.ones((len(xin),))
penalty  = 100
epoch    = 30
params0  = 0.*params
params1  = 0.*params
L        =  lagrangean_NS(params, params0, params1, xin, yin, xin, yin, 1)
div_u    =  norm_div_L2(params, 1)
print("k : Initial, Lagrangean_NS: %.12f,  loss_div_u: %.12f"  %(L, div_u))
one = jnp.ones((len(xin),))
'''
n_per = 2
n_iter = n_per*(2*N)

for n in range(2, n_iter + 1):

  print("n = %d" %n)

  # Runge-Kutta Method of order 2 (l = 1.)
  # Compute X^(n-1),n.

  # we compute K1 in both directions
  #K1x =  2*u1_vect(params1, xin, yin) - u1_vect(params0, xin, xin)    #we approach U^n(., .) by 2U^{n-1}(., .) - U^{n-2}(., .)
  #K1y =  2*u2_vect(params1, xin, yin) - u2_vect(params0, xin, yin)

  K1x = u1_vect(params, xin, yin, n*dt*one)
  K1y = u2_vect(params, xin, yin, n*dt*one)

  #we compute K2 in both direction using K1
  K2x = u1_vect(params1, xin - dt*K1x, yin - dt*K1y, (n-1)*dt*one)
  K2y = u2_vect(params1, xin - dt*K1x, yin - dt*K1y, (n-1)*dt*one)

  # We compute the approximation of the position. At previous time n-1 in both direction
  # This is the position at time n-1
  X1 = xin - 0.5*dt*(K1x + K2x)
  Y1 = yin - 0.5*dt*(K1y + K2y)

  #Compute X^(n-2),n
  # using the approximation at time n-1. We compute the approximation at time n-2
  # we compute K1 in both direction for the approximation of X^(n-2),n

  K1x = u1_vect(params1, X1, Y1, (n-1)*dt*one)
  K1y = u2_vect(params1, X1, Y1, (n-1)*dt*one)

  # we compute K2 in both direction, for the approximation of X^(n-2),n
  K2x = u1_vect(params0, X1 - dt*K1x, Y1 - dt*K1y, (n-2)*dt*one)
  K2y = u2_vect(params0, X1 - dt*K1x, Y1 - dt*K1y, (n-2)*dt*one)
  # Th is is the approximation of the position at time n-2
  X0 = X1 - 0.5*dt*(K1x + K2x)
  Y0 = Y1 - 0.5*dt*(K1y + K2y)

  @jit
  def lagrang_NS(params):
    return lagrangean_NS(params, params0, params1, X0, Y0, X1, Y1, n)

  grad_lagrang_NS   = jit(grad(lagrang_NS, 0))

  for k in range(epoch):

    L_BFGS_B_NS = minimize(lagrang_NS, params, args=(), method="L-BFGS-B", jac=grad_lagrang_NS,\
                      hess= None, hessp= None,\
                      bounds=None, constraints=(), tol= 10**(-8), callback=None, \
                      options= {"maxiter" : 100, "gtol" : 10**(-8)})

    params_new  = L_BFGS_B_NS.x

    pressure = pressure - (penalty + 4)*(u1_x1_vect(params_new, xin, yin, n*dt*one) +\
                                         u2_x2_vect(params_new, xin, yin, n*dt*one))


    L     = lagrang_NS(params_new)
    div_u   = norm_div_L2(params_new, n)
    print("k : %d, Lagrangean_NS: %.12f,   loss_div_u: %.12f"  %(k, L, div_u))

    if (params_new == params).all():
      break
    params  = params_new

  params0 = params1
  params1 = params


  #Plotting of BDF1 at each step
  '''
  if (n%1 == 0):

    params_NS = params
    grid_number = 100
    x, y = np.meshgrid(np.linspace(a1, a4, grid_number), np.linspace(b1, b2, grid_number) )
    XY = np.vstack((np.ravel(x), np.ravel(y))).T
    X = XY[:, 0]
    Y = XY[:, 1]
    one_grid = jnp.ones((len(X),))
    u1_vector = np.array(u1_vect(params, X, Y, n*dt*one_grid))
    u2_vector = np.array(u2_vect(params, X, Y, n*dt*one_grid))

    for i in range(len(X)):
      if (a2 < X[i] < a3) and (chamber_membrane_up(X[i]) < Y[i]):
        u1_vector[i] = 0.
        u2_vector[i] = 0.

      if (a1 <= X[i] <= c1 and Y[i] < chamber_bdry_l(X[i])) or (c2 <= X[i] <= a4 and Y[i] < chamber_bdry_r(X[i])):
        u1_vector[i] = 0.
        u2_vector[i] = 0.

    fig, ax = plt.subplots(1, 1, figsize=(5, 6))
    u1_pred = u1_vector
    u2_pred = u2_vector
    #ax = plt.axes(xlim=(0, R), ylim=(0, R))
    ax.set_title("AH, Re= %d, t =  %d*dt" %(Re, n))
    z1 = u1_pred.reshape(grid_number, grid_number)
    s1 = np.array(z1)
    z2 = u2_pred.reshape(grid_number, grid_number)
    s2 = np.array(z2)
    #plt.clabel(cp, inline=True, fontsize=15)
    strm = plt.streamplot(x, y,  z1, z2, density = 5, color= np.sqrt(s1**2 + s2**2), cmap = "plasma")
    fig.colorbar(strm.lines)
    ax.scatter(xbdry, ybdry, s = 2., marker = 'o', color = 'r')
    plt.colorbar(cp)
    '''
plt.show()





In [None]:
grid_number = 100

x, y = np.meshgrid(np.linspace(a1, a4, grid_number), np.linspace(b1, b2, grid_number) )

XY = np.vstack((np.ravel(x), np.ravel(y))).T
X = XY[:, 0]
Y = XY[:, 1]
one_grid = jnp.ones((len(X),))
u1_vector = np.array(u1_vect(params, X, Y, 0.5*N*dt*one_grid))
u2_vector = np.array(u2_vect(params, X, Y, 0.5*N*dt*one_grid))

for i in range(len(X)):
  if (a2 < X[i] < a3) and (chamber_membrane_up(X[i]) < Y[i]):
    u1_vector[i] = 0.
    u2_vector[i] = 0.
  if (a1 <= X[i] <= c1 and Y[i] < chamber_bdry_l(X[i])) or (c2 <= X[i] <= a4 and Y[i] < chamber_bdry_r(X[i])):
    u1_vector[i] = 0.
    u2_vector[i] = 0.

fig, ax = plt.subplots(1, 1, figsize=(5, 6))

u1_pred = u1_vector
u2_pred = u2_vector
#ax = plt.axes(xlim=(0, R), ylim=(0, R))
ax.set_title("AH, Re= %d, t =  %f" %(Re, T/2))
z1 = u1_pred.reshape(grid_number, grid_number)
s1 = np.array(z1)
z2 = u2_pred.reshape(grid_number, grid_number)
s2 = np.array(z2)
#plt.clabel(cp, inline=True, fontsize=15)
strm = plt.streamplot(x, y,  z1, z2, density = 5, color= np.sqrt(s1**2 + s2**2), cmap = "plasma")
fig.colorbar(strm.lines)
ax.scatter(xbdry, ybdry, s = 2., marker = 'o', color = 'r')
plt.colorbar(cp)