In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from scipy.stats.qmc import LatinHypercube

import os
os.environ['KERAS_BACKEND'] = 'torch'
import torch
import keras
from keras import layers

In [None]:
# First we define our data points.
Nbd = 100
N = 5000

# Our boundary conditions are u(x,0) = -sin(pi x); u(-1,t) = u(1,t) = 0
xt_bd = np.vstack((
    np.vstack((np.linspace(-1,1,Nbd),np.zeros(Nbd))).transpose(),
    np.vstack((-np.ones(Nbd),np.linspace(1/Nbd,1,Nbd))).transpose(),
    np.vstack((np.ones(Nbd),np.linspace(1/Nbd,1,Nbd))).transpose()
),dtype=np.float32)
u_bd = np.hstack((
    -np.sin(np.pi*np.linspace(-1,1,Nbd)),
    np.zeros(2*Nbd)
),dtype=np.float32)
ν = 0.01

# We sample interior points with a latin hypercube.
sampler = LatinHypercube(2)
xt = sampler.random(n=N)
xt[:,0] = 2*xt[:,0]-1

xt = np.vstack((xt_bd,xt),dtype=np.float32)

In [None]:
# The custom loss function is the core of the PINN.
def lossfn(y_true,y_pred):
    # Where we have Dirischlet boundary conditions, we can just use that error
    bd_loss = torch.sum(keras.losses.mean_squared_error(y_true,y_pred))/(3*Nbd)

    # Run the model forward and watch the derivatives with respect to (x,t).
    # We make sure to maintain the graph because we'll need it again for the
    # second derivatives.
    # By calling autograd.grad() instead of just u.backward(), we can avoid
    # taking the derivatives of weights on this pass since we don't need them
    xt_tensor = torch.tensor(xt,requires_grad=True, device=y_pred.device)
    xt_tensor.grad = None
    u = model(xt_tensor).squeeze()
    xt_grad = torch.autograd.grad(
        u,xt_tensor,grad_outputs=torch.ones(u.shape,device=u.device),
        retain_graph=True,create_graph=True
    )[0]
    
    du_dx = xt_grad[:,0]
    du_dt = xt_grad[:,1]
    xt_grad2 = torch.autograd.grad(
        du_dx,xt_tensor,grad_outputs=torch.ones(u.shape,device=u.device),
        retain_graph=True
    )[0]
    
    d2u_dx2 = xt_grad2[:,0]

    # compute the physical loss
    residual = du_dt + u * du_dx - ( ν / np.pi) * d2u_dx2
    phys_loss = torch.sum(torch.pow(residual,2))/N

    # Weight the physical and boundary loss with a hyperparameter weighting
    return 5*bd_loss + phys_loss

In [None]:
# A deep neural network with 8 size-20 layers
nnlayers = [20,20,20,20,20,20,20,20]
model = keras.Sequential([])
model.add(keras.Input(shape=(2,)))
for L in nnlayers:
    model.add(layers.Dense(L, activation='tanh'))
model.add(layers.Dense(1))

model.compile(loss=lossfn)

In [None]:
# Unfortunately, this is where we must leave keras behind
# and write a torch-style training loop

def run_epoch(model, input, target):
    # Doing this with a closure() function isn't necessary for
    # most optimizers, but LBFGS needs it and that's the best one
    # for PINNs usually, so it's good practice.
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = lossfn(target,output)
        loss.backward()
        return loss
    
    loss = optimizer.step(closure)

    return loss.item()

In [None]:
epochs = 10000
patience = 10
threshold = 1e-4

losses = np.array([0.]*epochs)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
bar = tqdm(range(epochs))
for e in bar:
    model.train(True)
    loss = run_epoch(model,xt_bd,u_bd)
    losses[e] = loss
    bar.set_description(f'epoch {e+1}, loss: {loss:.3e}')
    
    if e > patience and np.max(
        np.abs(losses[e-patience:e]-loss)
    )<threshold*loss:
        print('Model converged.')
        break

In [None]:
plt.semilogy(losses[0:e])
plt.title('Training convergence');

In [None]:
# Now we run the model for the whole domain and see what it looks like
x_full, t_full = np.meshgrid(np.linspace(-1,1,512),np.linspace(0,1,1024))
xt_full = np.vstack((x_full.ravel(),t_full.ravel())).transpose()
u_full = model.predict(xt_full,batch_size=10000)

u_full = np.reshape(u_full,(1024,512))

In [None]:
_, axs = plt.subplots(3)
axs[0].plot(x_full[-1,:],u_full[-1,:])
axs[1].imshow(u_full,origin='lower',extent=[-1, 1, 0, 1])
axs[2].plot(x_full[0,:],u_full[0,:])
plt.tight_layout()