# Imports

In [242]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch as T
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from datetime import datetime
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
from scipy import integrate
from sklearn.preprocessing import MinMaxScaler

# Classes + Helpers

## Create U, S

In [261]:
def create_U(length_scale, Nu, Nt):
  """ Create synthetic data for u """

  # Create a kernel and a Gaussian Process Regressor
  kernel = RBF(length_scale)
  gp = GaussianProcessRegressor(kernel=kernel)

  # Create x-axis
  t = np.linspace(0, 1, Nt).reshape(-1, 1)

  # Create u samples --> U
  U = np.zeros((Nu, Nt))
  for i in range(Nu):
    random_state = np.random.randint(0, 10000)
    U[i, :] = gp.sample_y(t, random_state=random_state).flatten()

  return t, U

In [262]:
def create_S(t_samples, u_samples):
  # Interpolate the sampled u values
  def interp_u(t):
      return np.interp(t, t_samples, u_samples)

  # Define the ODE function ds/dt - u(t) = 0
  def myODE(t, s):
      return interp_u(t)

  # Define the time span for integration
  t_span = (t_samples[0], t_samples[-1])

  # Define the initial condition for s
  s0 = 0

  # Call solve_ivp to solve the ODE
  sol = integrate.solve_ivp(myODE, t_span, [s0], t_eval=t_samples)

  # Extract the solution
  t = sol.t
  s = sol.y[0]

  return s, t

## Physics-Informed DeepONet

In [263]:
class FeedForward(nn.Module):
  def __init__(self, layers, activation=nn.ReLU()):
    super().__init__()
    self.layers = layers
    self.n_layers = len(layers)
    self.activation = activation

    dense_layers = [
        self.dense_layer(in_features=self.layers[i],
                         out_features=self.layers[i+1])
        for i in range(self.n_layers - 1)]

    dense_layers.append(nn.Linear(in_features=self.layers[-2],
                                  out_features=self.layers[-1]))

    self.feed_forward = nn.Sequential(*dense_layers)

  def dense_layer(self, in_features, out_features):
    dense_layer = nn.Sequential(
        nn.Linear(in_features=in_features,
                  out_features=out_features),
        self.activation,
    )
    return dense_layer

  def forward(self, x):
    return self.feed_forward(x)

In [264]:
class PIDeepONet(nn.Module):
  def __init__(self,
               branch_layers,
               trunk_layers,
               activation,):
    super().__init__()
    # self.bias = nn.Parameter(T.zeros(1,))
    self.branch_model = FeedForward(branch_layers, activation)
    self.trunk_model = FeedForward(trunk_layers, activation)

  def forward(self, U, t):
    branch_out = self.branch_model.forward(U)   # N x D
    trunk_out = self.trunk_model.forward(t)   # T x D
    S = T.matmul(branch_out, T.transpose(trunk_out, 0, 1))   # N x T   #+ self.bias
    return S

In [265]:
class ODESolver(nn.Module):
  def __init__(self,
               branch_layers,
               trunk_layers,
               activation,
               ODE_weight,
               IC_weight,
               epochs,
               lr,
               patience,
               weight_decay=0,
               dtype=T.float32,
               device='cpu',
               chkpt_path='model.pth'):
    super().__init__()

    self.device = device
    self.dtype = dtype
    self.chkpt_path = chkpt_path
    self.patience = patience
    self.w_ODE = T.tensor(ODE_weight, dtype=dtype, device=device)
    self.w_IC = T.tensor(IC_weight, dtype=dtype, device=device)
    self.epochs = epochs
    self.PIDeepONet = PIDeepONet(branch_layers, trunk_layers, activation)
    self.optimizer = T.optim.Adam(lr=lr, params=self.PIDeepONet.parameters(), weight_decay=weight_decay)

  def forward(self, U, t):
    return self.PIDeepONet.forward(U, t)

  def calculate_derivatives(self, t, U):
    Nu = U.shape[0]
    Nt = len(t)
    S_t = T.zeros(Nu, Nt)
    for i in range(Nu):
      u = U[i, :]
      s_i = self.forward(u, t).squeeze()
      grad_s_i = T.autograd.grad(
          s_i, t,
          grad_outputs=T.ones_like(s_i),
          create_graph=True,
          retain_graph=True
      )
      S_t[i, :] = grad_s_i[0].squeeze()

    return S_t

  def loss_fn(self, t_init, U_init, t, U):
    S_t = self.calculate_derivatives(t, U)
    ODE_residual = S_t - U
    IC_residual = self.forward(U_init, t_init)
    total_loss = self.w_ODE * T.mean(T.square(ODE_residual)) + \
                 self.w_IC * T.mean(F.mse_loss(IC_residual, T.tensor(0, device=self.device, dtype=self.dtype)))
    return total_loss

  def train_step(self, t_init, U_init, t, U):
    # Set 'PIDeepONet' model in training mode
    self.PIDeepONet.train()

    # Calculate loss
    loss = self.loss_fn(t_init, U_init, t, U)

    # Zero the gradients
    self.optimizer.zero_grad()

    # Back-propagate the loss
    loss.backward()

    # Implement one step of optimization
    self.optimizer.step()

    return loss.item()

  def train(self, t_init, U_init, t, U):
    losses = list()
    for epoch in tqdm(range(self.epochs)):
      flag = 0
      loss = self.train_step(t_init, U_init, t, U)

      # if self.use_scheduler:
      #   self.scheduler.step(loss)

      losses.append(loss)
      if epoch == 0:
        best_val_loss = loss
        best_epoch = -1
        self.checkpoint()
        flag = 1
        print(f"Epoch: {epoch+1}/{self.epochs} | Loss: {loss}- *Checkpoint*")
      else:
        if loss < best_val_loss:
          best_val_loss = loss
          best_epoch = epoch
          self.checkpoint()
          flag = 1
          if epoch % 1 == 0:
            print(f"Epoch: {epoch+1}/{self.epochs} | Loss: {loss} - *Checkpoint*")
        elif epoch - best_epoch > self.patience:
          if epoch % 1 == 0:
            print(f"\nEarly stopping applied at epoch {epoch}.")
          break
      if (flag == 0) and (epoch % 1 == 0):
        print(f"Epoch: {epoch+1}/{self.epochs} | Loss: {loss}")

    return losses

  def checkpoint(self):
    T.save({
      "optimizer": self.optimizer.state_dict(),
      "model": self.PIDeepONet.state_dict()
    }, self.chkpt_path)

# MAIN

In [266]:
Nu = 20
Nt = 100
length_scale = 0.4
activation = nn.Tanh()
n_hidden_layers = 3
hidden_layers_size = 50
ODE_weight = 1
IC_weight = 1
epochs = 5000
lr = 1e-3
batch_size = 20
batch_size_init = 20
dtype = T.float32
device = 'cpu'

t_train, U_train = create_U(length_scale, Nu, Nt)

scaler = MinMaxScaler()
U_train_scaled = scaler.fit_transform(U_train)

t_train = T.tensor(t_train, dtype=dtype, device=device, requires_grad=True).reshape(-1, 1)
t_train_init = T.tensor([[0.]], dtype=dtype, device=device)
U_train_scaled = T.tensor(U_train_scaled, dtype=dtype, device=device)
U_train_scaled_init = U_train_scaled[U_train_scaled[:, 0] == 0]

branch_layers = [len(t_train)] + n_hidden_layers * [hidden_layers_size]
trunk_layers = [1] + n_hidden_layers * [hidden_layers_size]

print(U_train_scaled.shape, t_train.shape)
print(U_train_scaled_init.shape)

torch.Size([20, 100]) torch.Size([100, 1])
torch.Size([1, 100])


In [267]:
ode_solver = ODESolver(branch_layers,
                       trunk_layers,
                       activation,
                       ODE_weight,
                       IC_weight,
                       epochs,
                       lr,
                       patience=epochs+1,
                       chkpt_path='model.pth')

losses = ode_solver.train(t_train_init, U_train_scaled_init, t_train, U_train_scaled)

  0%|          | 0/5000 [00:00<?, ?it/s]

  self.w_IC * T.mean(F.mse_loss(IC_residual, T.tensor(0, device=self.device, dtype=self.dtype)))


Epoch: 1/5000 | Loss: 0.31568753719329834- *Checkpoint*
Epoch: 2/5000 | Loss: 0.155369833111763 - *Checkpoint*
Epoch: 3/5000 | Loss: 0.06675956398248672 - *Checkpoint*
Epoch: 4/5000 | Loss: 0.050128232687711716 - *Checkpoint*
Epoch: 5/5000 | Loss: 0.07919018715620041
Epoch: 6/5000 | Loss: 0.08946357667446136
Epoch: 7/5000 | Loss: 0.07410049438476562
Epoch: 8/5000 | Loss: 0.055118378251791
Epoch: 9/5000 | Loss: 0.04596545547246933 - *Checkpoint*
Epoch: 10/5000 | Loss: 0.04751885309815407
Epoch: 11/5000 | Loss: 0.05368581786751747
Epoch: 12/5000 | Loss: 0.05777481570839882
Epoch: 13/5000 | Loss: 0.05651657655835152
Epoch: 14/5000 | Loss: 0.05061580240726471
Epoch: 15/5000 | Loss: 0.04298719763755798 - *Checkpoint*
Epoch: 16/5000 | Loss: 0.036708783358335495 - *Checkpoint*
Epoch: 17/5000 | Loss: 0.03365742415189743 - *Checkpoint*
Epoch: 18/5000 | Loss: 0.03384966403245926
Epoch: 19/5000 | Loss: 0.035613369196653366
Epoch: 20/5000 | Loss: 0.0367731899023056
Epoch: 21/5000 | Loss: 0.0361150