In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd /content/drive/MyDrive/NPDG/RD

/content/drive/MyDrive/NPDG_Github/RD


In [None]:
# @title import
import math
import random
import scipy

import matplotlib
matplotlib.use('agg')

from matplotlib.pyplot import figure
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

import numpy as np

from torch.func import grad, vmap
from torch.func import jacrev
import time

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from numpy import *
from torch import Tensor
from torch.nn import Parameter
from torch.optim.lr_scheduler import ExponentialLR

import os
import argparse



This notebook computes the numerical solution to a 1D Allen-Cahn equation with $\epsilon_0=0.01$ using the NPDG method.

See Section 5.4 and Appendix F for more detailed description.



Consider the reaction-diffusion (RD) equation (Allen-Cahn eq.) $$\partial_t u = a\Delta u - b f(u)$$ defined on $\Omega = [0,2]$, $a=0.01$, $b=100$. Here $$f(u) = W'(u) = u^3 - u.$$
Here $W(u)=\frac14(u^2-1)^2.$

Suppose we impose the Neumann boundary condition and initial condition $$u(x,0)=(1-\cos(\pi (x-1)))\cos(\pi (x-1)).$$

<!-- Or
$$u(x,0)=(1-\cos(\pi (x-1)))\cos(\pi (x-1))+0.5.$$ -->

In [None]:
# @title set dimension

dim = 1

In [None]:
# @title Check CUDA availability

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)


cuda


In [None]:
# @title Sampler


L = 2.0
def rho_1_sampler(n, dim=dim):
    samples = torch.arange(0, L, 0.001)
    center_samples = torch.arange(L/2-0.2, L/2+0.2, 0.001)
    totsample = torch.cat((samples, center_samples), 0)
    totsample.unsqueeze_(1)
    return totsample.cuda()


def rho_bdry_sampler(n, dim=dim):
    boundary_coord = np.random.randint(2, size=n) * L  # either 1 or 0
    index_randint = np.random.randint(dim, size=n)
    low_dim_sample = rho_1_sampler(n, dim-1).cpu()
    samples = torch.zeros(n, dim)
    for i in range(n):
        x = np.array(low_dim_sample[i, :])
        y = np.zeros(dim)
        if index_randint[i] < dim-1:
            y = np.insert(x, index_randint[i], boundary_coord[i])
        else:
            y = np.append(x, boundary_coord[i])
        samples[i] = torch.tensor(y)
    return samples.cuda()


def rho_bdry_sampler_with_directional_vector(n, dim=dim, L=L):
    if dim == 1:
        samples = torch.zeros(2, 1)
        samples[1] = L
        outward_direction = torch.zeros(2, 1)
        outward_direction[0] = -1
        outward_direction[1] = 1
    else:
        boundary_coord = np.random.randint(2, size=n) * L  # either 1 or 0
        index_randint = np.random.randint(dim, size=n)
        low_dim_sample = rho_1_sampler(n, dim-1).cpu()
        samples = torch.zeros(n, dim)
        outward_direction = torch.zeros(n, dim)
        for i in range(n):
            x = np.array(low_dim_sample[i, :])
            y = np.zeros(dim)
            if index_randint[i] < dim-1:
                y = np.insert(x, index_randint[i], boundary_coord[i])
            else:
                y = np.append(x, boundary_coord[i])
            samples[i] = torch.tensor(y)
            if boundary_coord[i] == L:
                outward_direction[i, index_randint[i]] = 1
            else:
                outward_direction[i, index_randint[i]] = -1
    return samples.cuda(), outward_direction.cuda()


In [None]:
# @title model (MLP, activation=Tanh)


class network_prim(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_prim, self).__init__()

        self.network_length = network_length

        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])
        self.tanh = nn.Tanh()

    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()

    def forward(self, x):
        l = self.linears[0]
        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.tanh(x)
            x = l(x)
        x = self.tanh(x)
        l = self.linears[self.network_length-1]
        x = l(x)

        return x


class network_dual(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension, L):
        super(network_dual, self).__init__()

        self.network_length = network_length

        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension)])
        self.tanh = nn.Tanh()
        self.L = L

    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()

    def forward(self, x):
        # modify = self.modif_function(x)
        l = self.linears[0]
        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.tanh(x)
            x = l(x)
        x = self.tanh(x)
        l = self.linears[self.network_length-1]
        x = l(x)
        # x = x * modify
        return x


class network_dual_on_bdry(nn.Module):
    def __init__(self, network_length, input_dimension, hidden_dimension, output_dimension):
        super(network_dual_on_bdry, self).__init__()

        self.network_length = network_length

        self.linears = nn.ModuleList([nn.Linear(input_dimension, hidden_dimension)])
        self.linears.extend([nn.Linear(hidden_dimension, hidden_dimension) for _ in range(1, network_length-1)])
        self.linears.extend([nn.Linear(hidden_dimension, output_dimension, bias=False)])

        self.tanh = nn.Tanh()


    def initialization(self):
        for l in self.linears:
            l.weight.data.normal_()
            if l.bias is not None:
                l.bias.data.normal_()


    def forward(self, x):

        l = self.linears[0]

        x = l(x)
        for l in self.linears[1: self.network_length-1]:
            x = self.tanh(x)
            x = l(x)
        x = self.tanh(x)
        l = self.linears[self.network_length-1]
        x = l(x)

        return x



In [None]:
# @title Plotting function


def Plot_graph_nn_primal(t, net_primal, L, num_of_intervals, Iter, flag_plot_real, u_real, save_path, z_min, z_max, device, chosen_dim_0, chosen_dim_1, d=dim):
    num_of_intervals = 100
    h = L / num_of_intervals

    coord = np.arange(num_of_intervals+1) * h
    nodes = torch.zeros(num_of_intervals+1, 1)
    nodes[:, 0] = torch.tensor(coord)
    nodes = nodes.cuda()
    u_nodes = net_primal(nodes)
    u_0_nodes = u_0(nodes)

    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111)
    ax.set_xlim([0.0, L])
    ax.set_ylim([z_min, z_max])
    ax.plot(nodes.cpu().detach().numpy(), u_nodes.cpu().detach().numpy(), 'b-', linewidth = 2, label='u(t)')
    ax.plot(nodes.cpu().detach().numpy(), u_0_nodes.cpu().detach().numpy(), 'g--', linewidth = 2, label='u(0)')
    if flag_plot_real == 1:
        ax.plot(nodes.cpu().detach().numpy(), u_real.cpu(), 'r--', linewidth = 2, label='u_IMEX(t)')
    plt.xlabel("x-axis")
    plt.ylabel("y-axis")
    plt.legend(fontsize=50, loc="upper right")
    ax.set_title('Graph of u(t) at physical time t={} (Iteration = {})'.format(t, Iter), fontsize = 20)
    filename = os.path.join(save_path, "({}th Iteration) Graph of u_theta at physical time t={}".format(Iter, t)+ '.pdf')
    plt.savefig(filename)
    plt.close()


def Plot_graph_nn_dual(t, net_dual, L, num_of_intervals, Iter, save_path, z_min, z_max, device, chosen_dim_0, chosen_dim_1, d=dim):
    num_of_intervals = 100
    h = L / num_of_intervals

    coord = np.arange(num_of_intervals+1) * h
    nodes = torch.zeros(num_of_intervals+1, 1)
    nodes[:, 0] = torch.tensor(coord)
    nodes = nodes.cuda()
    u_nodes = net_dual(nodes)

    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111)
    ax.set_xlim([0.0, L])
    ax.set_ylim([z_min, z_max])
    ax.plot(nodes.cpu().detach().numpy(), u_nodes.cpu().detach().numpy(), 'b-', linewidth = 2, label='phi(t)')
    plt.xlabel("x-axis")
    plt.ylabel("y-axis")
    plt.legend(fontsize=50, loc="upper right")
    ax.set_title('Graph of dual network phi at physical time t={} on {}-{} plane. (Iteration = {})'.format(t, chosen_dim_0, chosen_dim_1, Iter), fontsize = 20)
    filename = os.path.join(save_path, "({}th Iteration) Graph of phi_eta at physical time t={}".format(Iter, t)+ '.pdf')
    plt.savefig(filename)
    plt.close()


def Plot_numerical_solution(t, u, L, num_of_intervals, save_path, z_min, z_max, device, chosen_dim_0, chosen_dim_1, d=dim):
    num_of_intervals = 100
    h = L / num_of_intervals

    coord = np.arange(num_of_intervals+1) * h
    nodes = torch.zeros(num_of_intervals+1, 1)
    nodes[:, 0] = torch.tensor(coord)
    nodes = nodes.cuda()
    u_nodes = u
    u_0_nodes = u_0(nodes)

    fig = plt.figure(figsize=(20, 20))
    ax = fig.add_subplot(111)
    ax.set_xlim([0.0, L])
    ax.set_ylim([z_min, z_max])
    ax.plot(nodes.cpu().detach().numpy(), u_nodes.cpu().detach().numpy(), 'b-', linewidth = 2, label='u(t) discrete numerical scheme')
    ax.plot(nodes.cpu().detach().numpy(), u_0_nodes.cpu().detach().numpy(), 'g--', linewidth = 2, label='u(0)')
    plt.xlabel("x-axis")
    plt.ylabel("y-axis")
    plt.legend(fontsize=50, loc="upper right")
    ax.set_title('Graph of implicit numerical scheme solution u(t) at physical time t={}'.format(t), fontsize = 20)
    filename = os.path.join(save_path, "  Graph of implicit numerical scheme solution u_theta at physical time t={}".format( t)+ '.pdf')
    plt.savefig(filename)
    plt.close()



In [None]:
# @title Initial value, necessary constants

def u_0(x):
    u_value = (1 - torch.cos(math.pi * (x - 1))) * torch.cos(math.pi * (x - 1))
    return u_value

a = 0.01  # in this code, epsilon_0 is denoted as a
b = 100.0
def W(x):
    return 1/4 * (x*x - 1)**2

def dW(x):
    return x*x*x - x

def ddW(x):
    return 3 * x*x - 1


In [None]:
# @title Initial loss
def Initial_loss(net_u, N):

    samples = rho_1_sampler(N, dim)
    realsolution = u_0(samples)
    u_x = net_u(samples).cuda()
    diff = realsolution - u_x
    initial_loss = torch.sqrt((diff * diff).mean())

    return initial_loss.cuda()



In [None]:
# @title L2 error

def L2_error(net_u, ureal, N_x, L=L):
    h_x = L / N_x
    samples = (torch.arange(N_x + 1) * h_x).unsqueeze(-1).cuda()
    ux = net_u(samples)
    diff = ux - ureal.unsqueeze(-1)
    error = torch.sqrt((diff*diff).mean())
    return error


In [None]:
# @title fixed point solver used in solving the RD equation (1D)

ave_ddW_bar_u = 2.0
# solve the numerical solution on [0, t] with time stepsize h_t and space discretization h_x = L/N_x using the implicit scheme
# every t_i solve
# (I - ah_tΔ_hx) u + bh_tW'(u) = u_t_i-1
# for u

# fixed point
# U_{k+1} = (I - ah_tΔ_hx)^{-1}(u_t_i-1 - bh_tW'(U_{k }))
def Fixed_pt_solver_1D(t, h_t, L, N_x, iter_num_fixed_pt):

    N_t = int(t / h_t)

    h_x = L / N_x
    k = h_t / (h_x * h_x)
    nodes = torch.arange(N_x + 1) * h_x
    u0 = u_0(nodes)

    # matrix = I - a*k*Δ_hx
    matrix_upperdiag = np.diag(np.ones(N_x), 1)
    matrix_lowerdiag = np.diag(np.ones(N_x), -1)
    Id = np.eye(N_x+1)
    discrete_Lap = - 2 * Id + matrix_upperdiag + matrix_lowerdiag
    discrete_Lap[0, 0] = -1
    discrete_Lap[N_x, N_x] = -1
    discrete_Lap_tensorized = torch.from_numpy(discrete_Lap)
    matrix = (1 + b * h_t * ave_ddW_bar_u) * np.eye(N_x+1) - a * k * discrete_Lap
    inv_matrix = np.linalg.inv(matrix)
    inv_matrix = torch.from_numpy(inv_matrix)

    ut = torch.zeros(N_t, N_x+1)
    u_laststep = u0
    for i in range(N_t):
        # fixed pt
        u_fixed_pt_k = u_laststep
        for k in range(iter_num_fixed_pt):
            rhs = u_laststep - b * h_t * (dW(u_fixed_pt_k) - ave_ddW_bar_u * u_fixed_pt_k)
            u_fixed_pt_k = torch.matmul(inv_matrix, rhs.double())
            # compute residue
            res = (u_fixed_pt_k - u_laststep) / h_t - a * torch.matmul(discrete_Lap_tensorized, u_fixed_pt_k) / (h_x * h_x) + b * dW(u_fixed_pt_k)
            print(torch.norm(res))
            if torch.norm(res) < 1e-6:
                break
        u_laststep = u_fixed_pt_k
        ut[i, :] = u_fixed_pt_k
    return ut.cuda()



In [None]:
# @title IMEX solver for solving the reaction-diffusion equation (1D)

# solve the numerical solution on [0, t] with time stepsize h_t and space discretization h_x = L/N_x using the IMEX scheme
# every t_i solve
# (I - ah_tΔ_hx) u = u_t_i-1 - bh_tf(u_t_i-1)
# for u
def IMEX_solver_1D(t, h_t, L, N_x):

    N_t = int(t / h_t)

    h_x = L / N_x
    k = h_t / (h_x * h_x)
    nodes = torch.arange(N_x + 1) * h_x
    u0 = u_0(nodes)

    # matrix = I - a*k*Δ_hx
    matrix_upperdiag = np.diag(np.ones(N_x), 1)
    matrix_lowerdiag = np.diag(np.ones(N_x), -1)
    Id = np.eye(N_x+1)
    discrete_Lap = - 2 * Id + matrix_upperdiag + matrix_lowerdiag
    discrete_Lap[0, 0] = -1
    discrete_Lap[N_x, N_x] = -1
    matrix = np.eye(N_x+1) - a * k * discrete_Lap
    inv_matrix = np.linalg.inv(matrix)
    inv_matrix = torch.from_numpy(inv_matrix)

    ut = torch.zeros(N_t, N_x+1)
    u_laststep = u0
    for i in range(N_t):
        rhs = u_laststep - b * h_t * dW(u_laststep)
        u_current = torch.matmul(inv_matrix, rhs.double())
        ut[i, :] = u_current
        u_laststep = u_current

    return ut.cuda()


In [None]:
# @title computing Laplacian
from torch.func import hessian, vmap, functional_call


def v_compute_Laplacian(net, samples):
    def compute_Laplacian(x):
        hessian_net = hessian(net, argnums=0)(x) #forward-over-reverse hessian calc.
        laplacian_net = hessian_net.diagonal(0,-2,-1) #use relative dims for vmap (function doesn't see the batch dim of the input)
        return torch.sum(laplacian_net, -1)
    Laplacian_wrt_x = vmap(compute_Laplacian)(samples)
    return Laplacian_wrt_x


In [None]:
# @title PDHG loss


def gradient_nn(network, x):
    input_variable = autograd.Variable(x, requires_grad=True)
    output_value = network(input_variable)
    gradients_x = autograd.grad(outputs=output_value, inputs=input_variable, grad_outputs=torch.ones(output_value.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0]
    return gradients_x


# PDHG loss with no boundary error
# u_laststep = u_t-1(in_samples)
def PDHG_loss_without_bd(net_u, u_laststep, net_phi, in_samples, h_t, ave_value_ddW, a=a, b=b):

    u_x = net_u(in_samples)
    phi_x = net_phi(in_samples)
    grad_u = gradient_nn(net_u, in_samples)
    grad_phi = gradient_nn(net_phi, in_samples)
    dW_u = dW(u_x)
    loss1 = ((u_x + b * h_t * dW_u - u_laststep)/h_t * phi_x).mean()
    loss2 = a * (torch.sum(grad_u * grad_phi, -1).unsqueeze(-1)).mean()

    return loss1 + loss2


# PDHG loss
def PDHG_loss(net_u, u_laststep, net_phi, net_psi, in_samples, bd_samples, h_t, ave_value_ddW, bd_lambda, a=a, b=b):

    u_x = net_u(in_samples)
    phi_x = net_phi(in_samples)
    grad_u = gradient_nn(net_u, in_samples)
    grad_phi = gradient_nn(net_phi, in_samples)
    dW_u = dW(u_x)
    loss1 = ((u_x + b * h_t * dW_u - u_laststep) * phi_x).mean()
    loss2 = a * h_t * (torch.sum(grad_u * grad_phi, -1).unsqueeze(-1)).mean()
    in_loss = loss1 + loss2

    bdry_sample, outward_direction = rho_bdry_sampler_with_directional_vector(2)
    psi_x = net_psi(bdry_sample)
    grad_u = gradient_nn(net_u, bdry_sample)
    directional_grad_u = torch.sum(outward_direction * grad_u, -1).unsqueeze(-1)
    bd_loss = (directional_grad_u * psi_x).mean()

    return in_loss + bd_lambda * bd_loss


def PDHG_loss_typePINN(net_u, u_laststep, net_phi, net_psi, in_samples, bd_samples, h_t, ave_value_ddW, bd_lambda, a=a, b=b):
    u_x = net_u(in_samples)
    phi_x = net_phi(in_samples)
    Lap_u = v_compute_Laplacian(net_u, in_samples)
    dW_u = dW(u_x)
    in_loss = ((u_x + h_t * (- a * Lap_u + b * dW_u) - u_laststep) * phi_x).mean()
    bdry_sample, outward_direction = rho_bdry_sampler_with_directional_vector(2)
    psi_x = net_psi(bdry_sample)
    grad_u = gradient_nn(net_u, bdry_sample)
    directional_grad_u = torch.sum(outward_direction * grad_u, -1).unsqueeze(-1)
    bd_loss = (directional_grad_u * psi_x).mean()
    return in_loss + bd_lambda * bd_loss


# PDHG loss with extrapolation, i.e., replace phi in PDHG_loss by (1+\omega) * phi_k+1 - \omega * phi_k
#                                     replace psi in PDHG_loss by (1+\omega) * psi_k+1 - \omega * psi_k
def PDHG_loss_with_extraplt(net_u, u_laststep, net_phi_1, net_phi_0, net_psi_1, net_psi_0, in_samples, bd_samples, h_t, ave_value_ddW, bd_lambda, omega, a=a, b=b):

    u_x = net_u(in_samples)
    phi_1_x = net_phi_1(in_samples)
    phi_0_x = net_phi_0(in_samples)
    tilde_phi_x = phi_1_x + omega * (phi_1_x - phi_0_x)
    grad_u = gradient_nn(net_u, in_samples)
    grad_phi_1 = gradient_nn(net_phi_1, in_samples)
    grad_phi_0 = gradient_nn(net_phi_0, in_samples)
    grad_tilde_phi = grad_phi_1 + omega * (grad_phi_1 - grad_phi_0)
    dW_u = dW(u_x)
    loss1 = ((u_x + b * h_t * dW_u - u_laststep)  * tilde_phi_x).mean()
    loss2 = a * h_t * (torch.sum(grad_u * grad_tilde_phi, -1).unsqueeze(-1)).mean()
    in_loss = loss1 + loss2

    bdry_sample, outward_direction = rho_bdry_sampler_with_directional_vector(2)
    psi_1_x = net_psi_1(bdry_sample)
    psi_0_x = net_psi_0(bdry_sample)
    tilde_psi_x = psi_1_x + omega * (psi_1_x - psi_0_x)
    grad_u = gradient_nn(net_u, bdry_sample)
    directional_grad_u = torch.sum(outward_direction * grad_u, -1).unsqueeze(-1)
    bd_loss = (directional_grad_u * tilde_psi_x).mean()

    return in_loss + bd_lambda * bd_loss


def PDHG_loss_with_extraplt_typePINN(net_u, u_laststep, net_phi_1, net_phi_0, net_psi_1, net_psi_0, in_samples, bd_samples, h_t, ave_value_ddW, bd_lambda, omega, a=a, b=b):

    u_x = net_u(in_samples)
    phi_1_x = net_phi_1(in_samples)
    phi_0_x = net_phi_0(in_samples)
    tilde_phi_x = phi_1_x + omega * (phi_1_x - phi_0_x)

    Lap_u = v_compute_Laplacian(net_u, in_samples)
    dW_u = dW(u_x)
    in_loss = ((u_x + h_t * (-a * Lap_u + b * dW_u) - u_laststep) * tilde_phi_x).mean()


    bdry_sample, outward_direction = rho_bdry_sampler_with_directional_vector(2)
    psi_1_x = net_psi_1(bdry_sample)
    psi_0_x = net_psi_0(bdry_sample)
    tilde_psi_x = psi_1_x + omega * (psi_1_x - psi_0_x)
    grad_u = gradient_nn(net_u, bdry_sample)
    directional_grad_u = torch.sum(outward_direction * grad_u, -1).unsqueeze(-1)
    bd_loss = (directional_grad_u * tilde_psi_x).mean()

    return in_loss + bd_lambda * bd_loss


def L2_norm_sq_phi(net_phi, samples):
    phi_samples = net_phi(samples)
    norm_sq = (phi_samples * phi_samples).mean()
    return norm_sq


def L2_norm_sq_Lap_phi(net_phi, samples):
    lap_phi = v_compute_Laplacian(net_phi, samples)
    norm_sq = torch.sum(lap_phi * lap_phi, -1).mean()
    return norm_sq


def L2_norm_sq_nabla_phi(net_phi, samples):
    grad_phi = gradient_nn(net_phi, samples)
    norm_sq = torch.sum(grad_phi * grad_phi, -1).mean()
    return norm_sq


def L2_norm_D_phi(net_phi, samples, h_t, ave_value_ddW, a=a, b=b):
    phi_x = net_phi(samples)
    phi_sqr = (phi_x * phi_x).mean()
    grad_phi = gradient_nn(net_phi, samples)
    grad_phi_sq = torch.sum(grad_phi, -1).mean()
    return (1 + b * h_t * ave_value_ddW) * phi_sqr + a * h_t * grad_phi_sq


def L2_norm_sq_psi(net_psi, samples):
    psi_samples = net_psi(samples)
    norm_sq = (psi_samples * psi_samples).mean()
    return norm_sq



In [None]:
# @title boundary loss


def Bd_loss(net_u, N):
    bdry_sample, outward_direction = rho_bdry_sampler_with_directional_vector(N)
    grad_u = gradient_nn(net_u, bdry_sample)
    directional_grad_u = torch.sum(outward_direction * grad_u, -1).unsqueeze(-1)
    bd_loss = (directional_grad_u**2).mean()
    return bd_loss


def Loss_2(net_u, N):
    bd_samples = rho_bdry_sampler(N)
    num_u = net_u(bd_samples)
    diff_u_real = num_u
    loss = (diff_u_real * diff_u_real).mean()
    return loss


def Bd_loss_Neumann_use_samples(net_u, bd_samples, outward_direction):
    grad_u = gradient_nn(net_u, bd_samples)
    directional_grad_u = torch.sum(outward_direction * grad_u, -1).unsqueeze(-1)
    bd_loss = (directional_grad_u**2).mean()
    return bd_loss


In [None]:
# @title PINN loss
import torch.autograd.functional as functional


# PINN loss for Laplace equ
def PINN_Loss(net_u, u_laststep, in_samples, h_t, bd_lambda, a=a, b=b):
    u_x = net_u(in_samples)
    Lap_u_x = v_compute_Laplacian(net_u, in_samples)
    dW_u = dW(u_x)
    diff = (u_x - u_laststep - a * h_t * Lap_u_x + b * h_t * dW_u)#/h_t
    residual = (diff * diff).mean()
    bdry_samples, outward_direction = rho_bdry_sampler_with_directional_vector(2)
    grad_u = gradient_nn(net_u, bdry_samples)
    directional_grad_u = torch.sum(outward_direction * grad_u, -1).unsqueeze(-1)
    bd_loss = (directional_grad_u**2).mean()
    PINNloss =  residual + bd_lambda * bd_loss
    return PINNloss



In [None]:
# @title G(\theta) as a linear opt another
import scipy
from scipy.sparse.linalg import LinearOperator


def tensor_to_numpy(u):
    if u.device=="cpu":
        return u.detach().numpy()
    else:
        return u.cpu().detach().numpy()


#################################################################################
# In this document, we define various forms of the precondition matrix M(\theta),
# matrix M(\theta) can be viewed as a "metric tensor" in the parameter space,
# we denote the precondition matrix as "G" throughout the implementation.
#################################################################################

# explicitly form the Gram matrix M(\theta). Only for verification.
def form_metric_tensor(input_dim, net, G_samples, device):
    N = G_samples.size()[0]
    Jacobi_NN = jacrev(functional_call, argnums = 1)
    D_param_D_x_NN = vmap(jacrev(Jacobi_NN, argnums = 2), in_dims = (None, None, 0))
    D_param_D_x_net_on_x = D_param_D_x_NN(net, dict(net.named_parameters()), G_samples)
    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]
    print("Number of params = {}".format(num_params))
    list_of_vectorized_param_gradients = []
    for param_gradients in dict(D_param_D_x_net_on_x).items():
        vectorized_param_gradients = param_gradients[1].view(N, -1, input_dim)
        list_of_vectorized_param_gradients.append(vectorized_param_gradients)
    total_vectorized_param_gradients = torch.cat(list_of_vectorized_param_gradients, 1)
    transpose_total_vectorized_param_gradients = torch.transpose(total_vectorized_param_gradients, 1, 2)
    batched_metric_tensor = torch.matmul(total_vectorized_param_gradients, transpose_total_vectorized_param_gradients)
    metric_tensor = torch.mean(batched_metric_tensor, 0)
    return metric_tensor


def metric_tensor_as_Laplace_op(net, net_auxil, G_samples, vec, device):
    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))
    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net.zero_grad()
    net_auxil.zero_grad()
    laplace_net_x = v_compute_Laplacian(net, G_samples)
    laplace_net_auxil_x = v_compute_Laplacian(net_auxil, G_samples)
    ave_sqr_laplace_net = torch.sum(laplace_net_x * laplace_net_auxil_x) / G_samples.size()[0]
    nabla_theta_ave_sqr_laplace_net = torch.autograd.grad(ave_sqr_laplace_net, net_auxil.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_ave_sqr_laplace_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_sqr_laplace_net)
    vec_dot_nabla_theta_ave_sqr_laplace_net = vectorize_nabla_theta_ave_sqr_laplace_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_sqr_laplace_net, net.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)
    return vectorize_metric_tensor_mult_vec


def metric_tensor_as_nabla_op(net, net_auxil, G_samples, vec, device):
    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))
    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net.zero_grad()
    net_auxil.zero_grad()
    grad_net_x = gradient_nn(net, G_samples)
    grad_net_auxil_x = gradient_nn(net_auxil, G_samples)
    ave_sqr_grad_net = torch.sum(grad_net_x * grad_net_auxil_x) / G_samples.size()[0]  # torch.sum(grad_net_x * grad_net_auxil_x, 2).mean()
    nabla_theta_ave_sqr_grad_net = torch.autograd.grad(ave_sqr_grad_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_ave_sqr_grad_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_sqr_grad_net)
    vec_dot_nabla_theta_ave_sqr_grad_net = vectorize_nabla_theta_ave_sqr_grad_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_sqr_grad_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)
    return vectorize_metric_tensor_mult_vec


def metric_tensor_as_op_identity_part(net, net_auxil, G_samples, vec, device):
    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))
    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net_x = net(G_samples)
    net_auxil_x = net_auxil(G_samples)
    ave_net = torch.sum(net_x * net_auxil_x) / G_samples.size()[0]
    nabla_theta_ave_net = torch.autograd.grad(ave_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_net)
    vec_dot_nabla_theta_ave_net = vectorize_nabla_theta_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)
    return vectorize_metric_tensor_mult_vec


# \mathcal M_p = Id + h_t * W''(u) operator
def metric_tensor_as_op_DF_part(net, net_auxil, net_auxil_auxil, G_samples, vec, device):
    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))
    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net_auxil_auxil_x = net_auxil_auxil(G_samples)
    DFu_net = net_auxil_auxil_x + b * h_t * ddW(net_auxil_auxil_x)
    net_x = net(G_samples)
    net_auxil_x = net_auxil(G_samples)
    ave_net = torch.sum((DFu_net * net_x) * (DFu_net * net_auxil_x)) / G_samples.size()[0]
    nabla_theta_ave_net = torch.autograd.grad(ave_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_net)
    vec_dot_nabla_theta_ave_net = vectorize_nabla_theta_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)
    return vectorize_metric_tensor_mult_vec


def metric_tensor_as_trace_op(net, net_auxil, G_samples, vec, device):
    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))
    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net_x = net(G_samples)
    net_auxil_x = net_auxil(G_samples)
    ave_net = torch.sum(net_x * net_auxil_x) / G_samples.size()[0]
    nabla_theta_ave_net = torch.autograd.grad(ave_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_net)
    vec_dot_nabla_theta_ave_net = vectorize_nabla_theta_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)
    return vectorize_metric_tensor_mult_vec


def metric_tensor_as_Neumann_trace_op(net, net_auxil, G_samples, outward_direction, vec, device):
    num_params = len(torch.nn.utils.parameters_to_vector(net.parameters()))
    params_net = dict(net.named_parameters())
    params_net_auxil = dict(net_auxil.named_parameters())
    ################### computation starts here ##################################
    net_x = net(G_samples)
    grad_net_x = gradient_nn(net, G_samples)
    directional_grad_net_x = torch.sum(outward_direction * grad_net_x, -1).unsqueeze(-1)
    net_auxil_x = net_auxil(G_samples)
    grad_net_auxil_x = gradient_nn(net_auxil, G_samples)
    directional_grad_net_auxil_x = torch.sum(outward_direction * grad_net_auxil_x, -1).unsqueeze(-1)
    ave_net = torch.sum(directional_grad_net_x * directional_grad_net_auxil_x) / G_samples.size()[0]
    nabla_theta_ave_net = torch.autograd.grad(ave_net, net_auxil.parameters(), grad_outputs=None ,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_nabla_theta_net = torch.nn.utils.parameters_to_vector(nabla_theta_ave_net)
    vec_dot_nabla_theta_ave_net = vectorize_nabla_theta_net.dot(vec)
    metric_tensor_mult_vec = torch.autograd.grad(vec_dot_nabla_theta_ave_net, net.parameters(), grad_outputs=None,allow_unused=True, retain_graph=True, create_graph=True)
    vectorize_metric_tensor_mult_vec = torch.nn.utils.parameters_to_vector(metric_tensor_mult_vec)
    return vectorize_metric_tensor_mult_vec


# G1: \mathcal M = Id operator
# G2: \mathcal M = ▽ operator
# G3: \mathcal M = -△ operator
# G4: \mathcal M = lambda T (Trace) operator
# G5: \mathcal M: u \mapsto \partial u \partial n on \partial\Omega
# G6: \mathcal M: u \mapsto (I + h_t*W''(u))u

# In this implementation,
# as epsilon_0 = 0.01,
# we use G65 for u, G1 for \phi, G4 for \psi
# in the current paper.
def minres_solver_G(net, net_auxil, interior_samples, boundary_samples, outward_direction, RHS_vec, device, bd_lambda, max_iternum, minres_tolerance, G_type, ave_value_ddW, h_t=0.1, a=a, b=b, net_auxil_auxil=0):

    num_params = torch.nn.utils.parameters_to_vector(net.parameters()).size()[0]

    def G1_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_op_identity_part(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G2_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_nabla_op(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G3_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_Laplace_op(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G4_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_trace_op(net, net_auxil, boundary_samples, tensorized_vec, device)
        return tensor_to_numpy(bd_lambda * Gv)

    def G5_as_operator(vec):
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_Neumann_trace_op(net, net_auxil, boundary_samples, outward_direction, tensorized_vec, device)
        return tensor_to_numpy(bd_lambda * Gv)

    def G6_as_operator(vec):
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_op_DF_part(net, net_auxil, net_auxil_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G14_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_op_identity_part(net, net_auxil, interior_samples, tensorized_vec, device) + bd_lambda * metric_tensor_as_trace_op(net, net_auxil, boundary_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G24_as_operator(vec):  # input the vector v [on CPU], return vector Gv
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_nabla_op(net, net_auxil, interior_samples, tensorized_vec, device) + bd_lambda * metric_tensor_as_trace_op(net, net_auxil, boundary_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G65_RD_as_operator(vec):
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = metric_tensor_as_op_DF_part(net, net_auxil, net_auxil_auxil, interior_samples, tensorized_vec, device) + bd_lambda * metric_tensor_as_Neumann_trace_op(net, net_auxil, boundary_samples, outward_direction, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G125_RD_as_operator(vec):
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv =  (1+b*h_t*ave_value_ddW) * metric_tensor_as_op_identity_part(net, net_auxil, interior_samples, tensorized_vec, device) + a * h_t * metric_tensor_as_nabla_op(net, net_auxil, interior_samples, tensorized_vec, device) + bd_lambda * metric_tensor_as_Neumann_trace_op(net, net_auxil, boundary_samples, outward_direction, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    def G12_RD_as_operator(vec):
        tensorized_vec = torch.Tensor(vec).to(device)
        Gv = (1+b*h_t*ave_value_ddW) * metric_tensor_as_op_identity_part(net, net_auxil, interior_samples, tensorized_vec, device) + a * h_t * metric_tensor_as_nabla_op(net, net_auxil, interior_samples, tensorized_vec, device)
        return tensor_to_numpy(Gv)

    if G_type == "1":
        G_operator = LinearOperator((num_params, num_params), matvec=G1_as_operator)
    elif G_type == "2":
        G_operator = LinearOperator((num_params, num_params), matvec=G2_as_operator)
    elif G_type == "3":
        G_operator = LinearOperator((num_params, num_params), matvec=G3_as_operator)
    elif G_type == "4":
        G_operator = LinearOperator((num_params, num_params), matvec=G4_as_operator)
    elif G_type == "5":
        G_operator = LinearOperator((num_params, num_params), matvec=G5_as_operator)
    elif G_type == "14":
        G_operator = LinearOperator((num_params, num_params), matvec=G14_as_operator)
    elif G_type == "24":
        G_operator = LinearOperator((num_params, num_params), matvec=G24_as_operator)
    elif G_type == "65":
        G_operator = LinearOperator((num_params, num_params), matvec=G65_RD_as_operator)
    elif G_type == "125":
        G_operator = LinearOperator((num_params, num_params), matvec=G125_RD_as_operator)
    elif G_type == "12":
        G_operator = LinearOperator((num_params, num_params), matvec=G12_RD_as_operator)
    else:
        print("Wrong G_type")

    np_RHS_vec = tensor_to_numpy(RHS_vec)
    sol_vec, info = scipy.sparse.linalg.minres(G_operator, np_RHS_vec, rtol=minres_tolerance, maxiter=max_iternum)
    if (torch.max(torch.isnan(torch.tensor(sol_vec))) > 0):
        print("Got the NAN!!!")
        sol_vec = np_RHS_vec
        info = 0
    tensorized_sol_vec = torch.Tensor(sol_vec).to(device)

    return tensorized_sol_vec, info



In [None]:
# @title update param via line search

def update_param(number_stepsizes, base_stepsize, theta_0, tangent_theta, net_u, net_phi, net_psi, in_samples, bd_samples, bd_lambda, descent_or_ascent, net_type, epsilon, loss_phi_or_psi=None):

    stepsize_list = 0.2 * base_stepsize**np.arange(number_stepsizes)
    min_index = 0
    index = 0
    loss_along_stepsizes = []
    current_min = 1000

    if descent_or_ascent == "descent":
        flag = -1
    else:
        flag = 1

    if net_type == "u":
        net_test = network_prim(network_length, dim, hidden_dimension_net_u, 1).to(device)
    elif net_type == "phi":
        net_test = network_dual(network_length, dim, hidden_dimension_net_phi, 1, L).to(device)
    elif net_type == "psi":
        net_test = network_dual_on_bdry(network_length, dim, hidden_dimension_net_psi, 1).to(device)
    else:
        print("Wrong net_type")

    for stepsize in stepsize_list:
        updated_theta = theta_0 + flag * stepsize * tangent_theta

        torch.nn.utils.vector_to_parameters(updated_theta, net_test.parameters())
        pdhgloss = PDHG_loss(net_u, net_phi, net_psi, in_samples, bd_samples, bd_lambda)
        if net_type == "u":
            loss_regularization = 0
        elif net_type == "phi":
            loss_regularization = loss_phi_or_psi(net_phi, in_samples)
        elif net_type == "psi":
            loss_regularization = loss_phi_or_psi(net_psi, bd_samples)

        loss = pdhgloss - epsilon / 2 * loss_regularization

        loss_along_stepsizes.append( - flag * loss.cpu().detach().numpy() )
        if loss < current_min:
            min_index = index
            current_min = loss

        index = index + 1

    optimal_stepsize = stepsize_list[min_index]
    optimal_updated_theta = theta_0 + flag * optimal_stepsize * tangent_theta

    return optimal_updated_theta, optimal_stepsize, loss_along_stepsizes



In [None]:
# @title NPDHG solver on a single time interval
import pickle


def NPDHG_solver_on_single_time_interval(device, save_path, net_u_laststep, phy_time, h_t, L, ave_value_ddW, N_r, N_b,
                minres_max_iter, minres_tol,
                network_length, hidden_dimension_net_u, hidden_dimension_net_phi, hidden_dimension_net_psi, flag_init,
                iter, phi_psi_iter, u_iter, omega, epsilon,
                plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, flag_plot_real, u_real, N_IMEX,
                number_stepsizes, base_stepsize,
                precond_type,
                bd_lambda,
                stepsize_0=0.2,
                adaptive_or_fixed_stepsize ="fixed", tau_u = 0.5 * 1e-1, tau_phi = 0.9 * 1e-1, tau_psi = 0.9 * 1e-1):

    torch.manual_seed(50)

    # initialize nets
    net_u = network_prim(network_length, dim, hidden_dimension_net_u, 1).to(device)
    net_u.load_state_dict(net_u_laststep.state_dict())
    net_phi = network_dual(network_length, dim, hidden_dimension_net_phi, 1, L).to(device)
    net_psi = network_dual_on_bdry(network_length, dim, hidden_dimension_net_psi, 1).to(device)

    if flag_init == True:
        net_u.initialization()
        net_phi.initialization()
        net_psi.initialization()

    if precond_type == "MpMd_Id":
       G_u_type = "14"
       loss_phi = L2_norm_sq_phi
       G_phi_type = "1"
    elif precond_type == "MpMd_nabla":
       G_u_type = "24"
       loss_phi = L2_norm_sq_nabla_phi
       G_phi_type = "2"
    elif precond_type == "Mp_Id_Md_Laplace":
       G_u_type = "14"
       loss_phi = L2_norm_sq_Lap_phi
       G_phi_type = "3"
    elif precond_type == "RD_eps_small": # use in this test
        G_u_type = "65"
        G_phi_type = "1"
        loss_phi = L2_norm_sq_phi
    elif precond_type == "RD_eps_large":
        G_u_type = "125"
        G_phi_type = "12"
        loss_phi = L2_norm_D_phi
    loss_psi = L2_norm_sq_psi
    G_psi_type = "4"

    comp_time = []
    total_time = 0
    l2error_list = []
    H1error_list = []
    l2res_list = []
    bdryerr_list = []
    preconded_nabla_eta_norm_list = []
    preconded_nabla_eta2_norm_list = []
    preconded_nabla_theta_norm_list = []
    ######################################################### PDHG iterations START HERE ###################################################################################################################
    for t in range(iter):

        t_0 = time.time()

        in_samples = rho_1_sampler(N_r)
        bd_samples, outward_direction = rho_bdry_sampler_with_directional_vector(N_b)
        u_laststep = net_u_laststep(in_samples)

        net_u.zero_grad()
        net_phi.zero_grad()
        net_psi.zero_grad()
        ############################# update phi_\eta #####################################
        original_eta = torch.nn.utils.parameters_to_vector(net_phi.parameters())
        original_eta2 = torch.nn.utils.parameters_to_vector(net_psi.parameters())
        for inner_iter in range(phi_psi_iter):
            ############### compute G(\eta)^{-1} \nabla_\eta loss() #########################
            lossa = PDHG_loss_typePINN(net_u, u_laststep, net_phi, net_psi, in_samples, bd_samples, h_t, ave_value_ddW, bd_lambda) - epsilon/2 * loss_phi(net_phi, in_samples)
            nabla_eta_loss = torch.autograd.grad(lossa, net_phi.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_eta_loss = torch.nn.utils.parameters_to_vector(nabla_eta_loss)

            # copy net_phi for G(\eta) computation
            net_phi_auxil = network_dual(network_length, dim, hidden_dimension_net_phi, 1, L).to(device)
            net_phi_auxil.load_state_dict(net_phi.state_dict())

            # compute G(\eta)^{-1} \nabla_\eta loss()
            G_inv_nabla_eta_loss, info_phi = minres_solver_G(net_phi, net_phi_auxil, in_samples, bd_samples, outward_direction, vectorized_nabla_eta_loss, device, bd_lambda, minres_max_iter, minres_tol, G_phi_type, ave_value_ddW, h_t)

            ############### compute G(\eta2)^{-1} \nabla_\eta2 loss() #########################
            lossa2 = PDHG_loss_typePINN(net_u, u_laststep, net_phi, net_psi, in_samples, bd_samples, h_t, ave_value_ddW, bd_lambda) - epsilon/2 * bd_lambda * loss_psi(net_psi, bd_samples)
            nabla_eta2_loss = torch.autograd.grad(lossa2, net_psi.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_eta2_loss = torch.nn.utils.parameters_to_vector(nabla_eta2_loss)

            # copy net_phi for G(\eta2) computation
            net_psi_auxil = network_dual_on_bdry(network_length, dim, hidden_dimension_net_psi, 1).to(device)
            net_psi_auxil.load_state_dict(net_psi.state_dict())

            # compute G(\eta2)^{-1} \nabla_\eta2 loss()
            G_inv_nabla_eta2_loss, info_psi = minres_solver_G(net_psi, net_psi_auxil, in_samples, bd_samples, outward_direction, vectorized_nabla_eta2_loss, device, bd_lambda, minres_max_iter, minres_tol, G_psi_type, ave_value_ddW, h_t)

            # update \eta and \eta2
            original_eta = torch.nn.utils.parameters_to_vector(net_phi.parameters())
            original_eta2 = torch.nn.utils.parameters_to_vector(net_psi.parameters())
            if adaptive_or_fixed_stepsize == "adaptive":
               updated_eta, tau_phi, value_along_tau_phis = update_param(number_stepsizes, base_stepsize, original_eta, G_inv_nabla_eta_loss, net_u, net_phi, net_psi, in_samples, bd_samples, bd_lambda, "ascent", "phi", epsilon, loss_phi)
               updated_eta2, tau_psi, value_along_tau_psis = update_param(number_stepsizes, base_stepsize, original_eta2, G_inv_nabla_eta2_loss, net_u, net_phi, net_psi, in_samples, bd_samples, bd_lambda, "ascent", "psi", epsilon, loss_psi)
            elif adaptive_or_fixed_stepsize == "fixed":
               updated_eta = original_eta + tau_phi * G_inv_nabla_eta_loss
               updated_eta2 = original_eta2 + tau_psi * G_inv_nabla_eta2_loss
            else:
               raise ValueError("adaptive_or_fixed_stepsize must be 'adaptive' or 'fixed'")
            torch.nn.utils.vector_to_parameters(updated_eta, net_phi.parameters())
            torch.nn.utils.vector_to_parameters(updated_eta2, net_psi.parameters())

        ######################## update theta ##################################
        net_phi_0 = network_dual(network_length, dim, hidden_dimension_net_phi, 1, L).to(device)
        torch.nn.utils.vector_to_parameters(original_eta, net_phi_0.parameters())
        net_psi_0 = network_dual_on_bdry(network_length, dim, hidden_dimension_net_psi, 1).to(device)
        torch.nn.utils.vector_to_parameters(original_eta2, net_psi_0.parameters())
        for inner_iter in range(u_iter):
            # compute G(\theta)^{-1} \nabla_\theta loss()
            loss_pinn = PINN_Loss(net_u, u_laststep, in_samples, h_t, bd_lambda)
            loss_pd = PDHG_loss_with_extraplt_typePINN(net_u, u_laststep, net_phi, net_phi_0, net_psi , net_psi_0, in_samples, bd_samples, h_t, ave_value_ddW, bd_lambda, omega)
            lossb = 1 * loss_pd + loss_pinn
            nabla_theta_loss = torch.autograd.grad(lossb, net_u.parameters(), grad_outputs=None, allow_unused=True, retain_graph=True, create_graph=True)
            vectorized_nabla_theta_loss = torch.nn.utils.parameters_to_vector(nabla_theta_loss)

            # copy net_u for G(\theta) computation
            net_u_auxil = network_prim(network_length, dim, hidden_dimension_net_u, 1).to(device)
            net_u_auxil.load_state_dict(net_u.state_dict())
            net_u_auxil_auxil = network_prim(network_length, dim, hidden_dimension_net_u, 1).to(device)
            net_u_auxil_auxil.load_state_dict(net_u.state_dict())
            # compute G(\theta)^{-1} \nabla_\theta loss()
            G_inv_nabla_theta_loss, info_u = minres_solver_G(net_u, net_u_auxil, in_samples, bd_samples, outward_direction, vectorized_nabla_theta_loss, device, bd_lambda, minres_max_iter, minres_tol, G_u_type, ave_value_ddW, h_t, net_auxil_auxil=net_u_auxil_auxil)

            ############# update theta ####################
            original_theta = torch.nn.utils.parameters_to_vector(net_u.parameters())
            updated_theta = original_theta - tau_u * G_inv_nabla_theta_loss
            torch.nn.utils.vector_to_parameters(updated_theta, net_u.parameters())

        t_1 = time.time()
        total_time = total_time + (t_1 - t_0)
        comp_time.append(total_time)

        ################ plot ##################
        if (t+1) % plot_period == 0:
            z_min = -2.4
            z_max = 1.5
            Plot_graph_nn_primal(phy_time, net_u, L, N_plot, t, flag_plot_real, u_real, save_path, z_min, z_max, device, chosen_dim_0, chosen_dim_1)

        ########################################
        L2error = L2_error(net_u, u_real, N_IMEX)
        l2error_list.append(L2error.cpu().detach())
        bd_samples_for_loss, outward_direction_for_loss = rho_bdry_sampler_with_directional_vector(2000)
        boundary_error = torch.sqrt(Bd_loss_Neumann_use_samples(net_u, bd_samples_for_loss, outward_direction_for_loss))
        bdryerr_list.append(boundary_error.cpu().detach())
        if (t+1) % print_period == 0:
            print("Iter: {}, ".format(t))
            print("L2 error = {}".format(L2error))
            print("boundary loss = {}".format(boundary_error))
        if L2error < 0.5 * 1e-3:
          break

    ######################################################### PDHG iterations END HERE #################################################################################################################
    # save the models
    save_path = os.getcwd()
    filename = os.path.join(save_path, 'time={} netu.pt'.format(phy_time))
    torch.save(net_u.state_dict(), filename)

    save_path = os.getcwd()
    filename = os.path.join(save_path, 'time={} netphi.pt'.format(phy_time))
    torch.save(net_phi.state_dict(), filename)

    save_path = os.getcwd()
    filename = os.path.join(save_path, 'time={} netpsi.pt'.format(phy_time))
    torch.save(net_psi.state_dict(), filename)

    # write down the error
    with open('time={} l2error_list'.format(phy_time), 'wb') as file1:
        pickle.dump(l2error_list, file1)
    with open('time={} boundary_error'.format(phy_time), 'wb') as file3:
        pickle.dump(boundary_error, file3)
    with open('time={} comp_time'.format(phy_time), 'wb') as file_x:
        pickle.dump(comp_time, file_x)

    fig_plot = plt.figure(figsize=(15, 15))
    plt.plot(comp_time, np.log(l2error_list) / np.log(10), color="blue")
    plt.xlabel("Iteration", fontsize=30)
    plt.ylabel("L2 error", fontsize=30)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.title("plot of log_10(Rel L2 error) vs. computation time\n", fontsize=40)
    fig_plot.savefig(os.path.join(save_path, "Plot of the log l2 error vs. comptime.pdf"))
    plt.show()
    plt.close()

    return net_u



In [None]:
# @title Compute RD equation on the entire time interval
# compute sequentially

def NPDHG_solver_on_entire_time_interval(device, save_path, T, h_t, L, ave_value_ddW, N_r, N_b,
                              iter_init_training,
                              minres_max_iter, minres_tol,
                              network_length, hidden_dimension_net_u, hidden_dimension_net_phi, hidden_dimension_net_psi, flag_init,
                              iter, phi_psi_iter, u_iter, omega, epsilon,
                              plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, flag_plot_real, m, N_schm, Iter_num_fixed_pt, # m is the number of subintervals for IMEX scheme
                              number_stepsizes, base_stepsize,
                              precond_type,
                              bd_lambda,
                              stepsize_0=0.2,
                              adaptive_or_fixed_stepsize ="fixed", tau_u = 0.5 * 1e-1, tau_phi = 0.9 * 1e-1, tau_psi = 0.9 * 1e-1):

    N_t = int(T/h_t)

    # compute the numerical solution by fully implicit scheme (FIS)
    ut_implicit_schm = Fixed_pt_solver_1D(T, h_t/m, L, N_schm, Iter_num_fixed_pt )
    # # compute the numerical solution by IMEX scheme (not used)
    # ut_IMEX = IMEX_solver_1D(T, h_t/m, L, N_schm)


    net_u_init = network_prim(network_length, dim, hidden_dimension_net_u, 1).to(device)
    optim_u = torch.optim.Adam(net_u_init.parameters(), lr=1e-4)

    # train for the initial condition
    for k in range(iter_init_training):
      optim_u.zero_grad()
      u0loss = Initial_loss(net_u_init, 500)
      u0loss.backward()
      optim_u.step()
      print(u0loss)
      if u0loss < 1e-4:
        break

    # NPDHG solver
    net_u_laststep = net_u_init
    for k in range(N_t):

        print("-----------------------------------------------------------------------")
        print("Solve on interval [{}, {}]".format(round(k*h_t, 3), round((k+1)*h_t, 3)))
        print("-----------------------------------------------------------------------")

        # use the solution solved from fully implicit scheme as benchmark
        benchmark_u = ut_implicit_schm[k*m+m-1, :].cuda()

        phy_time = (k+1) * h_t


        net_u = NPDHG_solver_on_single_time_interval(device, save_path, net_u_laststep, phy_time, h_t, L, ave_value_ddW, N_r, N_b,
                      minres_max_iter, minres_tol,
                      network_length, hidden_dimension_net_u, hidden_dimension_net_phi, hidden_dimension_net_psi, flag_init,
                      iter, phi_psi_iter, u_iter, omega, epsilon,
                      plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, flag_plot_real, benchmark_u, N_schm,
                      number_stepsizes, base_stepsize,
                      precond_type,
                      bd_lambda,
                      stepsize_0=0.2, adaptive_or_fixed_stepsize ="fixed", tau_u = tau_u, tau_phi = tau_phi, tau_psi = tau_psi )
        net_u_laststep = net_u



In [None]:
# @title  apply NPDHG_solver (with time causality)
save_path = os.getcwd()


T =  0.002
h_t = 0.001
L=2

ave_value_ddW = 2.0

N_r = 2000
N_b = 20

iter_init_training = 20000

minres_max_iter = 800
minres_tol = 1e-4

network_length = 3
hidden_dimension_net_u = 128
hidden_dimension_net_phi = 128
hidden_dimension_net_psi = 32

flag_init = False

iter = 6000
phi_psi_iter = 1
u_iter = 1
omega = 1.0
epsilon = 1

plot_period  = 500
print_period = 100
N_plot = 100
chosen_dim_0 = 0
chosen_dim_1 = 0 # not used
flag_plot_real = True
m = 1 # number of time subintv of discrete scheme
N_space_discrt =   100
Iter_num_fixed_pt = 8000

number_stepsizes = 50
base_stepsize = 0.8

precond_type = "RD_eps_small" # if a is bounded away from 0, use "RD_eps_large"

bd_lambda = 1

NPDHG_solver_on_entire_time_interval(device, save_path, T, h_t, L, ave_value_ddW, N_r, N_b,
                          iter_init_training,
                          minres_max_iter, minres_tol,
                          network_length, hidden_dimension_net_u, hidden_dimension_net_phi, hidden_dimension_net_psi, flag_init,
                          iter, phi_psi_iter, u_iter, omega, epsilon,
                          plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, flag_plot_real, m, N_space_discrt, Iter_num_fixed_pt, # m is the number of subintervals for IMEX scheme
                          number_stepsizes, base_stepsize,
                          precond_type,
                          bd_lambda,
                          stepsize_0=0.2,
                          adaptive_or_fixed_stepsize ="fixed", tau_u = 0.01, tau_phi = 0.02, tau_psi = 0.02)



In [None]:
# @title Compute RD equation on the entire time interval (compute sequentially) Start from check point: time=k_0*ht


def NPDHG_solver_on_entire_time_interval_start_from_t0(k_0, net_u_t_0, device, save_path, T, h_t, L, ave_value_ddW, N_r, N_b,
                              iter_init_training,
                              minres_max_iter, minres_tol,
                              network_length, hidden_dimension_net_u, hidden_dimension_net_phi, hidden_dimension_net_psi, flag_init,
                              iter, phi_psi_iter, u_iter, omega, epsilon,
                              plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, flag_plot_real, m, N_schm, Iter_num_fixed_pt, # m is the number of subintervals for IMEX scheme
                              number_stepsizes, base_stepsize,
                              precond_type,
                              bd_lambda,
                              stepsize_0=0.2,
                              adaptive_or_fixed_stepsize ="fixed", tau_u = 0.5 * 1e-1, tau_phi = 0.9 * 1e-1, tau_psi = 0.9 * 1e-1):

    N_t = int(T/h_t)

    # compute the numerical solution by fully implicit scheme (FIS)
    ut_implicit_schm = Fixed_pt_solver_1D(T, h_t/m, L, N_schm, Iter_num_fixed_pt )
    # compute the numerical solution by IMEX scheme
    ut_IMEX = IMEX_solver_1D(T, h_t/m, L, N_schm)

    net_u_init = network_prim(network_length, dim, hidden_dimension_net_u, 1).to(device)
    net_u_init.load_state_dict(net_u_t_0.state_dict())
    optim_u = torch.optim.Adam(net_u_init.parameters(), lr=1e-4)

    # NPDHG solver
    net_u_laststep = net_u_init
    for k in range(k_0, N_t):

        print("-----------------------------------------------------------------------")
        print("Solve on interval [{}, {}]".format(round(k*h_t, 3), round((k+1)*h_t, 3)))
        print("-----------------------------------------------------------------------")

        # use the solution solved from fully implicit scheme as benchmark
        benchmark_u = ut_implicit_schm[k*m+m-1, :].cuda()

        phy_time = (k+1) * h_t

        if k == 0:
          net_u = NPDHG_solver_on_single_time_interval(device, save_path, net_u_laststep, phy_time, h_t, L, ave_value_ddW, N_r, N_b,
                    minres_max_iter, minres_tol,
                    network_length, hidden_dimension_net_u, hidden_dimension_net_phi, hidden_dimension_net_psi, flag_init,
                    3000, phi_psi_iter, u_iter, omega, epsilon,
                    plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, flag_plot_real, benchmark_u, N_schm,
                    number_stepsizes, base_stepsize,
                    precond_type,
                    bd_lambda,
                    stepsize_0=0.2, adaptive_or_fixed_stepsize ="fixed", tau_u = tau_u, tau_phi = tau_phi, tau_psi = tau_psi )
        else:
          net_u = NPDHG_solver_on_single_time_interval(device, save_path, net_u_laststep, phy_time, h_t, L, ave_value_ddW, N_r, N_b,
                      minres_max_iter, minres_tol,
                      network_length, hidden_dimension_net_u, hidden_dimension_net_phi, hidden_dimension_net_psi, flag_init,
                      iter, phi_psi_iter, u_iter, omega, epsilon,
                      plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, flag_plot_real, benchmark_u, N_schm,
                      number_stepsizes, base_stepsize,
                      precond_type,
                      bd_lambda,
                      stepsize_0=0.2, adaptive_or_fixed_stepsize ="fixed", tau_u = tau_u, tau_phi = tau_phi, tau_psi = tau_psi )
        net_u_laststep = net_u



In [None]:
# @title  apply NPDHG_solver (with time causality) start from time=k_0*ht

save_path = os.getcwd()

# In Google Colab, change to working directory
os.chdir("/content/drive/MyDrive/NPDG_RD")
# Verify
print("Current working directory:", os.getcwd())

T =  0.04
h_t = 0.001
L=2

ave_value_ddW = 2.0

N_r = 2000
N_b = 20

iter_init_training = 20000

minres_max_iter = 800
minres_tol = 1e-4  # 1e-3

network_length = 3
hidden_dimension_net_u = 128
hidden_dimension_net_phi = 128
hidden_dimension_net_psi = 32

flag_init = False

iter = 6000
phi_psi_iter = 1
u_iter = 1
omega = 1.0
epsilon = 1

plot_period  = 500
print_period = 100
N_plot = 100
chosen_dim_0 = 0
chosen_dim_1 = 0 # not used
flag_plot_real = True
m = 1
N_space_discrt =   100
Iter_num_fixed_pt = 8000

number_stepsizes = 50
base_stepsize = 0.8

precond_type = "RD"

bd_lambda = 1

k_0 = 10

net_u_k_0 = network_prim(network_length, dim, hidden_dimension_net_u, 1).to(device)
net_u_k_0.load_state_dict(torch.load(os.path.join(save_path, 'netu_10.pt'), weights_only=True)) # suppose you computed to time = 10 * h_t

NPDHG_solver_on_entire_time_interval_start_from_t0(k_0, net_u_k_0, device, save_path, T, h_t, L, ave_value_ddW, N_r, N_b,
                          iter_init_training,
                          minres_max_iter, minres_tol,
                          network_length, hidden_dimension_net_u, hidden_dimension_net_phi, hidden_dimension_net_psi, flag_init,
                          iter, phi_psi_iter, u_iter, omega, epsilon,
                          plot_period, print_period, N_plot, chosen_dim_0, chosen_dim_1, flag_plot_real, m, N_space_discrt, Iter_num_fixed_pt, # m is the number of subintervals for IMEX scheme
                          number_stepsizes, base_stepsize,
                          precond_type,
                          bd_lambda,
                          stepsize_0=0.2,
                          adaptive_or_fixed_stepsize ="fixed", tau_u = 0.01, tau_phi = 0.02, tau_psi = 0.02)

