# **Re**lative **Lo**ss **B**alancing with **Ra**ndom **Lo**okbacks for Kirchhoff PDE
This notebook implements the concepts from the [Multi-Objective Loss Balancing for Physics-Informed Deep Learning paper](https://arxiv.org/abs/2110.09813) and [Improving PINNs through Adaptive Loss Balancing medium article](https://medium.com/p/55662759e701). It showcases the gains in performance when applying Loss Balancing Schemes to PINN training.



In [None]:
import os
import numpy as np
import seaborn as sns
import tensorflow as tf
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid")

from typing import Tuple, Callable, List, Union
from tensorflow.experimental.numpy import isclose
from tensorflow.keras.layers import Input, Dense, Concatenate
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

## Utility functions
We start by defining a few utility functions that will be useful later. The first, *compute_derivatives*, computes all the derivatives necessary to formulate the Kirchhoff PINN objective (it is a fourth order differential equation), while *compute_moments* is used for computing the boundary conditions.



In [None]:
def compute_derivatives(x, y, u):
    """
    Computes the derivatives of `u` with respect to `x` and `y`.

    Parameters
    ----------
    x : tf.Tensor
        The x-coordinate of the collocation points, of shape (batch_size, 1).
    y : tf.Tensor
        The y-coordinate of the collocation points, of shape (batch_size, 1).
    u : tf.Tensor
        The prediction made by the PINN, of shape (batch_size, 1).

    Returns
    -------
    tuple
        The derivatives of `u` with respect to `x`, `y`, `xx`, `yy`, `xy`.
    """
    dudx, dudy = tf.gradients(u, [x, y])
    dudxx = tf.gradients(dudx, x)[0]
    dudyy = tf.gradients(dudy, y)[0]
    dudxxx, dudxxy = tf.gradients(dudxx, [x, y])
    dudyyy = tf.gradients(dudyy, y)[0]
    dudxxxx = tf.gradients(dudxxx, x)[0]
    dudxxyy = tf.gradients(dudxxy, y)[0]
    dudyyyy = tf.gradients(dudyyy, y)[0]
    return dudxx, dudyy, dudxxxx, dudyyyy, dudxxyy


def compute_moments(D, nue, dudxx, dudyy):
    """
    Computes the moments along the x and y axes.

    Parameters
    ----------
    D : float
        The flexural stiffness.
    nue : float
        Poisson's ratio.
    dudxx : tf.Tensor
        The second-order derivative of `u` with respect to `x`, of shape (batch_size, 1).
    dudyy : tf.Tensor
        The second-order derivative of `u` with respect to `y`, of shape (batch_size, 1).

    Returns
    -------
    tuple
        The moments along the x and y axes.
    """
    mx = -D * (dudxx + nue * dudyy)
    my = -D * (nue * dudxx + dudyy)
    return mx, my


## Define Kirchhoff PDE Class
This class represents a plate under the [Kirchhoff plate bending PDE](https://en.wikipedia.org/wiki/Kirchhoff%E2%80%93Love_plate_theory). It provides a set of utility functions to train a Physics Informed Neural Network (PINN) [link text](https://arxiv.org/pdf/1711.10566.pdf) on the Kirchhoff PDE. The class has functions that generate training and validation data, calculate loss, and visualize the results of the PINN's predictions.

In [None]:
EPS = 1e-7

class KirchhoffPDE:
    """
    Class representing a Kirchhoff plate, providing several methods for training a Physics-Informed Neural Network.
    """
    def __init__(self, p: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], u_val: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
                 T: float, nue: float, E: float, H: float, W: float):
        """
        Initialize the KirchhoffPDE class.

        PARAMETERS
        ----------------
        p : Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
            The load function, taking x and y coordinates as inputs and returning the load.
        u_val : Callable[[tf.Tensor, tf.Tensor], tf.Tensor]
            A function to validate the predictions.
        T : float
            Thickness of the plate.
        nue : float
            Poisson's ratio.
        E : float
            Young's modulus.
        W : float
            Width of the plate.
        H : float
            Height of the plate.
        """
        self.p = p
        self.u_val = u_val
        self.T = T
        self.nue = nue
        self.E = E
        self.D = (E * T**3) / (12 * (1 - nue**2)) # flexural stiffnes of the plate
        self.H = H
        self.W = W
        self.num_terms = 3

    def training_batch(self, batch_size_domain:int=800, batch_size_boundary:int=100) -> Tuple[tf.Tensor, tf.Tensor]:
        """
        Generates a batch of collocation points by randomly sampling `batch_size_domain` points inside the domain
        and `batch_size_boundary` points on each of the four boundaries.

        PARAMETERS
        --------------------
        batch_size_domain : int
            number of points to be sampled inside of the domain
        batch_size_boundary : int
            number of points to be sampled on each of the four boundaries
        """
        x_in = tf.random.uniform(shape=(batch_size_domain, 1), minval=0, maxval=self.W)
        x_b1 = tf.zeros(shape=(batch_size_boundary, 1))
        x_b2 = tf.zeros(shape=(batch_size_boundary, 1)) + self.W
        x_b3 = tf.random.uniform(shape=(batch_size_boundary, 1), minval=0, maxval=self.W)
        x_b4 = tf.random.uniform(shape=(batch_size_boundary, 1), minval=0, maxval=self.W)
        x = tf.concat([x_in, x_b1, x_b2, x_b3, x_b4], axis=0)

        y_in = tf.random.uniform(shape=(batch_size_domain, 1), minval=0, maxval=self.H)
        y_b1 = tf.random.uniform(shape=(batch_size_boundary, 1), minval=0, maxval=self.H)
        y_b2 = tf.random.uniform(shape=(batch_size_boundary, 1), minval=0, maxval=self.H)
        y_b3 = tf.zeros(shape=(batch_size_boundary, 1))
        y_b4 = tf.zeros(shape=(batch_size_boundary, 1)) + self.H
        y = tf.concat([y_in, y_b1, y_b2, y_b3, y_b4], axis=0)

        return x, y

    def get_train_dataset(self, batch_size_domain:int=800, batch_size_boundary:int=100):
        """
        Creates a tf.data.Dataset generator for training.

        Parameters
        ----------
        batch_size_domain : int
            number of points to be sampled inside of the domain. Default is 800.
        batch_size_boundary : int
            number of points to be sampled on each of the four boundaries. Default is 100.

        Returns
        -------
        tf.data.Dataset
            A `tf.data.Dataset` generator for training.
        """
        def generator():
            while True:
                xy = tf.concat(self.training_batch(batch_size_domain, batch_size_boundary), axis=-1)
                yield xy, xy

        return tf.data.Dataset.from_generator(
            generator,
            output_types=(tf.float32, tf.float32),
            output_shapes=((None, 2), (None, 2))
        )

    def validation_batch(self, grid_width:int=64, grid_height:int=64):
        """
        Generates a grid of points that can easily be used to generate an image of the plate,
        where each point is a pixel.

        PARAMETERS
        ----------
        grid_width : int
            width of the grid
        grid_height : int
            height of the grid
        """
        x, y = np.mgrid[0:self.W:complex(0, grid_width), 0:self.H:complex(0, grid_height)]
        x = tf.cast(x.reshape(grid_width * grid_height, 1), dtype=tf.float32)
        y = tf.cast(y.reshape(grid_width * grid_height, 1), dtype=tf.float32)
        u = self.u_val(x, y)
        return x, y, u

    def compute_loss(self, x, y, preds, eval=False):
        """
        Computes the physics-informed loss for Kirchhoff's plate bending equation.

        Parameters
        ----------
        x : tf.Tensor of shape (batch_size, 1)
            x coordinate of the points in the current batch
        y : tf.Tensor of shape (batch_size, 1)
            y coordinate of the points in the current batch
        preds : tf.Tensor of shape (batch_size, 6)
            predictions made by our PINN (dim 0) as well as dudxx (dim 1), dudyy (dim 2),
            dudxxxx (dim 3), dudxxyy (dim 4), dudyyyy (dim 5)
        """

        # governing equation loss
        f = preds[:, 3:4] + 2 * preds[:, 4:5] + preds[:, 5:6] - self.p(x, y) / self.D
        L_f = f**2

        # determine which points are on the boundaries of the domain
        # if a point is on either of the boundaries, its value is 1 and 0 otherwise
        x_lower = tf.cast(isclose(x, 0.     , rtol=0., atol=EPS), dtype=tf.float32)
        x_upper = tf.cast(isclose(x, self.W, rtol=0., atol=EPS), dtype=tf.float32)
        y_lower = tf.cast(isclose(y, 0.     , rtol=0., atol=EPS), dtype=tf.float32)
        y_upper = tf.cast(isclose(y, self.H, rtol=0., atol=EPS), dtype=tf.float32)

        # compute 0th order boundary condition loss
        L_b0 = ((x_lower + x_upper + y_lower + y_upper) * preds[:, :1])**2
        # compute 2nd order boundary condition loss
        mx, my = compute_moments(self.D, self.nue, preds[:, 1:2], preds[:, 2:3])
        L_b2 = ((x_lower + x_upper) * mx)**2 + ((y_lower + y_upper) * my)**2

        if eval:
            L_u = (self.u_val(x, y) - preds[:, 0:1])**2
            return L_f, L_b0, L_b2, L_u

        return L_f, L_b0, L_b2

    @tf.function
    def __validation_results(self, pinn: tf.keras.Model, image_width: int = 64, image_height: int = 64):
        """Computes the validation results for the given model.

        Parameters
        ----------
        pinn : tf.keras.Model
            A TensorFlow Keras model instance.
        image_width : int
            The width of the image (defaults to 64).
        image_height : int
            The height of the image (defaults to 64).

        Returns:
        u_real : tf.Tensor
            A tensor containing the real displacement.
        u_pred : tf.Tensor
            A tensor containing the predicted displacement.
        mx : tf.Tensor
            A tensor containing the x-component of the moments.
        my : tf.Tensor
            A tensor containing the y-component of the moments.
        f : tf.Tensor
            A tensor containing the governing equation.
        p : tf.Tensor
            A tensor containing the pressure.
        """
        x, y, u_real = self.validation_batch(image_width, image_height)
        pred = pinn(tf.concat([x, y], axis=-1), training=False)
        u_pred, dudxx, dudyy, dudxxxx, dudyyyy, dudxxyy = pred[:, 0:1], pred[:, 1:2], pred[:, 2:3], pred[:, 3:4], pred[:, 4:5], pred[:, 5:6]
        mx, my = compute_moments(self.D, self.nue, dudxx, dudyy)
        f = dudxxxx + 2 * dudxxyy + dudyyyy
        p = self.p(x, y)
        return u_real, u_pred, mx, my, f, p


    def visualise(self, pinn: tf.keras.Model = None, image_width: int = 64, image_height: int = 64):
        """
        If no model is provided, visualises only the load distribution on the plate.
        Otherwise, visualizes the results of the given model.

        Parameters
        ----------
        pinn : tf.keras.Model
            A TensorFlow Keras model instance.
        image_width : int
            The width of the image (defaults to 64).
        image_height : int
            The height of the image (defaults to 64).
        """
        if pinn is None:
            x, y, u_real = self.validation_batch(image_width, image_height)
            load = self.p(x, y).numpy().reshape(image_width, image_height)
            fig, axs = plt.subplots(1, 2, figsize=(8, 3.2), dpi=100)
            self.__show_image(
                load,
                axis=axs[0],
                title='Load distribution on the plate',
                z_label='$\\left[\\frac{NM}{m^2}\\right]$'
            )
            self.__show_image(
                u_real.numpy().reshape(image_width, image_height),
                axis=axs[1],
                title='Deformation',
                z_label='[m]'
            )
            plt.tight_layout()
            plt.show()

        else:
            u_real, u_pred, mx, my, f, p = self.__validation_results(pinn, image_width, image_height)
            u_real = u_real.numpy().reshape(image_width, image_height)
            u_pred = u_pred.numpy().reshape(image_width, image_height)
            mx = mx.numpy().reshape(image_width, image_height)
            my = my.numpy().reshape(image_width, image_height)
            f = f.numpy().reshape(image_width, image_height)
            p = p.numpy().reshape(image_width, image_height)

            fig, axs = plt.subplots(3, 2, figsize=(9.5, 12))
            self.__show_image(u_pred, axs[0, 0], 'Predicted Displacement (m)')
            self.__show_image((u_pred - u_real)**2, axs[0, 1], 'Squared Error Displacement')
            self.__show_image(mx, axs[1, 0], 'Moments mx')
            self.__show_image(my, axs[1, 1], 'Moments my')
            self.__show_image(f, axs[2, 0], 'Governing Equation')
            self.__show_image((f - p)**2, axs[2, 1], 'Squared Error Governing Equation')

            # Hide x labels and tick labels for top plots and y ticks for right plots.
            for ax in axs.flat:
                ax.label_outer()

            plt.tight_layout()
            plt.show()

    def __show_image(self, img:np.array, axis:plt.axis=None, title:str='', x_label='x [m]', y_label='y [m]', z_label=''):
        if axis is None:
             _, axis = plt.subplots(1, 1, figsize=(4, 3.2), dpi=100)
        im = axis.imshow(np.rot90(img, k=3), cmap='plasma', origin='lower', aspect='auto')
        cb = plt.colorbar(im, label=z_label, ax=axis)
        axis.set_xticks([0, img.shape[0]-1])
        axis.set_xticklabels([0, self.W])
        axis.set_yticks([0, img.shape[1]-1])
        axis.set_yticklabels([0, self.H])
        axis.set_xlabel(x_label)
        axis.set_ylabel(y_label)
        axis.set_title(title)
        return im

## Kirchhoff PINN Loss Function

### Default Kirchhoff PDE Loss
This class inherits from the keras Loss class and can be used in the keras API (i.e. model.fit()) for training our Kirchhoff PINN.

In [None]:
class KirchhoffLoss(tf.keras.losses.Loss):
    """
    Kirchhoff Loss for plate bending physics-informed neural network.

    Parameters
    ----------
    plate: KirchhoffPDE
        The KirchhoffPDE object representing the plate bending physics.
    name: str, optional
        The name of the loss, by default 'ReLoBRaLoKirchhoffLoss'
    """
    def __init__(self, plate:KirchhoffPDE, name='KirchhoffLoss'):
        super().__init__(name=name)
        self.plate = plate

    def call(self, xy, preds):
        x, y = xy[:, :1], xy[:, 1:]
        L_f, L_b0, L_b2 = self.plate.compute_loss(x, y, preds)
        loss = L_f + L_b0 + L_b2
        return tf.reduce_mean(loss)

### ReLoBRaLo Loss Balancing Objective
This class inherits from the KirchhoffLoss class and balances the contributions towards the total loss by scaling the terms L_0, L_b0 and L_b2 according to the ReLoBRaLo loss balancing scheme.

$$
\begin{aligned}
&\lambda_i^{\textit{bal}}(t, t') = k\cdot\frac{\operatorname{exp}\left(\frac{L_i(t)}{\tau L_i(t')}\right)}{\sum_{j=1}^k \operatorname{exp} \left(\frac{L_j(t)}{\tau L_j(t')} \right)}, \; i \in \{1, \dots, k\}\\
&\lambda_{i}^{\textit{hist}}(t) = \rho\lambda_i(t-1) + (1-\rho)\lambda_i^{\textit{bal}}(t, 0))\\
&\lambda_i(t) = \alpha\lambda_{i}^{\textit{hist}} + (1-\alpha)\lambda_i^{\textit{bal}}(t, t-1)
\end{aligned}
$$

where $\alpha$ is the exponential decay rate, $\rho$ is a Bernoulli random variable and $\mathbb{E}[\rho]$ should be chosen close to 1. The intermediate step $\lambda_i^{\textit{bal}}(t, t')$ calculates scalings based on the relative improvements of each term between time steps $t'$ and $t$. The following step $\lambda_{i}^{\textit{hist}}(t)$ defines, whether the scalings calculated in the previous time step ($\rho$ evaluates to 1) or the relative improvements since the beginning of training ($\rho$ evaluates to 0) should be carried forward. Note that this concept of randomly retaining or discarding the history of scalings is what we denote as "random lookbacks". Finally, the scaling $\lambda_i(t)$ for term $i$ is obtained by means of an exponential decay, where $\alpha$ controls the weight given to past scalings versus the scalings calculated in the current time step.

In [None]:
class ReLoBRaLoKirchhoffLoss(KirchhoffLoss):
    """
    Class for the ReLoBRaLo Kirchhoff Loss.
    This class extends the Kirchhoff Loss to have dynamic weighting for each term in the calculation of the loss.
    """
    def __init__(self, plate:KirchhoffPDE, alpha:float=0.999, temperature:float=1., rho:float=0.9999,
                 name='ReLoBRaLoKirchhoffLoss'):
        """
        Parameters
        ----------
        plate : KirchhoffPDE
            An instance of KirchhoffPDE class containing the `compute_loss` function.
        alpha, optional : float
            Controls the exponential weight decay rate.
            Value between 0 and 1. The smaller, the more stochasticity.
            0 means no historical information is transmitted to the next iteration.
            1 means only first calculation is retained. Defaults to 0.999.
        temperature, optional : float
            Softmax temperature coefficient. Controlls the "sharpness" of the softmax operation.
            Defaults to 1.
        rho, optional : float
            Probability of the Bernoulli random variable controlling the frequency of random lookbacks.
            Value berween 0 and 1. The smaller, the fewer lookbacks happen.
            0 means lambdas are always calculated w.r.t. the initial loss values.
            1 means lambdas are always calculated w.r.t. the loss values in the previous training iteration.
            Defaults to 0.9999.
        """
        super().__init__(plate, name=name)
        self.plate = plate
        self.alpha = alpha
        self.temperature = temperature
        self.rho = rho
        self.call_count = tf.Variable(0, trainable=False, dtype=tf.int16)

        self.lambdas = [tf.Variable(1., trainable=False) for _ in range(plate.num_terms)]
        self.last_losses = [tf.Variable(1., trainable=False) for _ in range(plate.num_terms)]
        self.init_losses = [tf.Variable(1., trainable=False) for _ in range(plate.num_terms)]

    def call(self, xy, preds):
        x, y = xy[:, :1], xy[:, 1:]
        losses = [tf.reduce_mean(loss) for loss in self.plate.compute_loss(x, y, preds)]

        # in first iteration (self.call_count == 0), drop lambda_hat and use init lambdas, i.e. lambda = 1
        #   i.e. alpha = 1 and rho = 1
        # in second iteration (self.call_count == 1), drop init lambdas and use only lambda_hat
        #   i.e. alpha = 0 and rho = 1
        # afterwards, default procedure (see paper)
        #   i.e. alpha = self.alpha and rho = Bernoully random variable with p = self.rho
        alpha = tf.cond(tf.equal(self.call_count, 0),
                lambda: 1.,
                lambda: tf.cond(tf.equal(self.call_count, 1),
                                lambda: 0.,
                                lambda: self.alpha))
        rho = tf.cond(tf.equal(self.call_count, 0),
              lambda: 1.,
              lambda: tf.cond(tf.equal(self.call_count, 1),
                              lambda: 1.,
                              lambda: tf.cast(tf.random.uniform(shape=()) < self.rho, dtype=tf.float32)))

        # compute new lambdas w.r.t. the losses in the previous iteration
        lambdas_hat = [losses[i] / (self.last_losses[i] * self.temperature + EPS) for i in range(len(losses))]
        lambdas_hat = tf.nn.softmax(lambdas_hat - tf.reduce_max(lambdas_hat)) * tf.cast(len(losses), dtype=tf.float32)

        # compute new lambdas w.r.t. the losses in the first iteration
        init_lambdas_hat = [losses[i] / (self.init_losses[i] * self.temperature + EPS) for i in range(len(losses))]
        init_lambdas_hat = tf.nn.softmax(init_lambdas_hat - tf.reduce_max(init_lambdas_hat)) * tf.cast(len(losses), dtype=tf.float32)

        # use rho for deciding, whether a random lookback should be performed
        new_lambdas = [(rho * alpha * self.lambdas[i] + (1 - rho) * alpha * init_lambdas_hat[i] + (1 - alpha) * lambdas_hat[i]) for i in range(len(losses))]
        self.lambdas = [var.assign(tf.stop_gradient(lam)) for var, lam in zip(self.lambdas, new_lambdas)]

        # compute weighted loss
        loss = tf.reduce_sum([lam * loss for lam, loss in zip(self.lambdas, losses)])

        # store current losses in self.last_losses to be accessed in the next iteration
        self.last_losses = [var.assign(tf.stop_gradient(loss)) for var, loss in zip(self.last_losses, losses)]
        # in first iteration, store losses in self.init_losses to be accessed in next iterations
        first_iteration = tf.cast(self.call_count < 1, dtype=tf.float32)
        self.init_losses = [var.assign(tf.stop_gradient(loss * first_iteration + var * (1 - first_iteration))) for var, loss in zip(self.init_losses, losses)]

        self.call_count.assign_add(1)

        return loss

## Metrics for logging

### Custom Metric for logging Kirchhoff Loss Terms

In [None]:
class KirchhoffMetric(tf.keras.metrics.Metric):
    """
    Kirchhoff metric to log the values of each loss term, i.e. L_f, L_b0 and L_b2.
    """
    def __init__(self, plate: KirchhoffPDE, name='kirchhoff_metric', **kwargs):
        """Initialize Kirchhoff metric with a KirchhoffPDE instance and metric name.

        Parameters
        ----------
        plate : KirchhoffPDE
            Instance of the KirchhoffPDE.
        name : str, optional
            Name of the metric. Defaults to 'kirchhoff_metric'.
        """
        super().__init__(name=name, **kwargs)
        self.plate = plate
        self.L_f_mean = self.add_weight(name='L_f_mean', initializer='zeros')
        self.L_b0_mean = self.add_weight(name='L_b0_mean', initializer='zeros')
        self.L_b2_mean = self.add_weight(name='L_b2_mean', initializer='zeros')
        self.L_u_mean = self.add_weight(name='L_u_mean', initializer='zeros')

    def update_state(self, xy, y_pred, sample_weight=None):
        x, y = xy[:, :1], xy[:, 1:]
        L_f, L_b0, L_b2, L_u = self.plate.compute_loss(x, y, y_pred, eval=True)
        self.L_f_mean.assign(tf.reduce_mean(L_f[:, 0], axis=0))
        self.L_b0_mean.assign(tf.reduce_mean(L_b0[:, 0], axis=0))
        self.L_b2_mean.assign(tf.reduce_mean(L_b2[:, 0], axis=0))
        self.L_u_mean.assign(tf.reduce_mean(L_u[:, 0], axis=0))

    def reset_state(self):
        self.L_f_mean.assign(0)
        self.L_b0_mean.assign(0)
        self.L_b2_mean.assign(0)
        self.L_u_mean.assign(0)

    def result(self):
        return {'L_f': self.L_f_mean, 'L_b0': self.L_b0_mean, 'L_b2': self.L_b2_mean, 'L_u': self.L_u_mean}

### Custom Metric for logging ReLoBRaLo weights

In [None]:
class ReLoBRaLoLambdaMetric(tf.keras.metrics.Metric):
    """
    A custom TensorFlow metric class to monitor the lambdas of the ReLoBRaLoKirchhoffLoss.
    """
    def __init__(self, loss:ReLoBRaLoKirchhoffLoss, name='relobralo_lambda_metric', **kwargs):
        """
        Parameters
        ----------
        loss : ReLoBRaLoKirchhoffLoss)
            The ReLoBRaLoKirchhoffLoss object that holds the lambdas.
        name : str, optional)
            The name of the metric. Defaults to 'relobralo_lambda_metric'.
        """
        super().__init__(name=name, **kwargs)
        self.loss = loss
        self.L_f_lambda_mean = self.add_weight(name='L_f_lambda_mean', initializer='zeros')
        self.L_b0_lambda_mean = self.add_weight(name='L_b0_lambda_mean', initializer='zeros')
        self.L_b2_lambda_mean = self.add_weight(name='L_b2_lambda_mean', initializer='zeros')

    def update_state(self, xy, y_pred, sample_weight=None):
        L_f_lambda, L_b0_lambda, L_b2_lambda = self.loss.lambdas
        self.L_f_lambda_mean.assign(L_f_lambda)
        self.L_b0_lambda_mean.assign(L_b0_lambda)
        self.L_b2_lambda_mean.assign(L_b2_lambda)

    def reset_state(self):
        self.L_f_lambda_mean.assign(0)
        self.L_b0_lambda_mean.assign(0)
        self.L_b2_lambda_mean.assign(0)

    def result(self):
        return {'L_f_lambda': self.L_f_lambda_mean, 'L_b0_lambda': self.L_b0_lambda_mean, 'L_b2_lambda': self.L_b2_lambda_mean}

## PINN Model

In [None]:
class KirchhoffPINN(tf.keras.Model):
    """
    This class is a implementation of a physics-informed neural network (PINN)
    for the Kirchhoff plate bending partial differential equation (PDE).
    """
    def __init__(self, layer_widths: List[int]=[64, 64, 64], activation: Union[str, Callable]=tf.keras.activations.swish, **kwargs):
        """
        Parameters
        ----------
        layer_widths : List[int], optional
            List of integers representing the widths of the hidden layers in the model.
        activation : Union[str, Callable], optional
            Activation function to be applied in each layer.
        """
        super().__init__(**kwargs)
        self.layer_sequence = [tf.keras.layers.Dense(width, activation=activation, kernel_initializer='glorot_normal') for width in layer_widths]
        self.layer_sequence.append(tf.keras.layers.Dense(1, kernel_initializer='glorot_normal'))

    def call(self, xy, training=None, mask=None):
        x, y = xy[:, :1], xy[:, 1:]

        u = Concatenate()([x, y])
        for layer in self.layer_sequence:
            u = layer(u)

        dudxx, dudyy, dudxxxx, dudyyyy, dudxxyy = compute_derivatives(x, y, u)

        return tf.concat([u, dudxx, dudyy, dudxxxx, dudyyyy, dudxxyy], axis=-1)

## Problem definition
The Kirchhoff PDE states that the fourth order derivative of the plate's deformation u is equal to the load p divided by a constant D. In order to obtain an analytical solution, we can define a load with a sinusoidal distribution:
$$
\begin{equation}
    \begin{array}{@{}l@{}}
        p(x, y) = p_0\sin\left(\frac{x\pi}{W}\right)\sin\left(\frac{y\pi}{H}\right)\\
        u(x, y) = \frac{p_0}{\pi^4 D (W^{-2} + H^{-2})^2}\sin\left(\frac{x\pi}{W}\right)\sin\left(\frac{y\pi}{H}\right)
    \end{array}
\end{equation}
$$
We consider a concrete plate of width $a = 1'000$ cm, height $b = 1'000$ cm, base load $p_0 = 15$ MN/cm^2, Young's modulus $E = 30,000$ MN/m^2, plate thickness $T = 0.2$m and Poisson's ratio of $\nu = 0.2$.

In [None]:
W = 10
H = 10
T = 0.2
E = 30000
nue = 0.2
p0 = 0.15
D = (E * T**3) / (12 * (1 - nue**2)) # flexural stiffnes of the plate

load = lambda x, y: p0 * tf.math.sin(x * np.pi / W) * tf.math.sin(y * np.pi / H)
u_val = lambda x, y: p0 / (np.pi**4 * D * (W**-2 + H**-2)**2) * tf.math.sin(x * np.pi / W) * tf.math.sin(y * np.pi / H)
plate = KirchhoffPDE(p=load, u_val=u_val, T=T, nue=nue, E=E, W=W, H=H)
plate.visualise()

### Train without loss balancing
Now that the PDE has been defined, we can build the model as well as the loss function. We are first using the default KirchhoffLoss and will compare it to the ReLoBRaLoKirchhoffLoss later.

In [None]:
pinn = KirchhoffPINN()
loss = KirchhoffLoss(plate)
pinn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=loss, metrics=[KirchhoffMetric(plate)])

In [None]:
h = pinn.fit(
    plate.get_train_dataset(),
    epochs=1000,
    steps_per_epoch=1,
    callbacks=[
        ReduceLROnPlateau(monitor='loss', factor=0.1, patience=30, min_delta=0, verbose=True),
        EarlyStopping(monitor='loss', patience=100, restore_best_weights=True, verbose=True)
    ]
)

Visualise the progress of the several loss terms as well as the error against the analytical solution.

In [None]:
fig = plt.figure(figsize=(6, 4.5), dpi=100)
plt.plot(np.log(h.history['L_f']), label='$L_f$ governing equation')
plt.plot(np.log(h.history['L_b0']), label='$L_{b0}$ Dirichlet boundaries')
plt.plot(np.log(h.history['L_b2']), label='$L_{b2}$ Moment boundaries')
plt.plot(np.log(h.history['L_u']), label='$L_u$ analytical solution')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Log-loss')
plt.title('Default Loss')
plt.title('Loss evolution Kirchhoff PDE')
plt.savefig('kirchhoff_loss_unscaled')

Visually inspect the error distribution on the physical domain.

In [None]:
plate.visualise(pinn)

### Train with ReLoBRaLo

In [None]:
relobralo_pinn = KirchhoffPINN()
relobralo_loss = ReLoBRaLoKirchhoffLoss(plate, temperature=0.1, rho=0.99, alpha=0.999)
relobralo_pinn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=relobralo_loss, metrics=[KirchhoffMetric(plate), ReLoBRaLoLambdaMetric(relobralo_loss)])

In [None]:
h_relobralo = relobralo_pinn.fit(
    plate.get_train_dataset(),
    epochs=1000,
    steps_per_epoch=100,
    callbacks=[
        ReduceLROnPlateau(monitor='loss', factor=0.1, patience=30, min_delta=0, verbose=True),
        EarlyStopping(monitor='loss', patience=100, restore_best_weights=True, verbose=True)
    ]
)

Visualise the progress of the several loss terms as well as the error against the analytical solution.

In [None]:
fig = plt.figure(figsize=(6, 4.5), dpi=100)
plt.plot(np.log(h_relobralo.history['L_f']), label='$L_f$ governing equation')
plt.plot(np.log(h_relobralo.history['L_b0']), label='$L_{b0}$ Dirichlet boundaries')
plt.plot(np.log(h_relobralo.history['L_b2']), label='$L_{b2}$ Moment boundaries')
plt.plot(np.log(h_relobralo.history['L_u']), label='$L_u$ analytical solution')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Log-loss')
plt.ylim([-17, -4])
plt.title('Loss evolution Kirchhoff PDE\nwith ReLoBRaLo')
plt.savefig('kirchhoff_loss_relobralo')

Visualise the progress of the scalings $\lambda$.

In [None]:
fig = plt.figure(figsize=(6, 4.5), dpi=100)
plt.plot(h_relobralo.history['L_f_lambda'], label='$\lambda_f$ governing equation')
plt.plot(h_relobralo.history['L_b0_lambda'], label='$\lambda_{b0}$ Dirichlet boundaries')
plt.plot(h_relobralo.history['L_b2_lambda'], label='$\lambda_{b2}$ Moment boundaries')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('scalings $\lambda$')
plt.title('ReLoBRaLo weights on Kirchhoff PDE')
plt.savefig('kirchhoff_lambdas_relobralo')

Visually inspect the error distribution on the physical domain.

In [None]:
plate.visualise(relobralo_pinn)