# Solving the equation:
# $ u_x -cos(x) = 0$
# s.t.
* # $u(0) = 0$
* # $u(2\pi) = 0$

# Exact solution: $ u(x) = sin(x)$

# Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from datetime import datetime
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Helpers

In [2]:
class FeedForward(nn.Module):
  def __init__(self,
               n_layers,
               layers,
               activation=nn.ReLU()):

    super().__init__()
    self.n_layers = n_layers
    self.layers = layers
    self.activation = activation

    dense_layers = [
        self.dense_layer(in_features=self.layers[i],
                         out_features=self.layers[i+1])
        for i in range(self.n_layers-1)]
    dense_layers.append(nn.Linear(in_features=self.layers[-2],
                                  out_features=self.layers[-1]))

    self.feed_forward = nn.Sequential(*dense_layers)

  def dense_layer(self, in_features, out_features):
    dense_layer = nn.Sequential(
      nn.Linear(in_features=in_features,
                out_features=out_features),
      self.activation,
    )
    return dense_layer

  def forward(self, x):
    return self.feed_forward(x)

In [3]:
class PINN():
  def __init__(self,
               x, xl, xr, ul, ur,
               layers, n_layers,
               epochs, patience,
               optimizer,
               scheduler, use_scheduler,
               scheduler_p, scheduler_f,
               chkpt_path='model.pth',
               lr=1e-3,
               weight_decay=0,
               activation=nn.ReLU(),
               dtype=torch.float32,
               device='cpu'):

    super().__init__()

    # Some constants
    self.device = device
    self.dtype = dtype
    self.x = torch.tensor(x, dtype=dtype, device=device, requires_grad=True)
    self.xl = torch.tensor(xl, dtype=dtype, device=device)
    self.xr = torch.tensor(xr, dtype=dtype, device=device)
    self.ul = torch.tensor(ul, dtype=dtype, device=device)
    self.ur = torch.tensor(ur, dtype=dtype, device=device)
    self.epochs = epochs
    self.chkpt_path = chkpt_path
    self.patience = patience

    # Neural network
    self.feed_forward = FeedForward(n_layers=n_layers,
                                    layers=layers,
                                    activation=activation)

    # Setting optimizer and scheduler
    if optimizer == 'Adam':
      self.optimizer = torch.optim.Adam(lr=lr, params=self.feed_forward.parameters(), weight_decay=weight_decay)
    self.use_scheduler = use_scheduler
    if use_scheduler:
      if scheduler == 'ReduceLROnPlateau':
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=scheduler_f, patience=scheduler_p, verbose=True)

  def calculate_u(self, x):
    return self.feed_forward(x)

  def calculate_f(self, x):
    u = self.calculate_u(x)

    u_x = torch.autograd.grad(
        u, x,
        grad_outputs=torch.ones_like(u),
        retain_graph=True,
        create_graph=True
    )[0]

    f_hat = u_x - torch.cos(x)
    return f_hat

  def loss_fn(self):
    ul_hat = self.calculate_u(xl)
    ur_hat = self.calculate_u(xr)
    f_hat = self.calculate_f(self.x.view(-1, 1))
    loss = torch.mean(f_hat ** 2) + (self.ul - ul_hat) ** 2 + (self.ur - ur_hat) ** 2
    return loss

  def train_step(self):
    self.feed_forward.train()
    loss = 0
    batch_loss = self.loss_fn()
    loss += batch_loss.item()
    self.optimizer.zero_grad()
    batch_loss.backward()
    self.optimizer.step()
    # loss /= len(self.x)
    return loss

  def train(self):
    t0 = pd.Timestamp.now()

    losses = list()
    for epoch in tqdm(range(self.epochs)):
      flag = 0
      loss = self.train_step()

      if self.use_scheduler:
        self.scheduler.step(loss)

      losses.append(loss)
      if epoch == 0:
        best_val_loss = loss
        best_epoch = -1
        self.checkpoint()
        flag = 1
        print(f"Epoch: {epoch+1}/{self.epochs} | Loss: {loss}- *Checkpoint*")
      else:
        if loss < best_val_loss:
          best_val_loss = loss
          best_epoch = epoch
          self.checkpoint()
          flag = 1
          if epoch % 100 == 0:
            print(f"Epoch: {epoch+1}/{self.epochs} | Loss: {loss} - *Checkpoint*")
        elif epoch - best_epoch > self.patience:
          if epoch % 100 == 0:
            print(f"\nEarly stopping applied at epoch {epoch}.")
          break
      if (flag == 0) and (epoch % 100 == 0):
        print(f"Epoch: {epoch+1}/{self.epochs} | Loss: {loss}")

    print(f"\nTOTAL TRAINING TIME: ")
    self.timer(t0, pd.Timestamp.now())

    return losses

  def checkpoint(self):
    torch.save({
      "optimizer": self.optimizer.state_dict(),
      "model": self.feed_forward.state_dict()
    }, self.chkpt_path)

  def timer(self, start, end):
    dt = pd.Timedelta(end - start)
    seconds = dt.seconds
    minutes = seconds / 60.
    hours = seconds / 3600.
    print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds))

  def predict(self, x):
    t0 = datetime.now()
    x = torch.tensor(x[:, 0], dtype=self.dtype, device=self.device, requires_grad=True).float()
    self.feed_forward.eval()
    u = self.calculate_u(x.view(-1, 1))
    f = self.calculate_f(x.view(-1, 1))
    u = u.detach().cpu().numpy()
    f = f.detach().cpu().numpy()
    print(f"TOTAL INFERENCE TIME: ")
    self.timer(t0, datetime.now())
    return u, f

# MAIN

In [13]:
X_MIN, X_MAX, N_X = 0, 2*np.pi, int(1023)

# **************  Create Training data **************
x = torch.linspace(X_MIN, X_MAX, N_X).view(-1, 1)
xl = torch.zeros(size=(1, ), dtype = torch.float32)
xr = 2*np.pi*torch.ones(size=(1, ), dtype = torch.float32)

# Boundary Conditions
ul = torch.zeros(size=(1, ), dtype = torch.float32)
ur = torch.zeros(size=(1, ), dtype = torch.float32)

# TRAIN

In [14]:
EPOCHS = int(10000)
PATIENCE = 100
LAYERS = [1, 50, 50, 20, 50, 50, 1]
N_LAYERS = len(LAYERS) - 1
PATH = "model.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

pinn = PINN(
    x=x, xl=xl, xr=xr, ul=ul, ur=ur,
    layers=LAYERS,
    n_layers=N_LAYERS,
    epochs=EPOCHS,
    patience=EPOCHS+1,
    optimizer='Adam',
    scheduler='ReduceLROnPlateau',
    use_scheduler=False,
    scheduler_p=1000,
    scheduler_f=.33,
    lr=1e-3,
    weight_decay=0,
    activation=nn.ReLU()
)

loss = pinn.train()

  self.x = torch.tensor(x, dtype=dtype, device=device, requires_grad=True)
  self.xl = torch.tensor(xl, dtype=dtype, device=device)
  self.xr = torch.tensor(xr, dtype=dtype, device=device)
  self.ul = torch.tensor(ul, dtype=dtype, device=device)
  self.ur = torch.tensor(ur, dtype=dtype, device=device)


  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch: 1/10000 | Loss: 0.4999271631240845- *Checkpoint*
Epoch: 101/10000 | Loss: 0.31352680921554565
Epoch: 201/10000 | Loss: 0.31118494272232056
Epoch: 301/10000 | Loss: 0.311550498008728
Epoch: 401/10000 | Loss: 0.3116736114025116
Epoch: 501/10000 | Loss: 0.31246688961982727
Epoch: 601/10000 | Loss: 0.31251123547554016
Epoch: 701/10000 | Loss: 0.31249532103538513
Epoch: 801/10000 | Loss: 0.3138500452041626
Epoch: 901/10000 | Loss: 0.3141089677810669
Epoch: 1001/10000 | Loss: 0.3173569142818451
Epoch: 1101/10000 | Loss: 0.3165907859802246
Epoch: 1201/10000 | Loss: 0.3168674111366272
Epoch: 1301/10000 | Loss: 0.31913501024246216
Epoch: 1401/10000 | Loss: 0.31759053468704224
Epoch: 1501/10000 | Loss: 0.3175644874572754
Epoch: 1601/10000 | Loss: 0.3177047371864319
Epoch: 1701/10000 | Loss: 0.3195549249649048
Epoch: 1801/10000 | Loss: 0.3196220397949219
Epoch: 1901/10000 | Loss: 0.31996670365333557
Epoch: 2001/10000 | Loss: 0.3221568763256073
Epoch: 2101/10000 | Loss: 0.32115259766578674


KeyboardInterrupt: ignored

In [None]:
plt.figure(figsize=(30, 10))
plt.plot(loss, label='loss')
plt.legend()
plt.show()

In [None]:
checkpoint = torch.load("model.pth")
pinn.feed_forward.load_state_dict(checkpoint['model'])

# PINN inference
u_hat, gv_hat = pinn.predict(x)

# Real solution
u = torch.sin(x)

plt.figure(figsize=(30, 10))
plt.plot(x, u, label='u_real')
plt.plot(x[0:-1:10], u_hat.flatten()[0:-1:10], "*", label='u_pred')
plt.legend()
plt.show()