## 1 environment setup

### 1.1 dependency

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from torchsummary import summary
import time

### 1.2 connect to machine

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

## 2 Functions

### 2.1 Generate training samples

In [None]:
def sample(N):
    global x
    x = torch.zeros([N, N, 2]).to(device)
    for i in range(N):
        for j in range(N):
            x[i][j][0] = i; # x
            x[i][j][1] = j; # y
    x.requires_grad = True

### 2.2 Generate benchmark solution

In [None]:
def compute_FD(N, Re, dt = 0.01, tol_v = 1e-6):
    ###############################################################
    # User Define                                                 #
    ###############################################################

    global UC, VC, PC
    
#     dt=0.01
    tol_p=5.077674e-12

    dx=1.0/(N-1)
    dy=1.0/(N-1)

    P=np.zeros((N+1,N+1))
    U=np.zeros((N+1,N+1))
    V=np.zeros((N+1,N+1))
    U_1=np.zeros((N+1,N+1))
    V_1=np.zeros((N+1,N+1))
    U_2=np.zeros((N+1,N+1))
    V_2=np.zeros((N+1,N+1))
    U_prev=np.zeros((N+1,N+1))
    V_prev=np.zeros((N+1,N+1))
    P_prev=np.zeros((N+1,N+1))
    
    UC=np.zeros((N,N))
    VC=np.zeros((N,N))
    PC=np.zeros((N,N))

    ###############################################################
    # Tools                                                       #
    ###############################################################

    def meet_poisson():
      residual=0
      for i in range(1,N):
        for j in range(1,N):
          ux=(U_2[i,j]-U_2[i-1,j])/dx;
          vy=(V_2[i,j]-V_2[i,j-1])/dy;
          poisson_LHP=(ux+vy)/dt;
          poisson_RHP=(P[i+1,j]+P[i-1,j]+P[i,j+1]+P[i,j-1]-4*P[i,j])/(dx*dx)
          residual+=abs(poisson_LHP-poisson_RHP)

    #   clear_output(wait=True)
      print("[meet Poisson]residual: ", residual)
      if(residual<tol_p):
        return True

      return False

    def is_steady(t):
      vt = 0
      for i in range(1,N):
        for j in range(1,N):
          vt = max(vt, np.abs((U[i, j] - U_prev[i,j]) / (U_prev[i,j] + 1e-20)))
          vt = max(vt, np.abs((V[i, j] - V_prev[i,j]) / (V_prev[i,j] + 1e-20)))
          U_prev[i,j]=U[i,j];
          V_prev[i,j]=V[i,j];

      if t%100==0:
        print("[is_steady] velocity deviation: ",vt)

      if(vt < tol_v):
        return True

      return False

    def collocate():
      for i in range(N):
        for j in range(N):
          UC[i,j]=0.5*(U[i,j]+U[i,j+1])
          VC[i,j]=0.5*(V[i,j]+V[i+1,j])
          PC[i,j]=(P[i,j]+P[i+1,j]+P[i,j+1]+P[i+1,j+1])*0.25


    ###############################################################
    # Functions                                                   #
    ###############################################################

    def setBC(option):
        if option=='P':
            P[0,:]=P[1,:] #west
            P[N,:]=P[N-1,:] #east
            P[:,0]=P[:,1] #south
            P[:,N]=P[:,N-1] #north

        elif option=='U':
    #     U[:,N]=2*np.sin(2*np.pi*f*timestep)-U[:,N-1] #north
            U[:,N]=2-U[:,N-1] #north
            U[:,0]=-U[:,1] #south
            U[0,:]=0 #west
            U[N-1,:]=0 #east

    #     U_1[:,N]=2*np.sin(2*np.pi*f*timestep)-U_1[:,N-1] #north
            U_1[:,N]=2-U_1[:,N-1] #north
            U_1[:,0]=-U_1[:,1] #south
            U_1[0,:]=0 #west
            U_1[N-1,:]=0 #east

    #     U_1[:,N]=2*np.sin(2*np.pi*f*timestep)-U_1[:,N-1] #north
            U_1[:,N]=2-U_1[:,N-1] #north
            U_2[:,0]=-U_2[:,1] #south
            U_2[0,:]=0 #west
            U_2[N-1,:]=0 #east

        elif option=='V':
            V[0,:]=-V[1,:] #west
            V[N,:]=-V[N-1,:] #east
            V[:,0]=0 #south
            V[:,N-1]=0 #north

            V_1[0,:]=-V_1[1,:] #west
            V_1[N,:]=-V_1[N-1,:] #east
            V_1[:,0]=0 #south
            V_1[:,N-1]=0 #north

            V_2[0,:]=-V_2[1,:] #west
            V_2[N,:]=-V_2[N-1,:] #east
            V_2[:,0]=0 #south
            V_2[:,N-1]=0 #north

    def solve_U1():
      for i in range(1,N-1):
        for j in range(1,N):
          u=U[i,j]
          v=(V[i,j]+V[i+1,j]+V[i,j-1]+V[i+1,j-1])/4
          ux=(U[i+1,j]-U[i-1,j])/(2*dx)
          uy=(U[i,j+1]-U[i,j-1])/(2*dy)
          u2x=(U[i+1,j]+U[i-1,j]-2*U[i,j])/(dx*dx)
          u2y=(U[i,j+1]+U[i,j-1]-2*U[i,j])/(dy*dy)

          C=u*ux+v*uy
          D=(u2x+u2y)/Re

          px=(P[i+1,j]-P[i,j])/dx
          U_1[i,j]=(-C+D-px)*dt+U[i,j]

    def solve_V1():
      for i in range(1,N):
        for j in range(1,N-1):
          u=(U[i-1,j+1]+U[i,j+1]+U[i-1,j]+U[i,j])/4.0;
          v=V[i,j]
          vx=(V[i+1,j]-V[i-1,j])/(2*dx)
          vy=(V[i,j+1]-V[i,j-1])/(2*dy)
          v2x=(V[i+1,j]+V[i-1,j]-2*V[i,j])/(dx*dx)
          v2y=(V[i,j+1]+V[i,j-1]-2*V[i,j])/(dy*dy)

          C=u*vx+v*vy
          D=(v2x+v2y)/Re

          py=(P[i,j+1]-P[i,j])/dy
          V_1[i][j]=(-C+D-py)*dt+V[i,j]

    def solve_U2():
      for i in range(1,N-1):
        for j in range(1,N):
          px=(P[i+1,j]-P[i,j])/dx
          U_2[i,j]=px*dt+U_1[i,j]

    def solve_V2():
      for i in range(1,N):
        for j in range(1,N-1):
          py=(P[i,j+1]-P[i,j])/dy
          V_2[i,j]=py*dt+V_1[i,j]

    def quicksolve_P():
      iteration=0
      while(iteration<1): #(meet_poisson()==False):
        iteration=iteration+1
    #     clear_output(wait=True)
    #     print("[solve_P]iteration: ", iteration)
        for i in range(1,N):
          for j in range(1,N):
            ux=(U_2[i,j]-U_2[i-1,j])/dx
            vy=(V_2[i,j]-V_2[i,j-1])/dy
            poisson_LHP=(ux+vy)/dt
            P[i,j]=0.25*(P[i+1,j]+P[i-1,j]+P[i,j+1]+P[i,j-1]-poisson_LHP*dx*dx)
    #     print("[solve_P](timestep=", timestep, ")iteration: ", iteration)
    #   meet_poisson()
      pd.DataFrame(np.array([timestep, iteration])).to_csv('iteration.csv', mode='a', header=False, index=False)

    def solve_U():
      for i in range(1,N-1):
        for j in range(1,N):
          px=(P[i+1,j]-P[i,j])/dx
          U[i,j]=-px*dt+U_2[i,j]

    def solve_V():
      for i in range(1,N):
        for j in range(1,N-1):
          py=(P[i,j+1]-P[i,j])/dy
          V[i,j]=-py*dt+V_2[i,j]

    def set_BC(option=3):
      if option==0:
        setBC('P')

      elif option==1:
        setBC('U')

      elif option==2:
        setBC('V')

      elif option==3:
        setBC('U')
        setBC('V')
        setBC('P')



    ###############################################################
    # Steps                                                       #
    ###############################################################

    def step_1(option):
      if option==1:
        solve_U1()
      elif option==2:
        solve_V1()

    def step_2(option):
      if option==1:
        solve_U2()
      elif option==2:
        solve_V2()

    def step_3(option):
        quicksolve_P()

    def step_4(option):
      if option==1:
        solve_U()
      elif option==2:
        solve_V()



    ###############################################################
    # Run algo                                                    #
    ###############################################################

    timestep=0
    set_BC()
    while((timestep<10)|(is_steady(timestep)==False)): #((timestep==0)|(is_steady()==False)): #&(is_steady()==False)):
      timestep+=1

      if(timestep%100==0):
        print("timestep: ", timestep)
        conti = np.abs((UC[2:, :] - UC[:-2, :])[:, 1:-1] + (VC[:,2:] - VC[:,:-2])[1:-1, :]).mean() * (N-1)
        print("Continuity residual: ", conti)

      step_1(1)
      step_1(2)

      step_2(1)
      step_2(2)

      set_BC(1)
      set_BC(2)

      step_3(0)
      set_BC(0)

      step_4(1)
      step_4(2)
      set_BC(1)
      set_BC(2)

      collocate()


    print("U: ")
    obj = np.sqrt(np.multiply(UC, UC)+np.multiply(VC, VC))
    plt.subplots()[1].set_box_aspect(1)
    plt.contourf(np.transpose(obj), levels = 20)
    plt.colorbar()
    plt.pause(0.1)

    print("u: ")
    obj = UC
    plt.subplots()[1].set_box_aspect(1)
    plt.contourf(np.transpose(obj), levels = 20)
    plt.colorbar()
    plt.pause(0.1)

    print("v: ")
    obj = VC
    plt.subplots()[1].set_box_aspect(1)
    plt.contourf(np.transpose(obj), levels = 20)
    plt.colorbar()
    plt.pause(0.1)

    print("P: ")
    obj = PC
    plt.subplots()[1].set_box_aspect(1)
    plt.contourf(np.transpose(obj), levels = 20)
    plt.colorbar()
    plt.pause(0.1)
    

    return np.transpose(UC), np.transpose(VC), np.transpose(PC)
    

### 2.3 PINN

##### Variables on the focus:
+ N(mesh number).
+ DAB("d"imension "a"nalysis "b"ased, i.e weighing scheme):
    0. loss weight: same weight | matrix: MSE
    1. loss weight: weighted by dimension analyzed result | matrix: MSE
    2. loss weight: weighted by root of dimension analyzed result | matrix: MSE
    3. loss weight: same weight | matrix: MAE
    4. loss weight: weighted by dimension analyzed result | matrix: MAE
    5. loss weight: weighted by root of dimension analyzed result | matrix: MAE

In [None]:
def predict_NN(N, Re, w_ns_x = 1, w_ns_y = 1, w_conti = 1, w_dbc = 1, w_nbc = 1, DAB = 0, start_epoch = 0, train_epoch = 50000, pl_1 = 64):
    class Net(torch.nn.Module):
        def __init__(self, n_feature, n_hidden, n_hidden1, n_output):
            super(Net, self).__init__()
            self.hidden = torch.nn.Linear(n_feature, n_hidden)
            self.hidden1 = torch.nn.Linear(n_hidden, n_hidden1)
            self.hidden2 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.hidden3 = torch.nn.Linear(n_hidden1, n_hidden1)

            self.hidden4 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.hidden5 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.hidden6 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.predictu = torch.nn.Linear(n_hidden1, 1)

            self.hidden7 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.hidden8 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.hidden9 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.predictv = torch.nn.Linear(n_hidden1, 1)

            self.hidden10 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.hidden11 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.hidden12 = torch.nn.Linear(n_hidden1, n_hidden1)
            self.predictp = torch.nn.Linear(n_hidden1, 1)

        def forward(self, x):
            x = torch.sin(self.hidden(x))
            x = torch.sin(self.hidden1(x))
            x = torch.sin(self.hidden2(x))
            x = torch.sin(self.hidden3(x))

            u = torch.sin(self.hidden4(x))
            u = torch.sin(self.hidden5(u))
            u = torch.sin(self.hidden6(u))
            u = self.predictu(u)

            v = torch.sin(self.hidden7(x))
            v = torch.sin(self.hidden8(v))
            v = torch.sin(self.hidden9(v))
            v = self.predictv(v)

            p = torch.sin(self.hidden10(x))
            p = torch.sin(self.hidden11(p))
            p = torch.sin(self.hidden12(p))
            p = self.predictp(p)

            y = torch.cat((u, v, p), 2)
            return y

    pl = 20
    net = Net(n_feature=2, n_hidden = pl_1, n_hidden1 = pl , n_output=3)     # define the network
    net.to(device)
    summary(net, (1,1, 2))

    #####

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
    
    loss_func = torch.nn.MSELoss()
    if (DAB == 3)|(DAB == 4)|(DAB == 5):
        loss_func = torch.nn.L1Loss()
        
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.8)
    plt.ion()

    #####

    if(start_epoch == 0):
        loss_epoch=[]
        loss_values = []

        loss_ns_x_values = []
        loss_ns_y_values = []
        loss_conti_values = []
        loss_DBC_values = []
        loss_NBC_values = []


    loss_op = 'mean'
    h = 1 / (N - 3) # 31+2 - 3


    ###############################################################
    # Define loss weight                                          #
    ###############################################################


    if DAB == 0: # MSE(1)
        a = w_ns_x #10
        b = w_ns_y #10
        c = w_conti #1000
        d = w_dbc #1000
        e = w_nbc #1
    elif DAB == 1: # MSE(DA)
        a = w_ns_x * (h**4) * (Re**2)
        b = w_ns_y * (h**4) * (Re**2)
        c = w_conti * (h**2)
        d = w_dbc * (1)
        e = w_nbc * (h**4) * (Re**2)
    elif DAB == 2: # MSE(Weak)
        a = w_ns_x * (h**2) * (Re)
        b = w_ns_y * (h**2) * (Re)
        c = w_conti * (h)
        d = w_dbc * (1)
        e = w_nbc * (h**2) * (Re)
    elif DAB == 3: # MAE(1)
        a = w_ns_x #10
        b = w_ns_y #10
        c = w_conti #1000
        d = w_dbc #1000
        e = w_nbc #1
    elif DAB == 4: # MAE(DA))
        a = w_ns_x * (h**2) * (Re)
        b = w_ns_y * (h**2) * (Re)
        c = w_conti * (h)
        d = w_dbc * (1)
        e = w_nbc * (h**2) * (Re)
    elif DAB == 5: # MAE(Weak)
        a = w_ns_x * (h) * (Re**0.5)
        b = w_ns_y * (h) * (Re**0.5)
        c = w_conti * (h**0.5)
        d = w_dbc * (1)
        e = w_nbc * (h) * (Re**0.5)


    ###############################################################
    # Loss weight normalization                                   #
    ###############################################################

    print(a, b, c, d, e)

    lambda_ns_x = a / (a + b + c + d + e)
    lambda_ns_y = b / (a + b + c + d + e)
    lambda_conti = c / (a + b + c + d + e)
    lambda_DBC = d / (a + b + c + d + e)
    lambda_NBC = e / (a + b + c + d + e)


    ###############################################################
    # Training                                                    #
    ###############################################################

    start = time.perf_counter()
    for t in range(start_epoch, train_epoch):
        scheduler.step()
        uvp = net(x)
        u = uvp[:, :, 0]
        v = uvp[:, :, 1]
        p = uvp[:, :, 2]

        ###############################################################
        # Compute loss                                                #
        ###############################################################

        # ns-x
        ux = (u[2:, :] - u[:-2, :]) / (2 * h) #29x31
        uux = torch.mul(u[1:-1, 1:-1], ux[:,1:-1]) #29x29
        uy = (u[:, 2:] - u[:,:-2]) / (2 * h) #31x29
        vuy = torch.mul(v[1:-1, 1:-1], uy[1:-1, :]) #29x29
        px = (p[2:, :] - p[:-2, :])[:, 1:-1] / (2 * h)
        lap_u = (-4 * u[1:-1, 1:-1] + u[2:, 1:-1] + u[:-2, 1:-1] + u[1:-1, 2:] + u[1:-1, :-2]) / (h * h)

        LHP_x = uux + vuy
        RHP_x = - px + 1 / Re * (lap_u)

        loss_ns_x = loss_func(LHP_x[1:-1, 1:-1] , RHP_x[1:-1, 1:-1])

        # ns-y
        vx = (v[2:, :] - v[:-2, :]) / (2 * h) #29x31
        uvx = torch.mul(u[1:-1, 1:-1], vx[:, 1:-1]) #29x29
        vy = (v[:, 2:] - v[:,:-2]) / (2 * h) #31x29
        vvy = torch.mul(v[1:-1, 1:-1], vy[1:-1, :]) #29x29
        py = (p[:, 2:] - p[:, :-2])[1:-1, :] / (2 * h)
        lap_v =  (-4 * v[1:-1, 1:-1] + v[2:, 1:-1] + v[:-2, 1:-1] + v[1:-1, 2:] + v[1:-1, :-2]) / (h * h)

        LHP_y = uvx + vvy
        RHP_y = - py + 1 / Re * (lap_v)

        loss_ns_y = loss_func(LHP_y[1:-1, 1:-1] , RHP_y[1:-1, 1:-1])

        ## conti
        loss_conti = loss_func(ux[1:-1, 2:-2] + vy[2:-2, 1:-1] , torch.zeros((N - 4 , N - 4)).to(device))

        ## DBC - u, v
        loss_DBC = loss_func(u[1:-1, 1], torch.zeros((N - 2 , 1)).to(device)) #s
        loss_DBC = loss_DBC + loss_func(u[1:-1, -2] - 1, torch.zeros((N - 2 , 1)).to(device)) #n
        loss_DBC = loss_DBC + loss_func(u[1, 1:-1] , torch.zeros((1 , N - 2)).to(device)) #w
        loss_DBC = loss_DBC + loss_func(u[-2, 1:-1], torch.zeros((1 , N - 2)).to(device)) #e

        loss_DBC = loss_DBC + loss_func(v[1:-1, 1], torch.zeros((N - 2 , 1)).to(device)) #s
        loss_DBC = loss_DBC + loss_func(v[1:-1, -2], torch.zeros((N - 2 , 1)).to(device)) #n
        loss_DBC = loss_DBC + loss_func(v[1, 1:-1] , torch.zeros((1 , N - 2)).to(device)) #w
        loss_DBC = loss_DBC + loss_func(v[-2, 1:-1], torch.zeros((1 , N - 2)).to(device)) #e

        ## NBC - p

        loss_NBC = loss_func((p[1:-1, 0] - p[1:-1, 2]) / h, torch.zeros((N - 2 , 1)).to(device)) #s
        loss_NBC = loss_NBC + loss_func((p[1:-1, -1] - p[1:-1, -3]) / h, torch.zeros((N - 2 , 1)).to(device)) #n
        loss_NBC = loss_NBC + loss_func((p[0, 1:-1] - p[2, 1:-1]) / h, torch.zeros((1 , N - 2)).to(device)) #w
        loss_NBC = loss_NBC + loss_func((p[-1, 1:-1] - p[-3, 1:-1]) / h, torch.zeros((1 , N - 2)).to(device)) #e

        # total loss

        loss = lambda_ns_x * loss_ns_x
        loss = loss + lambda_ns_y * loss_ns_y
        loss = loss + lambda_conti * loss_conti
        loss = loss + lambda_DBC * loss_DBC
        loss = loss + lambda_NBC * loss_NBC

        ###############################################################
        # Save loss                                                   #
        ###############################################################

        loss_epoch.append(t)
        loss_ns_x_values.append(loss_ns_x.item())
        loss_ns_y_values.append(loss_ns_y.item())
        loss_conti_values.append(loss_conti.item())
        loss_DBC_values.append(loss_DBC.item())
        loss_NBC_values.append(loss_NBC.item())
        loss_values.append(loss.item())

        ###############################################################
        # Weight update                                               #
        ###############################################################

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ###############################################################
        # Moniter                                                     #
        ###############################################################

        if t % 5000 == 0:
            print("Epoch: ", t, " ,loss: ", loss)
            print("loss_ns_x: ", loss_ns_x)
            print("loss_ns_y: ", loss_ns_y)
            print("loss_conti: ", loss_conti)
            print("loss_DBC: ", loss_DBC)
            print("loss_NBC: ", loss_NBC)
            print("Timer: {0:.2f} sec".format(time.perf_counter() - start)) 
            plt.subplots()[1].set_box_aspect(1)   
            plt.cla()
            U = torch.mul(u, u) + torch.mul(v, v)
            U = torch.sqrt(U)
            plt.contourf(np.transpose(torch.reshape(U, (N, N)).cpu().detach().numpy()), levels=20)
            plt.colorbar()
            plt.pause(0.1)     

    print("Timer: {0:.2f} sec".format(time.perf_counter() - start))

    #####
    
    u = u[1:-1, 1:-1]
    v = v[1:-1, 1:-1]
    p = p[1:-1, 1:-1]
    N = N - 2
    
    
    #####

    print("Epoch: ", t, " ,loss: ", loss)

    U = torch.mul(u, u) + torch.mul(v, v)
    U = torch.sqrt(U)
    UUC = np.sqrt(np.multiply(UC, UC)+np.multiply(VC, VC))
    
    print("U")
    plt.subplots()[1].set_box_aspect(1)  
    plt.cla()
    obj = np.transpose(torch.reshape(U, (N, N)).cpu().detach().numpy())
    plt.contourf(obj, vmax = UUC.max(), vmin = UUC.min(), levels=20)
    plt.colorbar()
    plt.pause(0.1)

    print("u")
    plt.subplots()[1].set_box_aspect(1)   
    plt.cla()
    obj = np.transpose(torch.reshape(u, (N, N)).cpu().detach().numpy())
    plt.contourf(obj, vmax = UC.max(), vmin = UC.min(), levels=20)
    plt.colorbar()
    plt.pause(0.1)   

    print("v")
    plt.subplots()[1].set_box_aspect(1)   
    plt.cla()
    obj = np.transpose(torch.reshape(v, (N, N)).cpu().detach().numpy())
    plt.contourf(obj, vmax = VC.max(), vmin = VC.min(), levels=20)
    plt.colorbar()
    plt.pause(0.1)   

    print("p")
    plt.subplots()[1].set_box_aspect(1)   
    plt.cla()
    obj = np.transpose(torch.reshape(p, (N, N)).cpu().detach().numpy())
    plt.contourf(obj, vmax = PC.max(), vmin = PC.min(), levels=20)
    plt.colorbar()
    plt.pause(0.1)   

    plt.ioff()
    plt.show()


    #####

    print("Loss curve: ")
    plt.plot(loss_epoch, loss_values, loss_epoch, loss_ns_x_values, loss_epoch, loss_ns_y_values, loss_epoch, loss_conti_values, loss_epoch, loss_DBC_values, loss_epoch, loss_NBC_values)
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.legend(['$\mathcal{L}$', '$\mathcal{L}_{ns_x}$', '$\mathcal{L}_{ns_y}$', '$\mathcal{L}_{conti}$', '$\mathcal{L}_{DBC}$','$\mathcal{L}_{NBC}$'])
    plt.semilogy()
    plt.pause(0.1)

    #####

    print("Error: ")

    print("U")
    plt.subplots()[1].set_box_aspect(1)  
    plt.cla()
    
    
    obj = np.transpose(torch.reshape(U, (N, N)).cpu().detach().numpy() - UUC)
    plt.contourf(obj, vmax = UUC.max(), vmin = UUC.min(), levels=20)
    plt.colorbar()
    plt.pause(0.1)
    
    print("u")
    plt.subplots()[1].set_box_aspect(1)   
    plt.cla()
    obj = np.transpose(torch.reshape(u, (N, N)).cpu().detach().numpy() - UC)
    plt.contourf(obj, vmax = UC.max(), vmin = UC.min(), levels=20)
    plt.colorbar()
    plt.pause(0.1)   

    print("v")
    plt.subplots()[1].set_box_aspect(1)   
    plt.cla()
    obj = np.transpose(torch.reshape(v, (N, N)).cpu().detach().numpy() - VC)
    plt.contourf(obj, vmax = VC.max(), vmin = VC.min(), levels=20)
    plt.colorbar()
    plt.pause(0.1)   

    print("p")
    plt.subplots()[1].set_box_aspect(1)   
    plt.cla()
    obj = np.transpose(torch.reshape(p, (N, N)).cpu().detach().numpy() - PC)
    plt.contourf(obj, vmax = PC.max(), vmin = PC.min(), levels=20)
    plt.colorbar()
    plt.pause(0.1)   

    plt.ioff()
    plt.show()
    
    #####
   
    print("x = 0.5: ") 
    fig, ax = plt.subplots(1, 1, figsize=(17,5))

    ax.plot(np.transpose(UC)[:,int((N+1)/2)], range(N), u[int((N-1)/2),:].cpu().detach().numpy(), range(N), 'r--' )
    ax.set(ylim=(0, N - 1))
    ax.set_box_aspect(1)
    plt.xlabel('u')
    plt.ylabel('j')
    plt.legend(['Ground truth', 'PINN'])
    plt.pause(0.1)   

    fig, ax = plt.subplots(1, 1, figsize=(17,5))
    ax.plot(range(N), np.transpose(VC)[int((N+1)/2)], range(N), v[:, int((N-1)/2)].cpu().detach().numpy(), 'r--')
    ax.set(xlim=(0, N - 1))
    ax.set_box_aspect(1)
    plt.xlabel('i')
    plt.ylabel('v')
    plt.legend(['Ground truth', 'PINN'])
    plt.pause(0.1)   
    
    #####
    
    pred = UUC
    exact = torch.reshape(U, (N, N)).cpu().detach().numpy()
    mse = np.mean(np.power((pred - exact), 2))
    
    pred_1 = UC[int((N-1)/2),:]
    exact_1 = torch.reshape(u, (N, N)).cpu().detach().numpy()[int((N-1)/2),:]
    mse_1 = np.mean(np.power((pred_1 - exact_1), 2))

    pred_2 = VC[:, int((N-1)/2)]
    exact_2 = torch.reshape(v, (N, N)).cpu().detach().numpy()[:, int((N-1)/2)]
    mse_2 = np.mean(np.power((pred_2 - exact_2), 2))

    print(f" ||||| {loss.item() :.4} | {loss_ns_x.item() :.4} | {loss_ns_y.item() :.4} | {loss_conti.item() :.4} | {loss_DBC.item() :.4} | {loss_NBC.item() :.4} | {time.perf_counter() - start - 4.59 :.2f} sec | {mse :.4} | {mse_1 :.4} | {mse_2 :.4} |")
    
    return np.transpose(torch.reshape(u, (N, N)).cpu().detach().numpy()), np.transpose(torch.reshape(v, (N, N)).cpu().detach().numpy()), np.transpose(torch.reshape(p, (N, N)).cpu().detach().numpy())

### 2.4 post process

In [None]:
def plotmid(UC, VC, PC, u1, v1, p1, u2, v2, p2, u3, v3, p3, N, Re, A = 0.5, B = 0.5, C = 0.5, a = 0.1, b = 0.1, c = 0.3):
    
    ############################################
    
    x_b=[0, 0.0625, 0.0703, 0.0781, 0.0938, 0.1563, 0.2266, 0.2344, 0.5000, 0.8047, 0.8594, 0.9063, 0.9453, 0.9531, 0.9609, 0.9688, 1.0]
    y_b=[0, 0.0547, 0.0625, 0.0703, 0.1016, 0.1719, 0.2813, 0.4531, 0.5000, 0.6172, 0.7344, 0.8516, 0.9531, 0.9609, 0.9688, 0.9766, 1.0]
    Re_100_v=[0, 0.09233, 0.10091, 0.10890, 0.12317, 0.16077, 0.17507, 0.17527, 0.05454, -0.24533, -0.22445, -0.16914, -0.10313, -0.08864, -0.07391, -0.05906, 0]
    Re_100_u=[0, -0.03717, -0.04192, -0.04775, -0.06434, -0.10150, -0.15662, -0.21090, -0.20581, -0.13641, 0.00332, 0.23151, 0.68717, 0.73722, 0.78871, 0.84123, 1]

    mesh = np.linspace(0, 1, num = N)
    
    ############################################
    
    
    print("x = 0.5: ") 

    # u at x = 0.5
    fig, ax = plt.subplots(figsize=(10,10))
    color = 'tab:red'
    ax.plot(UC[:,int((N+1)/2)], mesh, 'r-', markerfacecolor='none')
    ax.plot(u1[:,int((N+1)/2)], mesh, 'r--', markerfacecolor='none')
    ax.plot(u2[:,int((N+1)/2)], mesh, 'r:', markerfacecolor='none')
    ax.plot(u3[:,int((N+1)/2)], mesh, ':', c = 'hotpink', markerfacecolor='none')
    if (Re == 100):
        ax.plot(Re_100_u, y_b, 'ro', markerfacecolor='none')
    
    ax.set(ylim=(0, 1))
    ax.set_box_aspect(1)
    ax.tick_params(axis ='x', labelcolor = color)
    ax.tick_params(axis ='y', labelcolor = color)
    plt.xlabel('u', color = color)
    plt.ylabel('y', color = color)
    
    if(Re == 100):
        plt.legend(['$u_{FDM}$', '$\hatu_{0}$', '$\hatu_{NM}$', '${\hatu_{NM}}^2$', 'Ghia et al.'], loc = 'center', bbox_to_anchor=(A - 0.1, a), fontsize="14", frameon=False)
    else:
        plt.legend(['$u_{FDM}$', '$\hatu_{0}$', '$\hatu_{NM}$', '${\hatu_{NM}}^2$'], loc = 'center', bbox_to_anchor=(A - 0.1, a), fontsize="14", frameon=False)
    

    plt.pause(0.1)   

    # v at y = 0.5
    fig, ax = plt.subplots(figsize=(10,10))
    color = 'tab:green'
    ax.plot(mesh, VC[int((N+1)/2)], 'g-', markerfacecolor='none')
    ax.plot(mesh, v1[int((N+1)/2)], 'g--', markerfacecolor='none')
    ax.plot(mesh, v2[int((N+1)/2)], 'g:', markerfacecolor='none')
    ax.plot(mesh, v3[int((N+1)/2)], ':', c = 'lime', markerfacecolor='none')
    if (Re == 100):
        ax.plot(x_b, Re_100_v, 'go', markerfacecolor='none')
    
    ax.set(xlim=(0, 1))
    ax.set_box_aspect(1)
    ax.xaxis.set_label_position('top') 
    ax.yaxis.set_label_position('right') 
    ax.tick_params(axis ='x', labelcolor = color, top=True, labeltop=True, bottom=False, labelbottom=False)
    ax.tick_params(axis ='y', labelcolor = color, left=False, labelleft=False, right=True, labelright=True)
    plt.xlabel('x', color = color)
    plt.ylabel('v', color = color)
    
    
    if(Re == 100):
        plt.legend(['$v_{FDM}$', '$\hatv_{0}$', '$\hatv_{NM}$', '${\hatv_{NM}}^2$', 'Ghia et al.'], loc = 'center', bbox_to_anchor=(A + 0.1, a), fontsize="14", frameon=False)
    else:
        plt.legend(['$v_{FDM}$', '$\hatv_{0}$', '$\hatv_{NM}$', '${\hatv_{NM}}^2$'], loc = 'center', bbox_to_anchor=(A + 0.1, a), fontsize="14", frameon=False)
    
    
    plt.pause(0.1)   


    # p at x = 0.5
    fig, ax = plt.subplots(figsize=(10,10))
    color = 'tab:red'
    ax.plot(PC[:,int((N+1)/2)], mesh, 'r-', markerfacecolor='none')
    ax.plot(p1[:,int((N+1)/2)], mesh, 'r--', markerfacecolor='none')
    ax.plot(p2[:,int((N+1)/2)], mesh, 'r:', markerfacecolor='none')
    ax.plot(p3[:,int((N+1)/2)], mesh, ':', c = 'hotpink', markerfacecolor='none')
    
    ax.set(ylim=(0, 1))
    ax.set_box_aspect(1)
    ax.tick_params(axis ='x', labelcolor = color)
    ax.tick_params(axis ='y', labelcolor = color)
    plt.xlabel('p', color = color)
    plt.ylabel('y', color = color)
    plt.legend(['$p_{FDM}$', '$\hatp_{0}$', '$\hatp_{NM}$', '${\hatp_{NM}}^2$'], loc = 'center', bbox_to_anchor=(B - 0.1, b), fontsize="14", frameon=False)
    plt.pause(0.1)   

    # p at y = 0.5
    fig, ax = plt.subplots(figsize=(10,10))
    color = 'tab:green'
    ax.plot(mesh, PC[int((N+1)/2)], 'g-', markerfacecolor='none')
    ax.plot(mesh, p1[int((N+1)/2)], 'g--', markerfacecolor='none')
    ax.plot(mesh, p2[int((N+1)/2)], 'g:', markerfacecolor='none')
    ax.plot(mesh, p3[int((N+1)/2)], ':', c = 'lime', markerfacecolor='none')
    
    ax.set(xlim=(0, 1))
    ax.set_box_aspect(1)
    ax.xaxis.set_label_position('top') 
    ax.yaxis.set_label_position('right') 
    ax.tick_params(axis ='x', labelcolor = color, top=True, labeltop=True, bottom=False, labelbottom=False)
    ax.tick_params(axis ='y', labelcolor = color, left=False, labelleft=False, right=True, labelright=True)
    plt.xlabel('x', color = color)
    plt.ylabel('p', color = color)
    plt.legend(['$p_{FDM}$', '$\hatp_{0}$', '$\hatp_{NM}$', '${\hatp_{NM}}^2$'], loc = 'center', bbox_to_anchor=(B + 0.1, b), fontsize="14", frameon=False)
    plt.pause(0.1)   
    
    
    h = 1 / (N - 1)
    # px at x = 0.5
    fig, ax = plt.subplots(figsize=(10,10))
    color = 'tab:red'
    ax.plot((PC[:,int((N+1)/2) + 1] - PC[:,int((N+1)/2) - 1]) / (2 * h), mesh, 'r-', markerfacecolor='none')
    ax.plot((p1[:,int((N+1)/2) + 1] - p1[:,int((N+1)/2) - 1]) / (2 * h), mesh, 'r--', markerfacecolor='none')
    ax.plot((p2[:,int((N+1)/2) + 1] - p2[:,int((N+1)/2) - 1]) / (2 * h), mesh, 'r:', markerfacecolor='none')
    
    ax.set(ylim=(0, 1))
    ax.set_box_aspect(1)
    ax.tick_params(axis ='x', labelcolor = color)
    ax.tick_params(axis ='y', labelcolor = color)
    plt.xlabel('$p_x$', color = color)
    plt.ylabel('y', color = color)
    plt.legend(['${p_x}_{FDM}$', '$\hat{p_x}_{0}$', '$\hat{p_x}_{NM}$', '${\hat{p_x}_{NM}}^2$'], loc = 'center', bbox_to_anchor=(C - 0.1, c), fontsize="14", frameon=False)
    plt.pause(0.1)   

    # py at y = 0.5
    fig, ax = plt.subplots(figsize=(10,10))
    color = 'tab:green'
    ax.plot(mesh, (PC[int((N+1)/2) + 1] - PC[int((N+1)/2) - 1]) / (2 * h), 'g-', markerfacecolor='none')
    ax.plot(mesh, (p1[int((N+1)/2) + 1] - p1[int((N+1)/2) - 1]) / (2 * h), 'g--', markerfacecolor='none')
    ax.plot(mesh, (p2[int((N+1)/2) + 1] - p2[int((N+1)/2) - 1]) / (2 * h), 'g:', markerfacecolor='none')
    ax.plot(mesh, (p3[int((N+1)/2) + 1] - p3[int((N+1)/2) - 1]) / (2 * h), ':', c = 'lime', markerfacecolor='none')
    
    ax.set(xlim=(0, 1))
    ax.set_box_aspect(1)
    ax.xaxis.set_label_position('top') 
    ax.yaxis.set_label_position('right') 
    ax.tick_params(axis ='x', labelcolor = color, top=True, labeltop=True, bottom=False, labelbottom=False)
    ax.tick_params(axis ='y', labelcolor = color, left=False, labelleft=False, right=True, labelright=True)
    plt.xlabel('x', color = color)
    plt.ylabel('$p_y$', color = color)
    plt.legend(['${p_y}_{FDM} $', '$\hat{p_y}_{0}$', '$\hat{p_y}_{NM}$', '${\hat{p_y}_{NM}}^2$'], loc = 'center', bbox_to_anchor=(C + 0.1, c), fontsize="14", frameon=False)
    plt.pause(0.1)  

## 3. Experiment

In [None]:
Re = 100
N = 51

In [None]:
UC_100_51, VC_100_51, PC_100_51 = compute_FD(N, Re, tol_v = 1e-2)

In [None]:
sample(N + 2)

In [None]:
u_100_51_0, v_100_51_0, p_100_51_0 = predict_NN(N + 2, Re, DAB = 0, train_epoch = 1000)

In [None]:
u_100_51_1, v_100_51_1, p_100_51_1 = predict_NN(N + 2, Re, DAB = 1, train_epoch = 1000)

In [None]:
u_100_51_2, v_100_51_2, p_100_51_2 = predict_NN(N + 2, Re, DAB = 2, train_epoch = 1000)

In [None]:
plotmid(UC_100_51, VC_100_51, PC_100_51, u_100_51_0, v_100_51_0, p_100_51_0, u_100_51_1, v_100_51_1, p_100_51_1, u_100_51_2, v_100_51_2, p_100_51_2, N = 51, Re = 100, a = 0.3, b = 0.3, C = 0.2, c = 0.2)