# Physics-Informed Neural Networks

Original [paper](https://www.sciencedirect.com/science/article/pii/S0021999118307125) by M.Raissi et al. 

In complex physical, biological or engineering systems, obtaining data is sometimes unachievable. State of the art machine learning techniques cannot provide any guarantee of convergence given the lack of training data. 

**Traditional physics model** creation is a task of a domain expert, who parametrises physics models to best fit a system of interest. For example, creating a model of aircraft dynamics using equations of drag, lift, gravity, thrust, etc., and parametrising the model to attempt to closely match the model to a specific aircraft.

However, sometimes physical systems are hard to parametrise and there does not exist a closed form solution. 

In [50]:
import jax.numpy as np
from jax import value_and_grad, jit
from jax import random, vmap
import optax

KEY = random.PRNGKey(1)

## Traditional MLP

First implement `init_params` and `forward`

In [17]:
def init_params(layers, key=KEY):
    '''
    Initialize parameters in the MLP. Weights are initialized
    using Xavier initialization, while biases are zero initialized.

    Returns
    - params: the initialized parameters
    '''
    def xavier_init(input_dim, output_dim, key=key):
        '''Use Xavier initialization for weights of a single layer'''
        std_dev = np.sqrt(2/(input_dim + output_dim)) # compute standard deviation for xavier init
        w = std_dev * random.normal(key, (input_dim, output_dim)) # initialize the weights
        return w

    params = []

    for l in range(len(layers) - 1):
        w = xavier_init(layers[l], layers[l+1]) # xavier initialize the weight
        b = np.zeros(layers[l+1]) # zero initialize the bias
        params.append((w, b)) # append weight and bias for this layer to params
    
    return params


def forward(params, x):
    '''
    Forward pass through the MLP. In PINN, the nonlinearity are
    applied using tanh.

    Arguments
    - params: weights and biases for all layers of MLP
    - x: input to the MLP

    Returns
    - out: output of the MLP
    '''
    activations = x
    for w, b in params[:-1]:
        out = np.dot(activations, w) + b # Perform linear operation
        activations = np.tanh(out) # apply tanh activation
    
    final_w, final_b = params[-1]
    out = np.dot(activations, final_w) + final_b # Do not apply nonlinearity to last layer
    return out

batched_forward = vmap(forward, in_axes=(None, 0))

Test implementation for `init_param` and `forward`.

In [41]:
layers = [2, 10, 1]
x = random.uniform(KEY, (5, 2))
out = batched_forward(init_params(layers), x)
expected = np.array([[-0.05742961],
                     [-0.08960884],
                     [-0.04750253],
                     [-0.17843515],
                     [-0.09102767]])
assert np.allclose(out, expected)

Now implement `mseloss` and `update`

In [53]:
def mseloss(params, x, true):
    pred = batched_forward(params, x)
    return np.mean(np.square(pred - true))

def fit(params, optimizer, X, Y):
    opt_state = optimizer.init(params)

    @jit
    def step(params, opt_state, x, y):
        loss_value, grads = value_and_grad(mseloss)(params, x, y)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    for i, (x, y) in enumerate(zip(X, Y)):
        params, opt_state, loss_value = step(params, opt_state, x, y)
        print(f'step {i}, loss: {loss_value}')

    return params

In [58]:
layers = [5, 10, 10, 1]
X = random.uniform(KEY, (20, 5))
Y = random.uniform(KEY, (20, 1))

optimizer = optax.adam(learning_rate=1e-2)
trained_params = fit(init_params(layers), optimizer, X, Y)

step 0, loss: 0.9785979390144348
step 1, loss: 0.2405158132314682
step 2, loss: 0.07331878691911697
step 3, loss: 0.2963838279247284
step 4, loss: 0.0870107039809227
step 5, loss: 0.02348405122756958
step 6, loss: 0.3177001178264618
step 7, loss: 0.372842013835907
step 8, loss: 0.009827470406889915
step 9, loss: 0.010460514575242996


In [62]:
mseloss(trained_params, X, Y)

DeviceArray(0.09587543, dtype=float32)

In [60]:
Y

DeviceArray([[0.7551559 ],
             [0.3129729 ],
             [0.12388372],
             [0.548188  ],
             [0.4223112 ],
             [0.30576992],
             [0.82008433],
             [0.95633745],
             [0.3566252 ],
             [0.55691683]], dtype=float32)