# PINN Introduction

In [None]:
# Run this cell is you are using Google Colab
!git clone https://github.com/CPSHub/LecturePhysicsAwareML.git
%cd LecturePhysicsAwareML/PINNs
%pip install -e .

In [None]:
# Run this cell, if you are working locally
%cd ..
%pip install -e .

### Define boundary conditions.

In [None]:
import jax.numpy as jnp
from paml_pinns.lebb import Config, get_data_decorator


@get_data_decorator
def get_data(config: Config):
    if config.bc_case == 0:
        w_bc_coords = jnp.array([0.])
        w_bc_values = jnp.array([0.])
        w_x_bc_coords = jnp.array([0.])
        w_x_bc_values = jnp.array([0.])
        M_bc_coords = jnp.array([config.L])
        M_bc_values = jnp.array([0.])
        Q_bc_coords = jnp.array([config.L])
        Q_bc_values = jnp.array([config.F])
    elif config.bc_case == 1:
        # Start of task scope
        ######################
        pass # remove this for testing
        # w_bc_coords = 
        # w_bc_values = 
        # w_x_bc_coords = 
        # w_x_bc_values = 
        # M_bc_coords = 
        # M_bc_values = 
        # Q_bc_coords = 
        # Q_bc_values = 
        ######################
        # End of task scope
    else:
        NotImplementedError("Data generation is not implemented for these boundary conditions.")

    bc = {
        "w_bc_coords": w_bc_coords,
        "w_bc_values": w_bc_values,
        "w_x_bc_coords": w_x_bc_coords,
        "w_x_bc_values": w_x_bc_values,
        "M_bc_coords": M_bc_coords,
        "M_bc_values": M_bc_values,
        "Q_bc_coords": Q_bc_coords,
        "Q_bc_values": Q_bc_values
    }
    return bc


### Configure beam properties.

In [None]:
from paml_pinns.lebb import get_config_decorator
from typing import Tuple


@get_config_decorator
def get_config(bc_case: int) -> Tuple[float, float, float, float]:
    if bc_case == 0:
        EI = 1e6
        L = 1.0
        F = 1.0
        q = 0.0
    elif bc_case == 1:
        EI = 1e6
        L = 1.0
        F = 0.0 # Unused dummy
        q = 1.0
    else:
        NotImplementedError("No configuration implemented for these boundary conditions.")

    return EI, L, F, q

### Data generation, model creation & training

In [None]:
import jax
from paml_pinns.lebb import PINN
from paml_pinns import train


config = get_config(bc_case=0, non_dim=False)
x, y, bc, EI, L, q = get_data(config)
weights={
    "w_bc": 1.0,
    "w_x_bc": 1.0,
    "M_bc": 1.0,
    "Q_bc": 1.0,
    "rw": 1.0
}

key = jax.random.PRNGKey(1234)
model = PINN(EI, L, q, bc, key=key)

model = train(
    model,
    x,
    weights,
    steps=50_000
)

### Model evaluation

In [None]:
from paml_pinns.lebb import evaluate


evaluate(model, x, y)

### Tasks

#### Task 1: Dedimensionalization

For `bc_case=0`, ...

a) Set `EI=1e6`, and non_dim=False`. Can the model approximate the solution?

b) Now set the `EI = 1.0`. Can the model approximate the solution now?

c) Set `EI=1e6` again and set `non_dim=True`. The model should now be able approximate the solution very well. Why does it work better than a)?  

#### Task 2: Boundary conditions

Fix the boundary conditions for `bc_case=1` withing `get_data()`. Next, calibrate the PINN and test your implementation against the analytical solution. Note, boundary conditions can be set equal to `None`to be ignored.

*Solution:*

In [None]:
import jax.numpy as jnp

from paml_pinns.lebb import Config, get_data_decorator


@get_data_decorator
def get_data(config: Config):
    if config.bc_case == 1:
        w_bc_coords = jnp.array([0., config.L])
        w_bc_values = jnp.array([0., 0.])
        w_x_bc_coords = None
        w_x_bc_values = None
        M_bc_coords = jnp.array([0., config.L])
        M_bc_values = jnp.array([0., 0.])
        Q_bc_coords = None
        Q_bc_values = None
    else:
        NotImplementedError("Data generation is not implemented for these boundary conditions.")

    bc = {
        "w_bc_coords": w_bc_coords,
        "w_bc_values": w_bc_values,
        "w_x_bc_coords": w_x_bc_coords,
        "w_x_bc_values": w_x_bc_values,
        "M_bc_coords": M_bc_coords,
        "M_bc_values": M_bc_values,
        "Q_bc_coords": Q_bc_coords,
        "Q_bc_values": Q_bc_values
    }
    return bc