<a href="https://colab.research.google.com/github/JunaidAkhter/Physics-informed-DeepONets/blob/main/PytorchAntider.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as onp
import jax.numpy as np
from jax import random, grad, vmap, jit
from jax.example_libraries import optimizers
from jax.experimental.ode import odeint
from jax.nn import relu
from jax.config import config

import itertools
from functools import partial
from torch.utils import data
from tqdm import trange
import matplotlib.pyplot as plt

%matplotlib inline

## Let us generate the data first. 
We use RBF to generate the training and the testing data. \\
**Note:** The data is being generated using `Jax`. However, we use Pytorch for learning. Hence we convert the generated data to numpy arrays which is later on converted to `torch` tensors. 

In [2]:
#@title RBF and data generation. 
# Length scale of a Gaussian random field (GRF)
length_scale = 0.2

# Define RBF kernel
def RBF(x1, x2, params):
    output_scale, lengthscales = params
    diffs = np.expand_dims(x1 / lengthscales, 1) - \
            np.expand_dims(x2 / lengthscales, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

# Geneate training data corresponding to one input sample
def generate_one_training_data(key, m=100, P=1):
    # Sample GP prior at a fine grid
    N = 512
    gp_params = (1.0, length_scale)
    jitter = 1e-10
    X = np.linspace(0, 1, N)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(N))
    gp_sample = np.dot(L, random.normal(key, (N,)))

    # Create a callable interpolation function  
    u_fn = lambda x, t: np.interp(t, X.flatten(), gp_sample)

    # Input sensor locations and measurements
    x = np.linspace(0, 1, m)
    u = vmap(u_fn, in_axes=(None,0))(0.0, x)

    # Output sensor locations and measurements
    y_train = random.uniform(key, (P,)).sort() 
    s_train = odeint(u_fn, 0.0, np.hstack((0.0, y_train)))[1:] # JAX has a bug and always returns s(0), so add a dummy entry to y and return s[1:]

    # Tile inputs
    u_train = np.tile(u, (P,1))

    # training data for the residual
    u_r_train = np.tile(u, (m, 1))  # CREATES m COPIES of u.  
    y_r_train = x
    s_r_train = u    # STUPID NAMING WALLAHI

    #print("shape of u_r_train:", u_r_train.shape)
    #print("shape of s_r_train:", s_r_train.shape)

    return u_train, y_train, s_train, u_r_train, y_r_train,  s_r_train

# Geneate test data corresponding to one input sample
def generate_one_test_data(key, m=100, P=100):
    # Sample GP prior at a fine grid
    N = 512
    gp_params = (1.0, length_scale)
    jitter = 1e-10
    X = np.linspace(0, 1, N)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(N))
    gp_sample = np.dot(L, random.normal(key, (N,)))

    # Create a callable interpolation function  
    u_fn = lambda x, t: np.interp(t, X.flatten(), gp_sample)

    # Input sensor locations and measurements
    x = np.linspace(0, 1, m)
    u = vmap(u_fn, in_axes=(None,0))(0.0, x)

    # Output sensor locations and measurements
    y = np.linspace(0, 1, P)
    s = odeint(u_fn, 0.0, y)

    # Tile inputs
    u = np.tile(u, (P,1))

    return u, y, s 

# Geneate training data corresponding to N input sample
def generate_training_data(key, N, m, P):
    config.update("jax_enable_x64", True)
    keys = random.split(key, N)
    gen_fn = jit(lambda key: generate_one_training_data(key, m, P))
    u_train, y_train, s_train, u_r_train, y_r_train, s_r_train = vmap(gen_fn)(keys)

    u_train = np.float32(u_train.reshape(N * P,-1))
    y_train = np.float32(y_train.reshape(N * P,-1))
    s_train = np.float32(s_train.reshape(N * P,-1))

    u_r_train = np.float32(u_r_train.reshape(N * m,-1))
    y_r_train = np.float32(y_r_train.reshape(N * m,-1))
    s_r_train = np.float32(s_r_train.reshape(N * m,-1))

    config.update("jax_enable_x64", False)
    return u_train, y_train, s_train, u_r_train, y_r_train,  s_r_train

# Geneate test data corresponding to N input sample
def generate_test_data(key, N, m, P):
    config.update("jax_enable_x64", True)
    keys = random.split(key, N)
    gen_fn = jit(lambda key: generate_one_test_data(key, m, P))
    u, y, s = vmap(gen_fn)(keys)
    u = np.float32(u.reshape(N * P,-1))
    y = np.float32(y.reshape(N * P,-1))
    s = np.float32(s.reshape(N * P,-1))

    config.update("jax_enable_x64", False)
    return u, y, s

In [3]:
#@title creating training data and converting jax.numpy to onp
N_train = 10000 # number of input samples
m = 100 # number of input sensors
P_train = 1   # number of output sensors
key_train = random.PRNGKey(0) # use different key for generating training data and test data 
u_train, y_train, s_train, u_r_train, y_r_train, s_r_train = generate_training_data(key_train, N_train, m, P_train)

#changing to numpy 
u_train, y_train, s_train, u_r_train, y_r_train, s_r_train = onp.array(u_train), onp.array(y_train), onp.array(s_train), onp.array(u_r_train), onp.array(y_r_train), onp.array(s_r_train) 

print("type of data that we have now:", type(u_train), type(y_train), type(s_train), type(u_r_train), type(y_r_train), type(s_r_train))





type of data that we have now: <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'>


In [4]:
# Generate test data
N_test = 100
P_test = m
key_test = random.PRNGKey(12345)
keys_test = random.split(key_test, N_test)

u_test, y_test, s_test =  generate_test_data(key_test, N_test, m, m)

u_test, y_test, s_test = onp.array(u_test), onp.array(y_test), onp.array(s_test)
print(type(u_test), type(y_test), type(s_test))

<class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'>


In [5]:
# THIS WHOLE SCRIPT CAN BE WRITTEN WITH THE TEMPLATE OF OUR PACKAGE AND IT IS GOING TO LOOK BEAUTIFUL.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
from torch.autograd import grad
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ExponentialLR
import itertools
from functools import partial
from torch.utils import data
from tqdm import trange
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from typing import Callable


In [6]:
class DataGenerator(Dataset):

    """The inputs should be provided as numpy arrays"""

    def __init__(self, u, y, s, batch_size=64):
        self.u = torch.tensor(u, requires_grad=True)  # Convert to PyTorch tensor
        self.y = torch.tensor(y, requires_grad=True)  # Convert to PyTorch tensor
        self.s = torch.tensor(s, requires_grad=True)  # Convert to PyTorch tensor
        
        self.N = u.shape[0]
        self.batch_size = batch_size

    def __len__(self):
        return self.N // self.batch_size

    def __getitem__(self, index):
        start = index * self.batch_size
        end = (index + 1) * self.batch_size

        inputs, outputs = self.__data_generation(start, end)
        return inputs, outputs

    def __data_generation(self, start, end):
        idx = torch.randperm(self.N, device=self.u.device)[start:end]
        s = self.s[idx]
        y = self.y[idx]
        u = self.u[idx]
        
        inputs = (u, y)
        outputs = s
        return inputs, outputs


In [7]:
# Create data set
batch_size = 10000
operator_dataset = DataGenerator(u_train, y_train, s_train, batch_size)
physics_dataset = DataGenerator(u_r_train, y_r_train, s_r_train, batch_size)
operator_data = DataLoader(operator_dataset, batch_size=64, shuffle=True)
physics_data = DataLoader(physics_dataset, batch_size=64, shuffle=True)


In [8]:
# Define the Vanilla PyTorch model
class MLP(nn.Module):
    def __init__(self, layers, activation=F.relu):
        super(MLP, self).__init__()
        self.activation = activation
        self.layers = nn.ModuleList()
        for i in range(len(layers)-1):
            self.layers.append(nn.Linear(layers[i], layers[i+1]))

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        x = self.layers[-1](x)
        return x

In [9]:

class DeepONet(nn.Module):
    def __init__(self, branch_layers, trunk_layers):    
        super(DeepONet, self).__init__()

        #TODO: Make sure that the parameters are getting updated by writing the model this way

        
        # Network initialization and evaluation functions
        self.branch = MLP(branch_layers, torch.tanh)
        self.trunk = MLP(trunk_layers, torch.tanh)

        # TODO: Make sure that model parameters are being fed to the optimizer. Do we need to feed the parameters from both models?
        # TODO: Chat GPT tells that we need concatenate the model parameters from branch and trunk networks and feed it to the optimizer like below.  

        #parameters = list(model1.parameters()) + list(model2.parameters())
        # Define the optimizer
        
        # Define DeepONet architecture
    def forward(self, u, y):
        B = self.branch(u)
        T = self.trunk(y)
        #print("B", B)
        outputs = torch.sum(B * T, dim=-1)                                   #WHY IS dim = -1 here?
        return outputs


In [10]:

def s(model: nn.Module(), u: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute the value of the approximate solution from the DeepONet model"""
    return model(u, y)


def ds(model: nn.Module(), u: torch.Tensor, y: torch.Tensor, order: int = 1) -> torch.Tensor:
    """Compute neural network derivative with respect to input features using PyTorch autograd engine"""

    df_value = s(model, u, y)


    for _ in range(order):
        df_value = torch.autograd.grad(
            df_value.reshape((-1, 1)),
            y,
            grad_outputs=torch.ones_like(y),
            create_graph=True,
            retain_graph=True,
        )[0]

    return df_value


In [11]:

def residue(model:nn.Module(), u, y):

    #TODO: adapt it to higer order derrivatives. Maybe we need to create a resideue class and define df inside it. 

    return ds(model,u, y)


# Define operator loss
def loss_operator(model:nn.Module, batch):
    inputs, outputs = batch
    u, y = inputs
    
    pred = s(model, u, y)

    #pred = self(u, y)                                                  #STUPID LINE BY CHAT GPT
    loss = torch.mean((outputs.view(-1) - pred.view(-1))**2)
    return loss

def loss_physics(model:nn.Module(), batch):

    #TODO: clear the air regarding this loss. I think the formulation in the original code is wrong. 

    inputs, outputs = batch
    u, y = inputs
    
    pred = residue(model, u, y)
    
    loss = torch.mean((outputs.view(-1) - pred.view(-1))**2)

    return loss 


def total_loss(model:nn.Module(), operator_batch, physics_batch):
    """Summing up the two losses"""
    #TODO: One can think of weighed sum of the two losses instead of plain sum. 
    loss_op = loss_operator(model, operator_batch)
    loss_ph = loss_physics(model, physics_batch)

    return loss_op + loss_ph


In [15]:


def train(model:nn.Module(), 
    operator_dataset, 
    physics_dataset, 
#    loss_fn: Callable,  #TODO: Make this callable like PINN script
    learning_rate: 
    int = 0.01,
    max_epochs: int = 4_0000,
)-> nn.Module():

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    operator_data = iter(operator_dataset)
    physics_data = iter(physics_dataset)

    pbar = trange(max_epochs)
    for epoch in pbar:
        operator_batch = next(operator_data)
        physics_batch = next(physics_data)


            
        #Optimization step
        loss: torch.Tensor = total_loss(model, operator_batch, physics_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 50 == 0:
            print(f"Epoch: {epoch} - Loss: {float(loss):>7f}")



    return model


In [16]:
# Creating the object PI_DeepOneta
m = 100
branch_layers = [m, 50, 50, 50, 50, 50]
trunk_layers =  [1, 50, 50, 50, 50, 50]
model = DeepONet(branch_layers, trunk_layers)

In [17]:
#LET US TRAIN THE NETWORK NOW
#loss_fn = partial(total_loss, ) #TODO: complete this to make loss_fn callable
trained_model = train(model, operator_dataset, physics_dataset)

  0%|          | 1/40000 [00:00<6:20:27,  1.75it/s]

Epoch: 0 - Loss: 1.206528


  0%|          | 51/40000 [00:26<4:52:53,  2.27it/s]

Epoch: 50 - Loss:     nan


  0%|          | 101/40000 [00:52<4:59:41,  2.22it/s]

Epoch: 100 - Loss:     nan


  0%|          | 151/40000 [01:13<4:59:20,  2.22it/s]

Epoch: 150 - Loss:     nan


  1%|          | 201/40000 [01:36<5:28:33,  2.02it/s]

Epoch: 200 - Loss:     nan


  1%|          | 251/40000 [01:59<5:36:52,  1.97it/s]

Epoch: 250 - Loss:     nan


  1%|          | 301/40000 [02:18<3:49:15,  2.89it/s]

Epoch: 300 - Loss:     nan


  1%|          | 351/40000 [02:39<3:50:52,  2.86it/s]

Epoch: 350 - Loss:     nan


  1%|          | 401/40000 [02:59<4:38:54,  2.37it/s]

Epoch: 400 - Loss:     nan


  1%|          | 451/40000 [03:20<5:02:41,  2.18it/s]

Epoch: 450 - Loss:     nan


  1%|▏         | 501/40000 [03:42<4:07:47,  2.66it/s]

Epoch: 500 - Loss:     nan


  1%|▏         | 551/40000 [04:03<5:49:35,  1.88it/s]

Epoch: 550 - Loss:     nan


  2%|▏         | 601/40000 [04:24<4:21:41,  2.51it/s]

Epoch: 600 - Loss:     nan


  2%|▏         | 651/40000 [04:45<4:00:57,  2.72it/s]

Epoch: 650 - Loss:     nan


  2%|▏         | 701/40000 [05:06<4:35:02,  2.38it/s]

Epoch: 700 - Loss:     nan


  2%|▏         | 744/40000 [05:24<4:45:36,  2.29it/s]


KeyboardInterrupt: ignored