# PINN Wave Equation - Lambda vs Loss

> Supporting code for paper *Towards Optimally Weighted Physics-Informed Neural Networks in Ocean Modelling* submitted to NeurIPS 2021.

In [None]:
import numpy as np
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

In [None]:
# You need a (free) WandB.ai account.
import wandb

# avoid excess logging
import logging
logger = logging.getLogger("wandb")
logger.setLevel(logging.ERROR)

wandb.login();

In [None]:
folder = "../Data/Waves_Square/"
X = np.loadtxt(folder + "X_star.txt") # 2 x N
T = np.loadtxt(folder + "T_star.txt") # T
U = np.loadtxt(folder + "U_star.txt") # N x T  
V = np.loadtxt(folder + "H_0.txt")    # N

print("X.shape:", X.shape)
print("T.shape:", T.shape)
print("U.shape:", U.shape)
print("V.shape:", V.shape)

print("samples:", T.shape[0]*X.shape[1])

xx = np.tile(X.T, (T.shape[0],1)) # TN x 2
tt = np.repeat(T.reshape(-1,1), X.shape[1], axis=0) # TN x 1
X = np.concatenate([xx,tt], axis=1) # TN x 3

uu = U.T.reshape(-1,1) # TN x 1
vv = np.tile(V.reshape(-1,1), (T.shape[0],1)) # TN x 1
Y = np.concatenate([uu,vv], axis=1) # TN x 2

In [None]:
class PINN_Wave(nn.Module):
    def __init__(self, layer_width=50, layer_depth=5,
                 activation_function='tanh', initializer='none'): 
        super().__init__()
        
        input_width = 3
        output_width = 2
        
        sizes = [input_width] + [layer_width]*layer_depth + [output_width]
        self.net = nn.Sequential(
            *[self.block(dim_in, dim_out, activation_function)
            for dim_in, dim_out in zip(sizes[:-1], sizes[1:-1])],     
            nn.Linear(sizes[-2], sizes[-1]) # output layer is regular linear transformation
        )
        
        if initializer == 'xavier':
            def init_weights(m):
                if type(m) == nn.Linear:
                    torch.nn.init.xavier_uniform_(m.weight)
            self.net.apply(init_weights)
                
    def forward(self, x):
        return self.net.forward(x)
    

    def block(self, dim_in, dim_out, activation_function):
        activation_functions = nn.ModuleDict([
            ['lrelu', nn.LeakyReLU()],
            ['relu', nn.ReLU()],
            ['tanh', nn.Tanh()],
            ['sigmoid', nn.Sigmoid()],
            ['softplus', nn.Softplus()],
            ['softsign', nn.Softsign()],
            ['tanhshrink', nn.Tanhshrink()],
            ['celu', nn.CELU()],
            ['gelu', nn.GELU()],
            ['elu', nn.ELU()],
            ['selu', nn.SELU()],
            ['logsigmoid', nn.LogSigmoid()]
        ])
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            activation_functions[activation_function],
        )


    def f(self, x, y, t, u, v):
        u_x  = grad(u,   x, create_graph=True, grad_outputs=torch.ones_like(u))[0]
        u_xx = grad(u_x, x, create_graph=True, grad_outputs=torch.ones_like(u_x))[0]
        u_y  = grad(u,   y, create_graph=True, grad_outputs=torch.ones_like(u))[0]
        u_yy = grad(u_y, y, create_graph=True, grad_outputs=torch.ones_like(u_y))[0]
        u_t  = grad(u,   t, create_graph=True, grad_outputs=torch.ones_like(u))[0]
        u_tt = grad(u_t, t, create_graph=True, grad_outputs=torch.ones_like(u_t))[0]
        v_x  = grad(v,   x, create_graph=True, grad_outputs=torch.ones_like(v))[0]
        v_y  = grad(v,   y, create_graph=True, grad_outputs=torch.ones_like(v))[0]        
        return u_tt - v*(u_xx + u_yy) - (v_x*u_x + v_y*u_y)
    
    def loss(self, Xu, Yu, Xf=None):
        losses = []
        losses.append(F.mse_loss(self.forward(Xu), Yu))
        
        if Xf is not None:
            Xf.requires_grad=True
            x = Xf[:,0]
            y = Xf[:,1]
            t = Xf[:,2]
            Xf = torch.stack((x,y,t),1)
            Y_hat = self.forward(Xf)
            u = Y_hat[:,0]
            v = Y_hat[:,1]        
            f = self.f(x, y, t, u, v)        
            losses.append(F.mse_loss(f, torch.zeros_like(f)))
        return losses

In [None]:
# parameters
project = 'PINN Waves Lambda vs Loss (Nu=500)'
epochs = 10000 #15000
N = X.shape[0]

torch.manual_seed(2021)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# sweep and train
method = 'random'
if method == 'grid':
  sweep_config = {
      'project': project,
      'method': 'grid',
      'parameters': {
          'alpha': {
              'values': [0.0, 0.5, 1.0]
          },
          'func': {
              'values': ['tanh']
          },
          'Nu': {
              'values': [500]#[100, 1000, 10000]        
          },
          'Nf': {
              'values': [10000]        
          },
      }   
  }

if method == 'random':
  sweep_config = {
      'project': project,
      'method': 'random',
      'metric': {    
          'name': 'Data loss lowest (validation)',
          'goal': 'minimize'
      },
      'parameters': {
          'alpha': {
             'distribution': 'uniform',
             'min': 0.0, 
             'max': 0.01 
          },
          'Nu': {
              'distribution': 'constant',
              'value': 500
          },
          'Nf': {
              'distribution': 'constant',
              'value': 10000
          }
      }    
  }

sweep_id = wandb.sweep(sweep_config, project=project)

# train
models = []
def model_train():
    run = wandb.init()
    config = wandb.config
    rng = np.random.default_rng(2021)
    name = 'WavesLambdaVsLoss_a%g_nu%g_nf%g' % (config.alpha, config.Nu, config.Nf)
    
    # data
    Nu = int(config.Nu)
    Xu_idx = rng.choice(X.shape[0], Nu, replace=False)
    Xu = X[Xu_idx,:]
    Yu = Y[Xu_idx,:]

    Nf = int(config.Nf)
    Xf_idx = rng.choice(X.shape[0], Nf, replace=False)
    Xf = X[Xf_idx,:]

    print("Xu.shape:", Xu.shape)
    print("Yu.shape:", Yu.shape)
    print("Xf.shape:", Xf.shape)

    Xu = torch.tensor(Xu, dtype=torch.float, device=device)
    Yu = torch.tensor(Yu, dtype=torch.float, device=device)
    Xf = torch.tensor(Xf, dtype=torch.float, device=device)
    Xval = torch.tensor(X, dtype=torch.float, device=device)
    Yval = torch.tensor(Y, dtype=torch.float, device=device)

    normL2_u = F.mse_loss(Yval, torch.zeros_like(Yval))
    normL2_u = normL2_u.item()
  
  
    # model
    model = PINN_Wave(layer_width=50,
                         layer_depth=5,
                         activation_function='tanh',
                         initializer='xavier',
                        )
    model.to(device)    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            
    # training
    val_data_losses = np.array([])
    relative_errors = np.array([])
    start_time = time.time()
    for epoch in range(epochs):
        if config.alpha != 0.0:
            losses = model.loss(Xu, Yu, Xf)
            train_data_loss = (1.0-config.alpha)*losses[0]
            phys_loss = config.alpha*losses[1]
            loss = train_data_loss + phys_loss
        else:
            losses = model.loss(Xu, Yu, None)
            train_data_loss = losses[0]
            phys_loss = torch.tensor(0.0)
            loss = train_data_loss
        wandb.log({'Data loss (training)': train_data_loss.detach().item()}, step=epoch)
        wandb.log({'Physics loss': phys_loss.detach().item()}, step=epoch)
        wandb.log({'Loss (training)': loss.item()}, step=epoch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses = model.loss(Xval, Yval, None)
        val_data_loss = losses[0]
        val_loss = val_data_loss + phys_loss    
        relative_error = np.sqrt((1/normL2_u)*val_data_loss.item())
  
        wandb.log({'Data loss (validation)': val_data_loss.item()}, step=epoch)
        wandb.log({'Loss (validation)': val_loss.item()}, step=epoch)
        wandb.log({'Relative error': relative_error.item()}, step=epoch)

        
        val_data_losses = np.append(val_data_losses, val_data_loss.item())
        relative_errors = np.append(relative_errors, relative_error.item())
        if 250 <= len(val_data_losses):
            lowest_val_data_loss = np.min(val_data_losses[-250:])
            lowest_rel_error = np.min(relative_errors[-250:])
        else:
            lowest_val_data_loss = np.min(val_data_losses)
            lowest_rel_error = np.min(relative_errors)

        wandb.log({'Data loss lowest (validation)': lowest_val_data_loss}, step=epoch)
        wandb.log({'Lowest Rel Error': lowest_rel_error}, step=epoch)
           
        if epoch % 1000 == 0:            
            elapsed = time.time() - start_time
            print('Epoch: %d, Loss: %.3e, Time: %.2fs' % 
                      (epoch, val_data_loss, elapsed))
            start_time = time.time()

    torch.save(model.state_dict(), folder + name + '.pth')
    models.append(model)

wandb.agent(sweep_id, project=project, function=model_train);