<a href="https://colab.research.google.com/github/JunaidAkhter/Physics-informed-DeepONets/blob/main/CopyofPytorchAntider.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# `Pytorch` implementation of Physics informed DeepOnet to solve
$$
\frac{ds(x)}{dx} = u(x), \hspace{1cm} x ∈[0, 1]
$$
**Literature:**


1.   [DeepOnets](https://arxiv.org/pdf/1910.03193.pdf)
2.   [Physics Informed DeepONets](https://arxiv.org/pdf/2103.10974.pdf)



In [1]:
#@title importing modules
import numpy as onp
import jax.numpy as np
from jax import random, grad, vmap, jit
from jax.example_libraries import optimizers
from jax.experimental.ode import odeint
from jax.nn import relu
from jax.config import config
import itertools
from functools import partial
from torch.utils import data
from tqdm import trange
import matplotlib.pyplot as plt

%matplotlib inline

### Let us generate the data first.
We use RBF to generate the training and the testing data. \\
**Note:** The data is being generated using `Jax`. However, we use Pytorch for learning. Hence we convert the generated data to numpy arrays which is later on converted to `torch` tensors.

In [2]:
#@title RBF and data generation.
# Length scale of a Gaussian random field (GRF)
length_scale = 0.2

# Define RBF kernel
def RBF(x1, x2, params):
    output_scale, lengthscales = params
    diffs = np.expand_dims(x1 / lengthscales, 1) - \
            np.expand_dims(x2 / lengthscales, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

# Geneate training data corresponding to one input sample
def generate_one_training_data(key, m=100, P=1):
    # Sample GP prior at a fine grid
    N = 512
    gp_params = (1.0, length_scale)
    jitter = 1e-10
    X = np.linspace(0, 1, N)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(N))
    gp_sample = np.dot(L, random.normal(key, (N,)))

    # Create a callable interpolation function
    u_fn = lambda x, t: np.interp(t, X.flatten(), gp_sample)

    # Input sensor locations and measurements
    x = np.linspace(0, 1, m)
    u = vmap(u_fn, in_axes=(None,0))(0.0, x)

    # Output sensor locations and measurements
    y_train = random.uniform(key, (P,)).sort()
    s_train = odeint(u_fn, 0.0, np.hstack((0.0, y_train)))[1:] # JAX has a bug and always returns s(0), so add a dummy entry to y and return s[1:]

    # Tile inputs
    u_train = np.tile(u, (P,1))

    # training data for the residual
    u_r_train = np.tile(u, (m, 1))  # CREATES m COPIES of u.
    y_r_train = x
    s_r_train = u    # STUPID NAMING WALLAHI

    #print("shape of u_r_train:", u_r_train.shape)
    #print("shape of s_r_train:", s_r_train.shape)

    return u_train, y_train, s_train, u_r_train, y_r_train,  s_r_train

# Geneate test data corresponding to one input sample
def generate_one_test_data(key, m=100, P=100):
    # Sample GP prior at a fine grid
    N = 512
    gp_params = (1.0, length_scale)
    jitter = 1e-10
    X = np.linspace(0, 1, N)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(N))
    gp_sample = np.dot(L, random.normal(key, (N,)))

    # Create a callable interpolation function
    u_fn = lambda x, t: np.interp(t, X.flatten(), gp_sample)

    # Input sensor locations and measurements
    x = np.linspace(0, 1, m)
    u = vmap(u_fn, in_axes=(None,0))(0.0, x)

    # Output sensor locations and measurements
    y = np.linspace(0, 1, P)
    s = odeint(u_fn, 0.0, y)

    # Tile inputs
    u = np.tile(u, (P,1))

    return u, y, s

# Geneate training data corresponding to N input sample
def generate_training_data(key, N, m, P):
    config.update("jax_enable_x64", True)
    keys = random.split(key, N)
    gen_fn = jit(lambda key: generate_one_training_data(key, m, P))
    u_train, y_train, s_train, u_r_train, y_r_train, s_r_train = vmap(gen_fn)(keys)

    u_train = np.float32(u_train.reshape(N * P,-1))
    y_train = np.float32(y_train.reshape(N * P,-1))
    s_train = np.float32(s_train.reshape(N * P,-1))

    u_r_train = np.float32(u_r_train.reshape(N * m,-1))
    y_r_train = np.float32(y_r_train.reshape(N * m,-1))
    s_r_train = np.float32(s_r_train.reshape(N * m,-1))

    config.update("jax_enable_x64", False)
    return u_train, y_train, s_train, u_r_train, y_r_train,  s_r_train

# Geneate test data corresponding to N input sample
def generate_test_data(key, N, m, P):
    config.update("jax_enable_x64", True)
    keys = random.split(key, N)
    gen_fn = jit(lambda key: generate_one_test_data(key, m, P))
    u, y, s = vmap(gen_fn)(keys)
    u = np.float32(u.reshape(N * P,-1))
    y = np.float32(y.reshape(N * P,-1))
    s = np.float32(s.reshape(N * P,-1))

    config.update("jax_enable_x64", False)
    return u, y, s

In [3]:
#@title Generating training data and converting jax.numpy to onp
N_train = 10000 # number of input samples
m = 100 # number of input sensors
P_train = 1   # number of output sensors
key_train = random.PRNGKey(0) # use different key for generating training data and test data
u_train, y_train, s_train, u_r_train, y_r_train, s_r_train = generate_training_data(key_train, N_train, m, P_train)

#changing to numpy
#u_train, y_train, s_train, u_r_train, y_r_train, s_r_train = onp.array(u_train), onp.array(y_train), onp.array(s_train), onp.array(u_r_train), onp.array(y_r_train), onp.array(s_r_train)

#print("type of data that we have now:", type(u_train), type(y_train), type(s_train), type(u_r_train), type(y_r_train), type(s_r_train))





In [4]:
print(len(u_train), u_train.shape)
print(len(u_r_train), u_r_train.shape)
print(type(u_train))

10000 (10000, 100)
1000000 (1000000, 100)
<class 'jaxlib.xla_extension.ArrayImpl'>


In [5]:
#@title Generating test data
N_test = 100
P_test = m
key_test = random.PRNGKey(12345)
keys_test = random.split(key_test, N_test)

u_test, y_test, s_test =  generate_test_data(key_test, N_test, m, m)

#u_test, y_test, s_test = onp.array(u_test), onp.array(y_test), onp.array(s_test)
print(type(u_test), type(y_test), type(s_test))

<class 'jaxlib.xla_extension.ArrayImpl'> <class 'jaxlib.xla_extension.ArrayImpl'> <class 'jaxlib.xla_extension.ArrayImpl'>


In [6]:
# Data generator
class DataGenerator(data.Dataset):
    def __init__(self, u, y, s,
                 batch_size=64, rng_key=random.PRNGKey(1234)):
        'Initialization'
        self.u = u # input sample
        self.y = y # location
        self.s = s # labeled data evulated at y (solution measurements, BC/IC conditions, etc.)

        self.N = u.shape[0]
        self.batch_size = batch_size
        self.key = rng_key

    def __getitem__(self, index):
        'Generate one batch of data'
        self.key, subkey = random.split(self.key)
        inputs, outputs = self.__data_generation(subkey)
        return inputs, outputs

    @partial(jit, static_argnums=(0,))
    def __data_generation(self, key):
        'Generates data containing batch_size samples'
        idx = random.choice(key, self.N, (self.batch_size,), replace=False)
        s = self.s[idx,:]
        y = self.y[idx,:]
        u = self.u[idx,:]
        # Construct batch
        inputs = (u, y)
        outputs = s
        return inputs, outputs

## Solving the Problem
Now that we have the data (as numpy arrays which can easily be converted to torch tensors), we would like to use this to learn the operator $G$ as discussed in the paper [Physics Informed DeepONets](https://arxiv.org/pdf/2103.10974.pdf).




In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
from torch.autograd import grad
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ExponentialLR
import itertools
from functools import partial
from torch.utils import data
from tqdm import trange
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from typing import Callable


In [8]:
# Create data set
batch_size = 10000   #TODO: bigger batches do not work for some reasons.
operator_dataset = DataGenerator(u_train, y_train, s_train, batch_size)
physics_dataset = DataGenerator(u_r_train, y_r_train, s_r_train, batch_size)

In [22]:
#@title checking how the batch data looks like
max_epochs = 200

operator_data = iter(operator_dataset)
physics_data = iter(physics_dataset)


pbar = trange(max_epochs)
for epoch in pbar:

    operator_batch = next(operator_data)
    physics_batch = next(physics_data)

    inputs_o, outputs_o = operator_batch
    u_o, y_o = inputs_o
    u_o, y_o, outputs_o = onp.array(u_o), onp.array(y_o), onp.array(outputs_o)
    print("epoch: ", epoch)
    print("size of operator inputs:", len(u_o), len(y_o))
    print("type of operator inputs:", type(u_o), type(y_o))

    inputs_p, outputs_p = operator_batch
    u_p, y_p = inputs_p

    print("size of physics inputs:", len(u_p), len(y_p))

  0%|          | 1/200 [00:01<03:46,  1.14s/it]

epoch:  0
size of operator inputs: 10000 10000
type of operator inputs: <class 'numpy.ndarray'> <class 'numpy.ndarray'>
size of physics inputs: 10000 10000


  1%|          | 2/200 [00:02<03:51,  1.17s/it]

epoch:  1
size of operator inputs: 10000 10000
type of operator inputs: <class 'numpy.ndarray'> <class 'numpy.ndarray'>
size of physics inputs: 10000 10000


  2%|▏         | 3/200 [00:03<03:44,  1.14s/it]

epoch:  2
size of operator inputs: 10000 10000
type of operator inputs: <class 'numpy.ndarray'> <class 'numpy.ndarray'>
size of physics inputs: 10000 10000


  2%|▏         | 4/200 [00:04<03:41,  1.13s/it]

epoch:  3
size of operator inputs: 10000 10000
type of operator inputs: <class 'numpy.ndarray'> <class 'numpy.ndarray'>
size of physics inputs: 10000 10000


  2%|▎         | 5/200 [00:05<03:39,  1.13s/it]

epoch:  4
size of operator inputs: 10000 10000
type of operator inputs: <class 'numpy.ndarray'> <class 'numpy.ndarray'>
size of physics inputs: 10000 10000


  3%|▎         | 6/200 [00:06<03:37,  1.12s/it]

epoch:  5
size of operator inputs: 10000 10000
type of operator inputs: <class 'numpy.ndarray'> <class 'numpy.ndarray'>
size of physics inputs: 10000 10000


  3%|▎         | 6/200 [00:07<04:15,  1.32s/it]


KeyboardInterrupt: ignored

## Defining the DeepOnet

In [12]:
# Define the Vanilla PyTorch model
class MLP(nn.Module):
    def __init__(self, layers, activation=F.relu):
        super(MLP, self).__init__()
        self.activation = activation
        self.layers = nn.ModuleList()
        for i in range(len(layers)-1):
            self.layers.append(nn.Linear(layers[i], layers[i+1]))

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        x = self.layers[-1](x)
        return x

In [13]:
#@title DeepOnet class
class DeepONet(nn.Module):
    def __init__(self, branch_layers, trunk_layers):
        super(DeepONet, self).__init__()

        self.branch = MLP(branch_layers, torch.tanh)
        self.trunk = MLP(trunk_layers, torch.tanh)

    def forward(self, u, y):
        B = self.branch(u)
        T = self.trunk(y)
        #print("B", B)
        outputs = torch.sum(B * T, dim=-1)                                   #WHY IS dim = -1 here?
        return outputs


In [14]:
#@title evaluation and derrivatives
def s(model: nn.Module(), u: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the value of the approximate solution from the DeepONet model"""
    return model(u, y)


def ds(model: nn.Module(), u: torch.Tensor, y: torch.Tensor, order: int = 1) -> torch.Tensor:
    """Compute neural network derivative with respect to input features using PyTorch autograd engine"""

    df_value = s(model, u, y)


    for _ in range(order):
        df_value = torch.autograd.grad(
            df_value.reshape((-1, 1)),
            y,
            grad_outputs=torch.ones_like(y),
            create_graph=True,
            retain_graph=True,
        )[0]

    return df_value


## Defining loss functions that we want to minimize

In [26]:

def residue(model:nn.Module(), u, y):

    #TODO: adapt it to higer order derrivatives. Maybe we need to create a resideue class and define df inside it.

    return ds(model,u, y)


# Define operator loss
def loss_operator(model:nn.Module, batch):
    inputs, outputs = batch
    u, y = inputs

    #converting everything to numpy and then to torch tensor
    u, y, outputs = onp.array(u), onp.array(y), onp.array(outputs)
    u = torch.tensor(u, requires_grad=True)
    y = torch.tensor(y, requires_grad=True)
    outputs = torch.tensor(outputs, requires_grad=True)


    pred = s(model, u, y)
    #printing the size of each array as I realised that I was getting empty arrays with larger batch sizes.
    ''' print("size of pred:", len(pred))
    print("size of u:", len(u))
    print("size of y:", len(y)) '''


    #pred = self(u, y)                                                  #STUPID LINE BY CHAT GPT
    loss = torch.mean((outputs.view(-1) - pred.view(-1))**2)
    return loss

def loss_physics(model:nn.Module(), batch):

    #TODO: clear the air regarding this loss. I think the formulation in the original code is wrong.

    inputs, outputs = batch
    u, y = inputs

    #converting everything to numpy and then to torch tensors
    u, y, outputs = onp.array(u), onp.array(y), onp.array(outputs)
    u = torch.tensor(u, requires_grad=True)
    y = torch.tensor(y, requires_grad=True)
    outputs = torch.tensor(outputs, requires_grad=True)

    pred = residue(model, u, y)


    loss = torch.mean((outputs.view(-1) - pred.view(-1))**2)

    return loss


def total_loss(model:nn.Module(), operator_batch, physics_batch):
    """Summing up the two losses"""
    #TODO: One can think of weighed sum of the two losses instead of plain sum.
    loss_op = loss_operator(model, operator_batch)
    loss_ph = loss_physics(model, physics_batch)

    return loss_op + loss_ph


## Finally we can train the model.

In [30]:
def train(model:nn.Module(),
    operator_dataset,
    physics_dataset,
#    loss_fn: Callable,  #TODO: Make this callable like PINN script
    initial_learning_rate:
    int = 0.01,
    max_epochs: int = 40000,
)-> nn.Module():

    #TODO: Put the following two parameters in arguments.
    decay_steps = 1000
    decay_rate = 0.9


    operator_data = iter(operator_dataset)
    physics_data = iter(physics_dataset)

    pbar = trange(max_epochs)
    for epoch in pbar:

        operator_batch = next(operator_data)
        physics_batch = next(physics_data)

        '''         inputs_o, outputs_o = operator_batch
        u_o, y_o = inputs_o
        u_o, y_o, outputs_o = onp.array(u_o), onp.array(y_o), onp.array(outputs_o)
        #print("epoch: ", epoch)
        print("size of operator inputs:", len(u_o), len(y_o))
        print("type of operator inputs:", type(u_o), type(y_o))
        '''
        # decaying learning rate
        #learning_rate = initial_learning_rate * (decay_rate**(epoch/decay_steps))
        optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)

        #Optimization step
        loss: torch.Tensor = total_loss(model, operator_batch, physics_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 50 == 0:

            #compute losses
            loss_value = total_loss(model, operator_batch, physics_batch)
            loss_operator_value = loss_operator(model, operator_batch)
            loss_physics_value = loss_physics(model, physics_batch)

            print(f"Epoch: {epoch} - Loss: {float(loss_value):>7f}",
                    f"Loss Physics: {float(loss_physics_value):>7f}"
                        f"Loss Operator: {float(loss_operator_value):>7f}")



    return model


In [31]:
# Creating the object PI_DeepOneta
m = 100
branch_layers = [m, 50, 50, 50, 50, 50]
trunk_layers =  [1, 50, 50, 50, 50, 50]
model = DeepONet(branch_layers, trunk_layers)

In [32]:
#LET US TRAIN THE NETWORK NOW
#loss_fn = partial(total_loss, ) #TODO: complete this to make loss_fn callable
trained_model = train(model, operator_dataset, physics_dataset)

  0%|          | 1/40000 [00:02<23:10:05,  2.09s/it]

Epoch: 0 - Loss: 1.244428 Loss Physics: 0.987765Loss Operator: 0.256663


  0%|          | 51/40000 [01:09<15:19:53,  1.38s/it]

Epoch: 50 - Loss: 1.146872 Loss Physics: 0.971613Loss Operator: 0.175259


  0%|          | 101/40000 [02:18<14:25:49,  1.30s/it]

Epoch: 100 - Loss: 1.015501 Loss Physics: 0.845070Loss Operator: 0.170431


  0%|          | 151/40000 [03:26<15:36:23,  1.41s/it]

Epoch: 150 - Loss: 1.288554 Loss Physics: 1.168812Loss Operator: 0.119742


  1%|          | 201/40000 [04:36<17:44:44,  1.61s/it]

Epoch: 200 - Loss: 0.921753 Loss Physics: 0.743011Loss Operator: 0.178742


  1%|          | 251/40000 [05:44<15:44:48,  1.43s/it]

Epoch: 250 - Loss: 0.876132 Loss Physics: 0.722714Loss Operator: 0.153418


  1%|          | 301/40000 [06:52<14:24:33,  1.31s/it]

Epoch: 300 - Loss: 0.725970 Loss Physics: 0.619660Loss Operator: 0.106310


  1%|          | 351/40000 [08:00<16:14:55,  1.48s/it]

Epoch: 350 - Loss: 0.561937 Loss Physics: 0.494897Loss Operator: 0.067040


  1%|          | 401/40000 [09:08<14:49:35,  1.35s/it]

Epoch: 400 - Loss: 0.694941 Loss Physics: 0.577832Loss Operator: 0.117109


  1%|          | 451/40000 [10:18<17:24:15,  1.58s/it]

Epoch: 450 - Loss: 0.559140 Loss Physics: 0.471276Loss Operator: 0.087864


  1%|▏         | 501/40000 [11:26<14:42:52,  1.34s/it]

Epoch: 500 - Loss: 0.580570 Loss Physics: 0.494653Loss Operator: 0.085917


  1%|▏         | 551/40000 [12:36<16:19:58,  1.49s/it]

Epoch: 550 - Loss: 0.562382 Loss Physics: 0.485354Loss Operator: 0.077029


  2%|▏         | 601/40000 [13:45<15:06:37,  1.38s/it]

Epoch: 600 - Loss: 0.439897 Loss Physics: 0.376824Loss Operator: 0.063072


  2%|▏         | 651/40000 [14:55<15:51:38,  1.45s/it]

Epoch: 650 - Loss: 0.469134 Loss Physics: 0.414000Loss Operator: 0.055134


  2%|▏         | 701/40000 [16:03<14:52:44,  1.36s/it]

Epoch: 700 - Loss: 0.500611 Loss Physics: 0.431242Loss Operator: 0.069369


  2%|▏         | 751/40000 [17:13<16:04:15,  1.47s/it]

Epoch: 750 - Loss: 0.307711 Loss Physics: 0.279774Loss Operator: 0.027936


  2%|▏         | 801/40000 [18:20<14:44:54,  1.35s/it]

Epoch: 800 - Loss: 0.430131 Loss Physics: 0.375812Loss Operator: 0.054319


  2%|▏         | 851/40000 [19:29<15:51:16,  1.46s/it]

Epoch: 850 - Loss: 0.467090 Loss Physics: 0.403372Loss Operator: 0.063719


  2%|▏         | 901/40000 [20:37<14:38:44,  1.35s/it]

Epoch: 900 - Loss: 0.619500 Loss Physics: 0.470818Loss Operator: 0.148681


  2%|▏         | 951/40000 [21:46<17:02:01,  1.57s/it]

Epoch: 950 - Loss: 0.407128 Loss Physics: 0.356406Loss Operator: 0.050722


  3%|▎         | 1001/40000 [22:52<14:46:15,  1.36s/it]

Epoch: 1000 - Loss: 0.350971 Loss Physics: 0.297015Loss Operator: 0.053956


  3%|▎         | 1051/40000 [24:00<14:37:36,  1.35s/it]

Epoch: 1050 - Loss: 0.375065 Loss Physics: 0.335332Loss Operator: 0.039733


  3%|▎         | 1101/40000 [25:08<15:36:29,  1.44s/it]

Epoch: 1100 - Loss: 0.461095 Loss Physics: 0.386516Loss Operator: 0.074579


  3%|▎         | 1151/40000 [26:20<16:27:00,  1.52s/it]

Epoch: 1150 - Loss: 0.385907 Loss Physics: 0.307212Loss Operator: 0.078695


  3%|▎         | 1201/40000 [27:30<15:37:21,  1.45s/it]

Epoch: 1200 - Loss: 0.386439 Loss Physics: 0.327964Loss Operator: 0.058475


  3%|▎         | 1251/40000 [28:39<19:32:46,  1.82s/it]

Epoch: 1250 - Loss: 0.355715 Loss Physics: 0.267412Loss Operator: 0.088304


  3%|▎         | 1297/40000 [29:42<14:46:20,  1.37s/it]


KeyboardInterrupt: ignored