## 1 Environment setup

### 1.1 Dependency

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from torchsummary import summary
import time

### 1.2 Connect to machine

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

## 2 Functions

### 2.1 Generate training samples

In [None]:
def sample(N):
    global x
    x = torch.zeros([N, N, 2]).to(device)
    for i in range(N):
        for j in range(N):
            x[i][j][0] = i; # x
            x[i][j][1] = j; # y
    x.requires_grad = True

### 2.2 Generate benchmark solution

In [None]:
def compute_FD(N, Pe, w, tol = 1e-6):
    global Tn
    Tn = np.zeros((N, N))
    Tn[:, N - 1] = 1
    Tn[0, :] = 1
    
    h = 1 / (N - 1)
    iteration = 0
    res = 1
    
#     tol = h^4*0.01
    print("tol: ", tol)
    
    while(res > tol):
        iteration = iteration + 1
        res = 0
        for i in range(1, N - 1):
            for j in range(1, N - 1):
                prev = Tn[i, j]
                Tn[i, j] = Tn[i + 1 , j] + Tn[i - 1, j] + Tn[i, j + 1] + Tn[i, j - 1]
                Tn[i, j] = Tn[i, j] - 1 * Pe * (Tn[i + 1 , j] - Tn[i - 1, j] + Tn[i, j + 1] - Tn[i, j - 1]) * 0.5
                Tn[i, j] = Tn[i, j] * 0.25
                res = max(res, np.abs((Tn[i, j] - prev) / (prev + 1e-20))) 
                Tn[i, j] = prev + w * (Tn[i, j] - prev)
            
        print("iteration: ", iteration, "res: ", res)
        
    plt.subplots()[1].set_box_aspect(1)
    plt.contourf(np.transpose(Tn), vmax = Tn.max(), vmin = Tn.min(), levels = 20) #gnd_truth, PINN_out, error               
    plt.colorbar()
    
    
    fig, ax = plt.subplots(1, 1, figsize=(17,5))
    ax.plot(np.linspace(0, N - 1, N), Tn[int((N-1)/2), :], '-o')
    ax.set(xlim=(0, N))
    ax.set_box_aspect(1)
    plt.xlabel('j')
    plt.ylabel('T')
    plt.legend(['FDM', 'PINN'])
    plt.xlim((0,N-1))
    plt.ylim((0, 1))
    
    return Tn

In [None]:
##### Special case: Pe = inf
def gnd_Peinf(N):
    global Tn
    Tn = np.zeros((N, N))
    Tn[:, N - 1] = 1
    Tn[0, :] = 1
    for j in range(1, N - 1):
        Tn[j, j] = 0.5
        for i in range(1, j):
            Tn[i, j] = 1
           
    plt.subplots()[1].set_box_aspect(1)
    plt.contourf(np.transpose(Tn), vmax = Tn.max(), vmin = Tn.min(), levels = 20) #gnd_truth, PINN_out, error               
    plt.colorbar()
    
    
    fig, ax = plt.subplots(1, 1, figsize=(17,5))
    ax.plot(np.linspace(0, N - 1, N), Tn[int((N-1)/2), :], '-o')
    ax.set(xlim=(0, N))
    ax.set_box_aspect(1)
    plt.xlabel('j')
    plt.ylabel('T')
    plt.legend(['FDM', 'PINN'])
    plt.xlim((0,N-1))
    plt.ylim((0, 1))     
    return Tn

### 2.3 PINN

##### Variables on the focus:
+ N(mesh number).
+ DAB("d"imension "a"nalysis "b"ased, i.e weighing scheme):
    0. loss weight: same weight | matrix: MSE
    1. loss weight: weighted by dimension analyzed result | matrix: MSE
    2. loss weight: weighted by root of dimension analyzed result | matrix: MSE
    3. loss weight: weighted by dimension analyzed result($Pe = \infty$) | matrix: MSE
    4. loss weight: weighted by root of dimension analyzed result($Pe = \infty$) | matrix: MSE
    5. loss weight: same weight | matrix: MAE
    6. loss weight: weighted by dimension analyzed result | matrix: MAE
    7. loss weight: weighted by root of dimension analyzed result | matrix: MAE
    8. loss weight: weighted by dimension analyzed result($Pe = \infty$) | matrix: MAE
    9. loss weight: weighted by root of dimension analyzed result($Pe = \infty$) | matrix: MAE

In [None]:
def predict_NN(N, Pe, Gamma = 1, w_DE = 1, w_BC = 1, DAB = 0, start_epoch = 0, train_epoch = 50000, pl = 64):
    class Net(torch.nn.Module):
        def __init__(self, n_feature, n_hidden, n_hidden1, n_hidden2, n_hidden3, n_output):
            super(Net, self).__init__()
            self.hidden = torch.nn.Linear(n_feature, n_hidden)
            self.hidden1 = torch.nn.Linear(n_hidden, n_hidden1)
            self.hidden2 = torch.nn.Linear(n_hidden1, n_hidden2)
            self.hidden3 = torch.nn.Linear(n_hidden2, n_hidden3)
            self.predict = torch.nn.Linear(n_hidden3, n_output)

        def forward(self, x):
            x = torch.sin(self.hidden(x))
            x = torch.sin(self.hidden1(x))
            x = torch.sin(self.hidden2(x))
            x = torch.sin(self.hidden3(x))
            x = self.predict(x)
            return x
 
    pl_1 = 64
    net = Net(n_feature=2, n_hidden=pl, n_hidden1=pl_1, n_hidden2=pl_1, n_hidden3=pl_1, n_output=1)     # define the network
    net.to(device)
    summary(net, (1,1, 2))

    ########

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    
    loss_func = torch.nn.MSELoss()
    
    if (DAB == 5)|(DAB == 6)|(DAB == 7)|(DAB == 8)|(DAB == 9):
        loss_func = torch.nn.L1Loss()
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.8)
    plt.ion()

    ########

    if(start_epoch == 0):
        loss_epoch=[]
        loss_values = []
        loss_DE_values = []
        loss_DBC_values = []

    h = 1 / (N - 1)
    loss_op = 'mean'


    ###############################################################
    # Define loss weight                                          #
    ###############################################################

    if DAB == 0: #MSE(1)
        b = w_DE
        a = w_BC
    elif DAB == 1: #MSE(DA)
        b = w_DE /((N-1)*(N-1)*(N-1)*(N-1) * Pe * Pe)
        a = w_BC 
    elif DAB == 2: #MSE(Weak)
        b = w_DE /((N-1)*(N-1)*Pe)
        a = w_BC 
    elif DAB == 3: #MSE(DA)(inf)
        b = w_DE /((N-1)*(N-1))
        a = w_BC 
    elif DAB == 4: #MSE(weak)(inf)
        b = w_DE /(N-1)
        a = w_BC 
    elif DAB == 5: #MAE(1)
        b = w_DE #/((N-1))
        a = w_BC #/ 1
    elif DAB == 6: #MAE(DA)
        b = w_DE /((N-1)*(N-1)*Pe)
        a = w_BC 
    elif DAB == 7: #MAE(Weak)
        b = w_DE /((N-1)*(Pe**0.5))
        a = w_BC 
    elif DAB == 8: #MAE(DA)(inf)
        b = w_DE /((N-1))
        a = w_BC 
    elif DAB == 9: #MAE(Weak)(inf)
        b = w_DE /((N-1)**0.5)
        a = w_BC 
        
    ###############################################################
    # Loss weight normalization                                   #
    ###############################################################
    
    print(b, a)
    
    lambda_DE = b / (a + b)
    lambda_DBC = a / (a + b)


    ###############################################################
    # Training                                                    #
    ###############################################################

    start = time.perf_counter()
    for t in range(start_epoch, train_epoch):
        scheduler.step()
        T_out = net(x)
        T = T_out[:, :, 0]


        lap_T = (-4 * T[1:-1, 1:-1] + T[2:, 1:-1] + T[:-2, 1:-1] + T[1:-1, 2:] + T[1:-1, :-2]) / (h * h)
        Tx = (T[2:, :] - T[:-2, :])/ (h * 2) #cds
        Ty = (T[:, 2:] - T[:, :-2])/ (h * 2) #cds
    #     Tx = (T[1:-1, :] - T[:-2, :])/ (h) #up
    #     Ty = (T[:, 1:-1] - T[:, :-2])/ (h) #up
        
        if (Gamma != 0):
            loss_DE = loss_func(lap_T , (Tx[:, 1:-1] + Ty[1:-1, :]) * Pe / h)
        elif(Gamma == 0): #Pe = inf
            loss_DE = loss_func(lap_T * 0 , (Tx[:, 1:-1] + Ty[1:-1, :]))
            
        loss_DBC = loss_func(T[:, 0], torch.zeros((N , 1)).to(device)) #s
        loss_DBC = loss_DBC + loss_func(T[:, -1] - 1, torch.zeros((N , 1)).to(device)) #n
        loss_DBC = loss_DBC + loss_func(T[0, :] - 1, torch.zeros((1 , N)).to(device)) #w
        loss_DBC = loss_DBC + loss_func(T[-1, :], torch.zeros((1 , N)).to(device)) #e

        loss =  lambda_DE * loss_DE + lambda_DBC * loss_DBC

        loss_epoch.append(t)
        loss_DE_values.append(loss_DE.item())
        loss_DBC_values.append(loss_DBC.item())
        loss_values.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    #####
    
    print("Epoch: ", t, " ,loss: ", loss)
    plt.subplots()[1].set_box_aspect(1)  
    plt.cla()
    plt.contourf(np.transpose(torch.reshape(T, (N, N)).cpu().detach().numpy()), vmax = Tn.max(), vmin = Tn.min(), levels = 20)
    plt.colorbar()
    plt.pause(0.1)

    plt.ioff()
    plt.show()
    
    #####
    
    print("Loss curve: ")
    plt.plot(loss_epoch, loss_values, loss_epoch, loss_DE_values, loss_epoch, loss_DBC_values)
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.legend(['$\mathcal{L}$', '$\mathcal{L}_{DE}$','$\mathcal{L}_{DBC}$'])
    plt.semilogy()
    plt.pause(0.1)
    
    #####
    
    print("Error: ")
    plt.subplots()[1].set_box_aspect(1)  
    plt.cla()
    plt.contourf(np.transpose(torch.reshape(T, (N, N)).cpu().detach().numpy() - Tn), vmax = Tn.max(), vmin = Tn.min(), levels = 20)
    plt.colorbar()
    plt.pause(0.1)

    plt.ioff()
    plt.show()
    
    #####
    
    print("x = 0.5: ")
    fig, ax = plt.subplots(1, 1, figsize=(17,5))
    plot_x = np.linspace(0, N - 1, N)
    if Gamma == 0:
        ana = np.zeros(N)
        ana[-int(N/2):] = np.ones(int(N/2))
        ana[-int(N/2)] = 0.5
        ax.plot(plot_x, ana, plot_x, T[int((N-1)/2), :].cpu().detach().numpy(), 'r--')
    elif Gamma != 0:
        ax.plot(plot_x, np.transpose(Tn[int((N-1)/2), :]), plot_x, T[int((N-1)/2), :].cpu().detach().numpy(), 'r--')
    
    ax.set(xlim=(0, N - 1))
    ax.set_box_aspect(1)
    plt.xlabel('j')
    plt.ylabel('T')
    plt.legend(['Ground truth', 'PINN'])
    plt.pause(0.1)
    
    #####
    
    pred = Tn
    exact = torch.reshape(T, (N, N)).cpu().detach().numpy()
    mse = np.mean(np.power((pred - exact), 2))
    
    pred_1 = Tn[int((N-1)/2),:]
    exact_1 = torch.reshape(T, (N, N)).cpu().detach().numpy()[int((N-1)/2),:]
    mse_1 = np.mean(np.power((pred_1 - exact_1), 2))

    print(f" ||||| {loss.item() :.4} | {loss_DE.item() :.4} | {loss_DBC.item() :.4} | {time.perf_counter() - start :.2f} sec | {mse :.4} | {mse_1 :.4} |")
    
    return (torch.reshape(T, (N, N)).cpu().detach().numpy())

### 2.4 Post process

In [None]:
def plotmid(Tn = np.zeros((11, 11)), T1 = np.zeros((11, 11)), T2 = np.zeros((11, 11)), Tx = np.zeros((51, 51)), N = 11, Gamma = 1):
    print("x = 0.5: ")
    fig, ax = plt.subplots(figsize=(10,10))
    plot_x = np.linspace(0, 1, N)
    if Gamma == 0:
        ana = np.zeros(N)
        ana[-int(N/2):] = np.ones(int(N/2))
        ana[int(N/2)] = 0.5
        ax.plot(plot_x, ana, '-k', markerfacecolor='none')
        ax.plot(plot_x, T1[int((N-1)/2), :], 'sr', ms = 10, markerfacecolor='none')
        ax.plot(plot_x, T2[int((N-1)/2), :], '^b', ms = 10, markerfacecolor='none')
        ax.plot(plot_x, Tx[int((N-1)/2), :], '+g', ms = 10, markerfacecolor='none')
        plt.legend(['$T_{true}$', '$\hatT_{0}$', '$\hatT_{NM}$', '${\hatT_{NM}}^2$'], fontsize="14", frameon=False)
        
    elif Gamma != 0:
        ax.plot(plot_x, (Tn[int((N-1)/2), :]), '-k', markerfacecolor='none')
        ax.plot(plot_x, T1[int((N-1)/2), :], 'sr', ms = 10, markerfacecolor='none')
        ax.plot(plot_x, T2[int((N-1)/2), :], '^b', ms = 10, markerfacecolor='none')
        ax.plot(plot_x, Tx[int((N-1)/2), :], '+g', ms = 10, markerfacecolor='none')
        plt.legend(['$T_{FDM}$', '$\hatT_{0}$', '$\hatT_{NM}$', '${\hatT_{NM}}^2$'], fontsize="14", frameon=False)
    
    ax.set(xlim=(0, 1))
    ax.set_box_aspect(1)
    plt.xlabel('y')
    plt.ylabel('T')
    plt.pause(0.1)

## 3. Experiment

In [None]:
###############################################################
# Sample                                                      #
###############################################################

In [None]:
Pe = 100
N = 31

In [None]:
Tn = compute_FD(N, Pe, w = 0.01, tol= 1e-2) # Pe or Gamma, 

In [None]:
sample(N)

In [None]:
T0 = predict_NN(N, Pe, DAB = 0, train_epoch = 1000)

In [None]:
T1 = predict_NN(N, Pe, DAB = 1, train_epoch = 1000)

In [None]:
T2 = predict_NN(N, Pe, DAB = 2, train_epoch = 1000)

In [None]:
plotmid(Tn, T0, T1, T2, N = N)