# JAXPI PINN Implementation

In [None]:
# Run this cell is you are using Google Colab
!git clone https://github.com/CPSHub/LecturePhysicsAwareML.git
%cd LecturePhysicsAwareML/PINNs
%pip install -e .

In [None]:
# Run this cell, if you are working locally
%cd ..
%pip install -e .

In [None]:
from functools import partial

import jax
import jax.numpy as jnp

from pinns.lebb import Config
import pinns.lebb.cases as cases

def get_data(config: Config):
    EI = config.EI
    L = config.L
    F = config.F
    q = config.q

    if config.bc_case == 0:
        sol_fun = partial(cases.bc_case_0, EI, F, L)
        w_bc_coords = jnp.array([0.])
        w_bc_values = jnp.array([0.])
        w_x_bc_coords = jnp.array([0.])
        w_x_bc_values = jnp.array([0.])
        w_xx_bc_coords = jnp.array([L])
        w_xx_bc_values = jnp.array([0.])
        w_xxx_bc_coords = jnp.array([L])
        w_xxx_bc_values = jnp.array([- F / EI])
    elif config.bc_case == 1:
        sol_fun = partial(cases.bc_case_1, EI, q, L)
        w_bc_coords = jnp.array([0., L])
        w_bc_values = jnp.array([0., 0.])
        w_x_bc_coords = None
        w_x_bc_values = None
        w_xx_bc_coords = jnp.array([0., L])
        w_xx_bc_values = jnp.array([0., 0.])
        w_xxx_bc_coords = None
        w_xxx_bc_values = None
    else:
        NotImplementedError("Data generation is not implemented for these boundary conditions.")
    

    x = jnp.linspace(0.0, config.L, config.dataset_size).reshape(-1, 1)
    y = jax.vmap(sol_fun)(x)
    w, w_x, w_xx, w_xxx, w_xxxx = y

    if config.non_dim:
        x0 = L
        q0 = 1.0 if q == 0.0 else q
        w0 = q0 * x0**4 / EI

        x = x / x0
        w = w / w0
        w_x = w_x / w0 * x0
        w_xx = w_xx / w0 * x0**2
        w_xxx = w_xxx / w0 * x0**3
        w_xxxx = w_xxxx / w0 * x0**4

        w_bc_coords = w_bc_coords / x0 if w_bc_coords is not None else None
        w_x_bc_coords = w_x_bc_coords / x0 if w_x_bc_coords is not None else None
        w_xx_bc_coords = w_xx_bc_coords / x0 if w_xx_bc_coords is not None else None
        w_xxx_bc_coords = w_xx_bc_coords / x0 if w_xxx_bc_coords is not None else None
        
        w_bc_values = w_bc_values / w0 if w_bc_coords is not None else None
        w_x_bc_values = w_x_bc_values / w0 * x0 if w_x_bc_coords is not None else None
        w_xx_bc_values = w_xx_bc_values / w0 * x0**2 if w_xx_bc_coords is not None else None
        w_xxx_bc_values = w_xxx_bc_values / w0 * x0**3 if w_xxx_bc_coords is not None else None
        
        q_out = q if q == 0.0 else 1.0
        L_out = 1.0

    else:
        q_out = q / EI
        L_out = L 

    bc = {
        "w_bc_coords": w_bc_coords,
        "w_bc_values": w_bc_values,
        "w_x_bc_coords": w_x_bc_coords,
        "w_x_bc_values": w_x_bc_values,
        "w_xx_bc_coords": w_xx_bc_coords,
        "w_xx_bc_values": w_xx_bc_values,
        "w_xxx_bc_coords": w_xxx_bc_coords,
        "w_xxx_bc_values": w_xxx_bc_values
    }
    return x, (w, w_x, w_xx, w_xxx, w_xxxx), bc, L_out, q_out

### PINN model

In [None]:
from typing import Tuple, Self

import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray
import paramax
import equinox as eqx

from pinns.nn import FFNN


class PINN(eqx.Module):
    nn: FFNN
    q: float
    L: float
    w_bc_coords: paramax.NonTrainable | None
    w_bc_values: paramax.NonTrainable | None
    w_x_bc_coords: paramax.NonTrainable | None
    w_x_bc_values: paramax.NonTrainable | None
    w_xx_bc_coords: paramax.NonTrainable | None
    w_xx_bc_values: paramax.NonTrainable | None
    w_xxx_bc_coords: paramax.NonTrainable | None
    w_xxx_bc_values: paramax.NonTrainable | None

    def __init__(
        self,
        L: float,
        q: float,
        bc: dict[str, Array | None],
        *,
        key: PRNGKeyArray
    ):
        self.nn = FFNN(
            in_features=1,
            hidden_features=[8, 8],
            out_features=1,
            activations=[jax.nn.tanh, jax.nn.tanh],
            final_activation=lambda x: x,
            key=key
        )
        self.q = q
        self.L = L

        if bc["w_bc_coords"] is None:
            self.w_bc_coords = None
            self.w_bc_values = None
        else:
            self.w_bc_coords = paramax.NonTrainable(bc["w_bc_coords"].reshape(-1, 1))
            self.w_bc_values = paramax.NonTrainable(bc["w_bc_values"].reshape(-1, 1))

        if bc["w_x_bc_coords"] is None:
            self.w_x_bc_coords = None
            self.w_x_bc_values = None
        else:
            self.w_x_bc_coords = paramax.NonTrainable(bc["w_x_bc_coords"].reshape(-1, 1))
            self.w_x_bc_values = paramax.NonTrainable(bc["w_x_bc_values"].reshape(-1, 1))

        if bc["w_xx_bc_coords"] is None:
            self.w_xx_bc_coords = None
            self.w_xx_bc_values = None
        else:
            self.w_xx_bc_coords = paramax.NonTrainable(bc["w_xx_bc_coords"].reshape(-1, 1))
            self.w_xx_bc_values = paramax.NonTrainable(bc["w_xx_bc_values"].reshape(-1, 1))

        if bc["w_xxx_bc_coords"] is None:
            self.w_xxx_bc_coords = None
            self.w_xxx_bc_values = None
        else:
            self.w_xxx_bc_coords = paramax.NonTrainable(bc["w_xxx_bc_coords"].reshape(-1, 1))
            self.w_xxx_bc_values = paramax.NonTrainable(bc["w_xxx_bc_values"].reshape(-1, 1))

    def __call__(self, x: Array) -> Array:
        w = self.w(self, x)
        w_x = self.w_x(self, x)
        w_xx = self.w_xx(self, x)
        w_xxx = self.w_xxx(self, x)
        w_xxxx = self.w_xxxx(self, x)

        return w, w_x, w_xx, w_xxx, w_xxxx

    def forward(self, x: Array) -> Tuple[Array, ...]:
        x = x / self.L
        return self.nn(x)

    def w(self, model: Self, x: Array) -> Array:
        return model.forward(x)

    def w_x(self, model: Self,  x: Array) -> Array:
        return jax.jacfwd(self.w, argnums=1)(model, x)[0]

    def w_xx(self, model: Self, x: Array) -> Array:
        return jax.jacfwd(self.w_x, argnums=1)(model, x)[0]

    # def M(self, model: Self, x: Array) -> Array:
    #     return - self.EI * self.w_xx(model, x)

    def w_xxx(self, model: Self, x: Array) -> Array:
        return jax.jacfwd(self.w_xx, argnums=1)(model, x)[0]

    # def Q(self, model: Self, x: Array) -> Array:
    #     return - self.EI * self.w_xxx(model, x)

    def w_xxxx(self, model: Self, x: Array) -> Array:
        return jax.jacfwd(self.w_xxx, argnums=1)(model, x)[0]

    def res_w(self, model: Self, x: Array):
        w_xxxx = self.w_xxxx(model, x)
        rw = w_xxxx - self.q

        return rw

    def losses(self, model, x):
        w_pred_fun = jax.vmap(self.w, (None, 0))
        w_x_pred_fun = jax.vmap(self.w_x, (None, 0))
        w_xx_pred_fun = jax.vmap(self.w_xx, (None, 0))
        w_xxx_pred_fun = jax.vmap(self.w_xxx, (None, 0))

        res_w_fun = jax.vmap(self.res_w, (None, 0))
        
        if self.w_bc_coords is None:
            w_bc_loss = jnp.array(0.)
        else:
            w_bc_pred = w_pred_fun(model, self.w_bc_coords)
            w_bc_loss = jnp.mean((w_bc_pred - self.w_bc_values)**2)

        if self.w_x_bc_coords is None:
            w_x_bc_loss = jnp.array(0.)
        else:
            w_x_bc_pred = w_x_pred_fun(model, self.w_x_bc_coords)
            w_x_bc_loss = jnp.mean((w_x_bc_pred - self.w_x_bc_values)**2)

        if self.w_xx_bc_coords is None:
            w_xx_bc_loss = jnp.array(0.)
        else:
            w_xx_bc_pred = w_xx_pred_fun(model, self.w_xx_bc_coords)
            w_xx_bc_loss = jnp.mean((w_xx_bc_pred - self.w_xx_bc_values)**2)

        if self.w_xxx_bc_coords is None:
            w_xxx_bc_loss = jnp.array(0.)
        else:
            w_xxx_bc_pred = w_xxx_pred_fun(model, self.w_xxx_bc_coords)
            w_xxx_bc_loss = jnp.mean((w_xxx_bc_pred - self.w_xxx_bc_values)**2)

        rw_pred = res_w_fun(model, x)
        rw_loss = jnp.mean(rw_pred**2)

        loss_dict = {
            "w_bc": w_bc_loss,
            "w_x_bc": w_x_bc_loss,
            "w_xx_bc": w_xx_bc_loss,
            "w_xxx_bc": w_xxx_bc_loss,
            "rw": rw_loss
        }
        return loss_dict

    def loss(self, model, weights, x):
        losses = self.losses(model, x)
        weighted_losses = jax.tree.map(lambda x, y: x * y, losses, weights)
        loss = jax.tree.reduce(lambda x, y: x + y, weighted_losses)
        return loss


In [None]:
def get_config(bc_case: int, non_dim: bool):
    if bc_case == 0:
        EI = 1e6
        L = 1.0
        F = 1.0
        q = 0.0
    elif bc_case == 1:
        EI = 1e6
        L = 1.0
        F = None
        q = 1.0
    else:
        NotImplementedError("No configuration implemented for these boundary conditions.")

    config = Config(
        EI=EI,
        L=L,
        F=F,
        q=q,
        bc_case=bc_case,
        dataset_size=1_000,
        steps=50_000,
        learning_rate=1e-3,
        batch_size=32,
        weights={
            "w_bc": 1.0,
            "w_x_bc": 1.0,
            "w_xx_bc": 1.0,
            "w_xxx_bc": 1.0,
            "rw": 1.0
        },
        non_dim=non_dim
    )

    return config

In [None]:
from pinns.lebb import train


config = get_config(bc_case=0, non_dim=False)
x, y, bc, L, q = get_data(config)

key = jax.random.PRNGKey(1234)
model = PINN(L, q, bc, key=key)

model = train(model, x, config)

In [None]:
from pinns.lebb import evaluate

evaluate(model, x, y)

### Tasks

#### 1. Dimensionality:

For `bc_case=0`, ...

a) Set the parameter EI to $10^6$ and set `non_dim=False`. Can the model approximate the solution.

b) Try setting the parameter EI to $1.0$. Can the model approximate the solution now?

c) Set EI back to $10^6$ and set `non_dim=True`. Can the model fit the data now?

d) Summarize your findings.

#### 2. Boundary conditions

Fix the boundary conditions for `bc_case=1` withing `get_data()`. Now calibrate the PINN and test your implementation againt the analytical solution.