In [1]:
## Pawel Maczuga and Maciej Paszynski 2023

import meshio
import numpy as np
import torch
import os

from datetime import datetime
from torch import nn
from typing import Callable, Tuple, List
# from utils import get_initial_points, plot_intial_condition, plot_simulation_by_frame, create_gif, ReportContext, create_report, plot_running_average

In [7]:
!pip install meshio
!pip install xhtml2pdf
!mkdir img
!mkdir results

Collecting meshio
  Downloading meshio-5.3.4-py3-none-any.whl (167 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m167.7/167.7 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: meshio
Successfully installed meshio-5.3.4
Collecting xhtml2pdf
  Obtaining dependency information for xhtml2pdf from https://files.pythonhosted.org/packages/1d/b7/637d96fe25024fdaaa4d265ae353cafdca706167325109fc1e574174b2bf/xhtml2pdf-0.2.13-py3-none-any.whl.metadata
  Downloading xhtml2pdf-0.2.13-py3-none-any.whl.metadata (21 kB)
Collecting arabic-reshaper>=3.0.0 (from xhtml2pdf)
  Downloading arabic_reshaper-3.0.0-py3-none-any.whl (20 kB)
Collecting pyHanko>=0.12.1 (from xhtml2pdf)
  Obtaining dependency information for pyHanko>=0.12.1 from https://files.pythonhosted.org/packages/68/fc/d3d6dbb6ca6a9c755df5b6a1f2897e7560ef4c1384d8c89d182424f1582e/pyHanko-0.21.0-py3-none-any.whl.metadata
  Downloading pyHanko-0.21.0-py3-none-any.whl.metadata (9.4 kB)
Collecting pyhan

In [7]:
MESH_FILENAME = os.path.join("data", "val_square_UTM_translated_4.inp")

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
torch.cuda.empty_cache()

In [10]:
RUN_NUM = 7

### Parameters

In [15]:
## LENGTH = 1. # Domain size in x axis. Always starts at 0
TOTAL_TIME = 0.5 # Domain size in t axis. Always starts at 0
N_POINTS = 15 # Number of in single asxis
N_POINTS_PLOT = 100 # Number of points in single axis used in plotting

WEIGHT_RESIDUAL = 0.05 # Weight of residual part of loss function
WEIGHT_INITIAL = 3 # Weight of initial part of loss function
WEIGHT_BOUNDARY = 0.0005 # Weight of boundary part of loss function

# Original
# WEIGHT_RESIDUAL = 0.03 # Weight of residual part of loss function
# WEIGHT_INITIAL = 1.0 # Weight of initial part of loss function
# WEIGHT_BOUNDARY = 0.0005 # Weight of boundary part of loss function

LAYERS = 10
NEURONS_PER_LAYER = 120
EPOCHS = 50_000
LEARNING_RATE = 0.00005
GRAVITY = 9.81

In [16]:
t_domain = [0, TOTAL_TIME]

In [17]:
def floor(x, y):
    """Get the sea floor value"""
    return 2

In [18]:
mesh = meshio.avsucd.read(MESH_FILENAME)
vertices = torch.tensor(mesh.points, dtype=torch.float32)  # Tensor of vertices' coordinates
triangles = mesh.cells_dict['triangle']  # Connectivity information for triangles as NumPy array

print(vertices.shape)

# Function to compute partial derivatives at vertices
def compute_derivatives_at_vertices(vertices, triangles):
    dx_vertices = torch.zeros(vertices.shape[0])  # Initialize tensors for derivatives
    dy_vertices = torch.zeros(vertices.shape[0])

    for triangle in triangles:
        # Extract vertex indices for the current triangle
        idx1, idx2, idx3 = triangle
        
        # Vertices' coordinates for the current triangle
        v1, v2, v3 = vertices[idx1], vertices[idx2], vertices[idx3]
        
        # Compute partial derivatives (approximate gradient using cross product of edges)
        dx = torch.cross(v2 - v1, v3 - v1)[0] / 2  # x-component of the cross product
        dy = torch.cross(v2 - v1, v3 - v1)[1] / 2  # y-component of the cross product
        
        # Add computed derivatives to the corresponding vertices
        dx_vertices[idx1] += dx
        dx_vertices[idx2] += dx
        dx_vertices[idx3] += dx
        
        dy_vertices[idx1] += dy
        dy_vertices[idx2] += dy
        dy_vertices[idx3] += dy

    return dx_vertices.to(device), dy_vertices.to(device)

# Compute derivatives at vertices
dx_vertices, dy_vertices = compute_derivatives_at_vertices(vertices, triangles)

torch.Size([25, 3])


## PINN

In [19]:
class PINN(nn.Module):
    """Simple neural network accepting two features as input and returning a single output

    In the context of PINNs, the neural network is used as universal function approximator
    to approximate the solution of the differential equation
    """
    def __init__(self, num_hidden: int, dim_hidden: int, act=nn.Tanh()):

        super().__init__()

        self.layer_in = nn.Linear(3, dim_hidden)
        self.layer_out = nn.Linear(dim_hidden, 1)

        num_middle = num_hidden - 1
        self.middle_layers = nn.ModuleList(
            [nn.Linear(dim_hidden, dim_hidden) for _ in range(num_middle)]
        )
        self.act = act

    def forward(self, x, y, t):
        x_stack = torch.cat([x, y, t], dim=1).to(device)
        out = self.act(self.layer_in(x_stack))
        for layer in self.middle_layers:
            out = self.act(layer(out))
        logits = self.layer_out(out)
        return logits

    def device(self):
        return next(self.parameters()).device


def f(pinn: PINN, x: torch.Tensor, y: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """Compute the value of the approximate solution from the NN model"""
    return pinn(x, y, t)


def df(output: torch.Tensor, input: torch.Tensor, order: int = 1) -> torch.Tensor:
    """Compute neural network derivative with respect to input features using PyTorch autograd engine"""
    df_value = output
    for _ in range(order):
        df_value = torch.autograd.grad(
            df_value,
            input,
            grad_outputs=torch.ones_like(input),
            create_graph=True,
            retain_graph=True,
        )[0]

    return df_value


def dfdt(pinn: PINN, x: torch.Tensor, y: torch.Tensor, t: torch.Tensor, f_val=None, order: int = 1):
    f_value = f_val if f_val is not None else f(pinn, x, y, t)
    # f_value = f(pinn, x, y, t)
    return df(f_value, t, order=order)


def dfdx(pinn: PINN, x: torch.Tensor, y: torch.Tensor, t: torch.Tensor, f_val=None, order: int = 1):
    # f_value = f(pinn, x, y, t)
    f_value = f_val if f_val is not None else f(pinn, x, y, t)
    return df(f_value, x, order=order)

def dfdy(pinn: PINN, x: torch.Tensor, y: torch.Tensor, t: torch.Tensor, f_val=None, order: int = 1):
    # f_value = f(pinn, x, y, t)
    f_value = f_val if f_val is not None else f(pinn, x, y, t)
    return df(f_value, y, order=order)

def dzdx(x: torch.Tensor, y: torch.Tensor, order: int = 1):
    return dx_vertices

def dzdy(x: torch.Tensor, y: torch.Tensor, order: int = 1):
    return dy_vertices

## Loss function

In [20]:
def get_initial_points(x_domain: List[float], 
                       y_domain: List[float], 
                       t_domain: List[float], 
                       n_points: int, 
                       device=torch.device("cpu"), 
                       requires_grad=True):
    x_linspace = torch.linspace(x_domain[0], x_domain[1], n_points)
    y_linspace = torch.linspace(y_domain[0], y_domain[1], n_points)
    
    x_grid, y_grid = torch.meshgrid(x_linspace, y_linspace, indexing="ij")
    
    x_grid = x_grid.reshape(-1, 1).to(device)
    y_grid = y_grid.reshape(-1, 1).to(device)
    
    x_grid.requires_grad = requires_grad
    y_grid.requires_grad = requires_grad
    
    t0 = torch.full_like(x_grid, t_domain[0], requires_grad=requires_grad)
    return (x_grid, y_grid, t0)

def get_initial_mesh(x_domain: List[float], 
                       y_domain: List[float], 
                       t_domain: List[float], 
                       n_points: int, 
                       device=torch.device("cpu"), 
                       requires_grad=True):
    x_raw, y_raw, _ = dump_points(MESH_FILENAME)
    x = x_raw.to(device)
    y = y_raw.to(device)
    
    x.requires_grad = requires_grad
    y.requires_grad = requires_grad

    x = x.reshape(-1, 1).to(device)
    y = y.reshape(-1, 1).to(device)
    
    t0 = torch.full_like(x, t_domain[0], requires_grad=requires_grad)
    return (x, y, t0)

In [21]:
def get_boundary_points(x_domain, y_domain, t_domain, n_points, device = torch.device("cpu"), requires_grad=True):
    """
         .+------+
       .' |    .'|
      +---+--+'  |
      |   |  |   |
    y |  ,+--+---+
      |.'    | .' t
      +------+'
         x
    """
    x_linspace = torch.linspace(x_domain[0], x_domain[1], n_points)
    y_linspace = torch.linspace(y_domain[0], y_domain[1], n_points)
    t_linspace = torch.linspace(t_domain[0], t_domain[1], n_points)

    x_grid, t_grid = torch.meshgrid(x_linspace, t_linspace, indexing="ij")
    y_grid, _      = torch.meshgrid(y_linspace, t_linspace, indexing="ij")

    x_grid = x_grid.reshape(-1, 1).to(device)
    y_grid = y_grid.reshape(-1, 1).to(device)
    t_grid = t_grid.reshape(-1, 1).to(device)
    
    x_grid.requires_grad = requires_grad
    y_grid.requires_grad = requires_grad
    t_grid.requires_grad = requires_grad

    x0 = torch.full_like(t_grid, x_domain[0], requires_grad=requires_grad)
    x1 = torch.full_like(t_grid, x_domain[1], requires_grad=requires_grad)
    y0 = torch.full_like(t_grid, y_domain[0], requires_grad=requires_grad)
    y1 = torch.full_like(t_grid, y_domain[1], requires_grad=requires_grad)

    down    = (x_grid, y0,     t_grid)
    up      = (x_grid, y1,     t_grid)
    left    = (x0,     y_grid, t_grid)
    right   = (x1,     y_grid, t_grid)

    return down, up, left, right

#### Interior basic

In [22]:
def get_interior_points(x_domain, y_domain, t_domain, n_points, device = torch.device("cpu"), requires_grad=True):
    x_raw = torch.linspace(x_domain[0], x_domain[1], steps=n_points, requires_grad=requires_grad)
    y_raw = torch.linspace(y_domain[0], y_domain[1], steps=n_points, requires_grad=requires_grad)
    t_raw = torch.linspace(t_domain[0], t_domain[1], steps=n_points, requires_grad=requires_grad)
    grids = torch.meshgrid(x_raw, y_raw, t_raw, indexing="ij")

    x = grids[0].reshape(-1, 1).to(device)
    y = grids[1].reshape(-1, 1).to(device)
    t = grids[2].reshape(-1, 1).to(device)
    return x, y, t

#### Interior based on bedside map

In [23]:
def get_interior_points_mesh(t_domain, n_points, device=torch.device("cpu"), requires_grad=True):
    x_raw, y_raw, z_raw = dump_points(MESH_FILENAME)
    t_raw = torch.linspace(t_domain[0], t_domain[1], steps=n_points)
    x_grid, t_grid = torch.meshgrid(x_raw, t_raw, indexing="ij")
    y_grid, _      = torch.meshgrid(y_raw, t_raw, indexing="ij")
    z_grid, _      = torch.meshgrid(z_raw, t_raw, indexing="ij")
    x = x_grid.reshape(-1, 1).to(device)
    y = y_grid.reshape(-1, 1).to(device)
    z = z_grid.reshape(-1, 1).to(device)
    t = t_grid.reshape(-1, 1).to(device)
    x.requires_grad = True
    y.requires_grad = True
    z.requires_grad = True
    t.requires_grad = True
    return x, y, z, t


def dump_points(filename):
    mesh = meshio.avsucd.read(filename)
    points = torch.tensor(mesh.points, dtype=torch.float32)
    x,y,z = points.transpose(0,1)
    #-> translate into [0,1]
    min_x, min_y, min_z = torch.min(x), torch.min(y), torch.min(z)
    max_x, max_y, max_z = torch.max(x), torch.max(y), torch.max(z)
    x = (x - min_x) / (max_x - min_x)
    y = (y - min_y) / (max_y - min_y)
    z = (z - min_z) / (max_z - min_z)
    return x,y,z


def mesh_from_tensors(x,y,z):
    normalized_points = torch.stack((x, y, z), dim=1).tolist()
    new_mesh = meshio.Mesh(points=normalized_points, cells=mesh.cells)
    return new_mesh

In [24]:
x_raw, y_raw, z_raw = dump_points(MESH_FILENAME)
x_interior, y_interior, z_interior, t_interior = get_interior_points_mesh(t_domain, N_POINTS, device)
x_domain = [x_interior.min().item(), x_interior.max().item()]
y_domain = [y_interior.min().item(), y_interior.max().item()]
X_POINTS = x_interior.size()[0] // N_POINTS
Y_POINTS = x_interior.size()[0] // N_POINTS
LENGTH = x_domain[1]

In [30]:
x_initial, y_initial, t_initial = get_initial_mesh(x_domain, y_domain, t_domain, N_POINTS, device)

In [31]:
down, up, left, right = get_boundary_points(x_domain, y_domain, t_domain, N_POINTS, device)

In [32]:
class Loss:
    def __init__(
        self,
        x_domain: Tuple[float, float],
        y_domain: Tuple[float, float],
        t_domain: Tuple[float, float],
        n_points: int,
        initial_condition: Callable,
        floor: Callable,
        weight_r: float = 1.0,
        weight_b: float = 1.0,
        weight_i: float = 1.0,
        verbose: bool = False,
    ):
        self.x_domain = x_domain
        self.y_domain = y_domain
        self.t_domain = t_domain
        self.n_points = n_points
        self.initial_condition = initial_condition
        self.floor = floor
        self.weight_r = weight_r
        self.weight_b = weight_b
        self.weight_i = weight_i

    def residual_loss(self, pinn: PINN):
        # x, y, t = get_interior_points(self.x_domain, self.y_domain, self.t_domain, self.n_points, pinn.device())
        x,y,z,t = x_interior, y_interior, z_interior, t_interior
        u = f(pinn, x, y, t)

        # loss = dfdt(pinn, x, y, t, u, order=2) - \
        #               GRAVITY * ((dfdx(pinn, x, y, t, u) - dzdx(x,y))*dfdx(pinn, x, y, t, u) + \
        #               (u-z) * dfdx(pinn, x, y, t, u, order=2) + \
        #               (dfdy(pinn, x, y, t, u) - dzdy(x,y))*dfdy(pinn, x, y, t, u) + \
        #               (u-z) * dfdy(pinn, x, y, t, u, order=2))
        
        loss = dfdt(pinn, x, y, t, u, order=2) - \
              GRAVITY * (dfdx(pinn, x, y, t, u) ** 2 + \
              (u-z) * dfdx(pinn, x, y, t, u, order=2) + \
              dfdy(pinn, x, y, t, u) ** 2 + \
              (u-z) * dfdy(pinn, x, y, t, u, order=2))
        return loss.pow(2).mean()

    def initial_loss(self, pinn: PINN):
        x, y, t = x_initial, y_initial, t_initial #get_initial_points(self.x_domain, self.y_domain, self.t_domain, self.n_points, pinn.device())
        pinn_init = self.initial_condition(x, y)
        loss = f(pinn, x, y, t) - pinn_init
        return loss.pow(2).mean()

    def boundary_loss(self, pinn: PINN):
        # down, up, left, right = get_boundary_points(self.x_domain, self.y_domain, self.t_domain, self.n_points, pinn.device())
        x_down,  y_down,  t_down    = down
        x_up,    y_up,    t_up      = up
        x_left,  y_left,  t_left    = left
        x_right, y_right, t_right   = right
        
        loss_down  = dfdy( pinn, x_down,  y_down,  t_down  )
        loss_up    = dfdy( pinn, x_up,    y_up,    t_up    )
        loss_left  = dfdx( pinn, x_left,  y_left,  t_left  )
        loss_right = dfdx( pinn, x_right, y_right, t_right )

        return loss_down.pow(2).mean()  + \
            loss_up.pow(2).mean()    + \
            loss_left.pow(2).mean()  + \
            loss_right.pow(2).mean()

    def verbose(self, pinn: PINN, only_initial=False):
        """
        Returns all parts of the loss function

        Not used during training! Only for checking the results later.
        """
        residual_loss = self.residual_loss(pinn)
        initial_loss = self.initial_loss(pinn)
        boundary_loss = self.boundary_loss(pinn)

        final_loss = \
            self.weight_r * residual_loss + \
            self.weight_i * initial_loss + \
            self.weight_b * boundary_loss

        if only_initial:
          final_loss = \
            self.weight_r * residual_loss + \
            self.weight_i * initial_loss + \
            self.weight_b * boundary_loss # 5, 1000 i 1?, 0.0005

        return final_loss, residual_loss, initial_loss, boundary_loss

    def __call__(self, pinn: PINN, only_initial=False):
        """
        Allows you to use the instance of this class as if it were a function:

        ```
            >>> loss = Loss(*some_args)
            >>> calculated_loss = loss(pinn)
        ```
        """
        return self.verbose(pinn, only_initial)

## Train function

In [33]:
def train_model(
    nn_approximator: PINN,
    loss_fn: Callable,
    learning_rate: int = 0.01,
    max_epochs: int = 1_000
) -> PINN:

    optimizer = torch.optim.Adam(nn_approximator.parameters(), lr=learning_rate)
    loss_values = []
    residual_loss_values = []
    initial_loss_values = []
    boundary_loss_values = []
    top_loss = 100000000

    for epoch in range(max_epochs):
        try:
            loss: torch.Tensor = loss_fn(nn_approximator)
            optimizer.zero_grad()
            loss[0].backward()
#             torch.nn.utils.clip_grad_norm_(nn_approximator.parameters(), 0.5)
            optimizer.step()

            if loss[0].item() < top_loss:
                torch.save(nn_approximator, f"./best_{RUN_NUM}.pt")
                top_loss = loss[0].item()

            loss_values.append(loss[0].item())
            residual_loss_values.append(loss[1].item())
            initial_loss_values.append(loss[2].item())
            boundary_loss_values.append(loss[3].item())
            if (epoch + 1) % 1000 == 0:
                print(f"Epoch: {epoch + 1} - Loss: {float(loss[0].item()):>7f}, Residual Loss: {float(loss[1].item()):>7f}, Initital Loss: {float(loss[2].item()):>7f}, Boundary Loss: {float(loss[3].item()):>7f}")

        except KeyboardInterrupt:
            break

    return nn_approximator, np.array(loss_values), np.array(residual_loss_values), np.array(initial_loss_values), np.array(boundary_loss_values)

### Initial condition

In [34]:
def initial_condition(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    r = torch.sqrt((x-LENGTH/2)**2 + (y-LENGTH/2)**2)
    res = 2 * torch.exp(-(r)**2 * 30) + 2
    return res

# Running code

## Train data

In [None]:
pinn = PINN(LAYERS, NEURONS_PER_LAYER, act=nn.Tanh()).to(device)

# train the PINN
loss_fn = Loss(
    x_domain=x_domain,
    y_domain=y_domain,
    t_domain=t_domain,
    n_points=N_POINTS,
    initial_condition=initial_condition,
    floor=floor,
    weight_r=WEIGHT_RESIDUAL,
    weight_b=WEIGHT_BOUNDARY,
    weight_i=WEIGHT_INITIAL
)

pinn_trained, loss_values, residual_loss_values, initial_loss_values, boundary_loss_values = train_model(
    pinn, loss_fn=loss_fn, learning_rate=LEARNING_RATE, max_epochs=EPOCHS)

losses = loss_fn.verbose(pinn)

print(f'Total loss: \t{losses[0]:.5f} ({losses[0]:.3E})')
print(f'Interior loss: \t{losses[1]:.5f} ({losses[1]:.3E})')
print(f'Initial loss: \t{losses[2]:.5f} ({losses[2]:.3E})')
print(f'Boundary loss: \t{losses[3]:.5f} ({losses[3]:.3E})')

In [None]:
plot_running_average(loss_values, "Loss function (runnig average)", "total_loss")

In [None]:
plot_running_average(residual_loss_values, "Residual loss function (running average)", "residual_loss")

In [None]:
plot_running_average(initial_loss_values, "Initial loss function (running average)", "initial_loss")

In [None]:
plot_running_average(boundary_loss_values, "Boundary loss function (running average)", "boundary_loss")

In [None]:
plot_intial_condition(x_domain=x_domain, 
                      y_domain=y_domain, 
                      t_domain=t_domain,
                      pinn=pinn,
                      initial_condition=initial_condition,
                      n_points=N_POINTS_PLOT,
                      length=LENGTH,
                      floor=floor)

In [None]:
plot_simulation_by_frame(total_time=TOTAL_TIME,
                   x_domain=x_domain, 
                   y_domain=y_domain, 
                   t_domain=t_domain,
                   pinn=pinn,
                   n_points=N_POINTS_PLOT,
                   length=LENGTH,
                   floor=floor)

In [None]:
create_gif(TOTAL_TIME, title=f"tsunami_{RUN_NUM}")
create_gif_2d(TOTAL_TIME, title=f"tsunami_2d_{RUN_NUM}")

In [None]:
date = datetime.now()

context: ReportContext = {
           'num': RUN_NUM,
           'date': date,
           'WEIGHT_RESIDUAL': WEIGHT_RESIDUAL, 
           'WEIGHT_INITIAL': WEIGHT_INITIAL, 
           'WEIGHT_BOUNDARY': WEIGHT_BOUNDARY,
           'LAYERS': LAYERS, 
           'NEURONS_PER_LAYER': NEURONS_PER_LAYER,
           'EPOCHS': EPOCHS, 
           'LEARNING_RATE': LEARNING_RATE,
           'total_loss': f"{losses[0]:.3E}",
           "residual_loss": f"{losses[1]:.3E}",
           "initial_loss": f"{losses[2]:.3E}",
           "boundary_loss": f"{losses[3]:.3E}",
           "img1": "./results/total_loss.png",
           "img2": "./results/residual_loss.png",
           "img3": "./results/initial_loss.png",
           "img4": "./results/boundary_loss.png",
    }

create_report(context,
              env_path="/kaggle/",
              template_path="/input/report-template/report_template.html",
              report_title=f"report_{RUN_NUM}.pdf")

In [None]:
# # with profiling
# global x_domain, y_domain, t_domain, loss_fn, pinn_trained, loss_values, residual_loss_values, initial_loss_values, boundary_loss_values, losses

# def run_training():
#     pinn = PINN(LAYERS, NEURONS_PER_LAYER, act=nn.Tanh()).to(device)
    
#     x_domain = [0.0, LENGTH]
#     y_domain = [0.0, LENGTH]
#     t_domain = [0.0, TOTAL_TIME]
    
#     # train the PINN
#     loss_fn = Loss(
#         x_domain=x_domain,
#         y_domain=y_domain,
#         t_domain=t_domain,
#         n_points=N_POINTS,
#         initial_condition=initial_condition,
#         floor=floor,
#         weight_r=WEIGHT_RESIDUAL,
#         weight_b=WEIGHT_BOUNDARY,
#         weight_i=WEIGHT_INITIAL
#     )
    
#     pinn_trained, loss_values, residual_loss_values, initial_loss_values, boundary_loss_values = train_model(
#         pinn, loss_fn=loss_fn, learning_rate=LEARNING_RATE, max_epochs=EPOCHS)
    
#     pinn = pinn.cpu()
#     losses = loss_fn.verbose(pinn)
#     print(f'Total loss: \t{losses[0]:.5f} ({losses[0]:.3E})')
#     print(f'Interior loss: \t{losses[1]:.5f} ({losses[1]:.3E})')
#     print(f'Initial loss: \t{losses[2]:.5f} ({losses[2]:.3E})')
#     print(f'Boundary loss: \t{losses[3]:.5f} ({losses[3]:.3E})')

# %prun -D program3.prof run_training()

### Evaluation

In [None]:
model = torch.load(f"kaggle/working/best_{RUN_NUM}.pt")

In [None]:
plot_simulation_by_frame(total_time=TOTAL_TIME,
                   x_domain=x_domain, 
                   y_domain=y_domain, 
                   t_domain=t_domain,
                   pinn=model.to(device),
                   n_points=N_POINTS_PLOT,
                   length=LENGTH,
                   floor=floor)

In [14]:
import imageio
import jinja2
import matplotlib.pyplot as plt
import numpy as np
import torch

from datetime import datetime
from matplotlib.animation import FuncAnimation
from typing import List, TypedDict, Callable
from xhtml2pdf import pisa


class ReportContext(TypedDict):
    num: int
    date: datetime
    WEIGHT_RESIDUAL: float 
    WEIGHT_INITIAL: float 
    WEIGHT_BOUNDARY: float
    LAYERS: int
    NEURONS_PER_LAYER: int
    EPOCHS: int 
    LEARNING_RATE: float
    total_loss: float
    residual_loss: float
    initial_loss: float
    boundary_loss: float
    img1: str
    img2: str
    img3: str
    img4: str


def create_report(context: ReportContext, 
                  env_path: str, 
                  template_path: str, 
                  report_title: str) -> None:
    template_loader = jinja2.FileSystemLoader(env_path)
    template_env = jinja2.Environment(loader=template_loader)

    template = template_env.get_template(template_path)
    output_text = template.render(context)
    
    with open(f'./results/{report_title}', "w+b") as out_pdf_file_handle:
        pisa.CreatePDF(src=output_text, dest=out_pdf_file_handle)
        
def create_gif_2d(total_time: float, 
               title: str, 
               step: float=0.01, 
               base_dir: str=".", 
               duration: float=0.1) -> None:
    time_values = np.arange(0, total_time, step)
    frames = []
    for idx in range(len(time_values)):
        image = imageio.v2.imread(base_dir + '/img/img_2d_{:03d}.png'.format(idx))
        frames.append(image)

    imageio.mimsave(f'{base_dir}/results/{title}.gif', frames, duration=duration)

def create_gif(total_time: float, 
               title: str, 
               step: float=0.01, 
               base_dir: str=".", 
               duration: float=0.1) -> None:
    time_values = np.arange(0, total_time, step)
    frames = []
    for idx in range(len(time_values)):
        image = imageio.v2.imread(base_dir + '/img/img_{:03d}.png'.format(idx))
        frames.append(image)

    imageio.mimsave(f'{base_dir}/results/{title}.gif', frames, duration=duration)

def plot_intial_condition(x_domain: List[float], 
                          y_domain: List[float], 
                          t_domain: List[float],
                          length: float,
                          pinn: 'PINN',
                          initial_condition: Callable,
                          floor: Callable,
                          n_points: int) -> None:
    title = "Initial condition"
    
    x, y, t = get_initial_points(x_domain, y_domain, t_domain, n_points, requires_grad=False)
    z = initial_condition(x, y)
    
    fig = plot_color(z, x, y, n_points, n_points, f"{title} - exact")
    fig = plot_3D(z, x, y, n_points, n_points, length, floor, f"{title} - exact")
    
    z = pinn(x, y, t)
    
    fig = plot_color(z, x, y, n_points, n_points, f"{title} - PINN")    
    fig = plot_3D(z, x, y, n_points, n_points, length, floor, f"{title} - PINN")
    

def plot_frame(x_domain: List[float], 
               y_domain: List[float], 
               t_domain: List[float], 
               pinn: 'PINN', 
               idx: int, 
               t_value: float, 
               n_points: int,
               length: float,
               floor: Callable,
               base_dir: str=".") -> None:
    x, y, _ = get_initial_points(x_domain, y_domain, t_domain, n_points, requires_grad=False)
    t = torch.full_like(x, t_value)
    z = pinn(x, y, t)
    fig = plot_color(z, x, y, n_points, n_points, f"PINN for t = {t_value}")
    plt.savefig(base_dir + '/img/img_2d_{:03d}.png'.format(idx))
    fig = plot_3D(z, x, y, n_points, n_points, length, floor, f"PINN for t = {t_value}")
    plt.savefig(base_dir + '/img/img_{:03d}.png'.format(idx))
#     plt.clf()


def plot_simulation_by_frame(total_time: float, 
                             x_domain: List[float], 
                             y_domain: List[float], 
                             t_domain: List[float], 
                             pinn: 'PINN', 
                             n_points: int,
                             length: float,
                              floor: Callable,
                             step:float=0.01) -> None:
    time_values = np.arange(0, total_time, step)

    for idx, t_value in enumerate(time_values):
        plot_frame(x_domain=x_domain, 
                   y_domain=y_domain, 
                   t_domain=t_domain,
                   pinn=pinn,
                   idx=idx,
                   t_value=t_value,
                   n_points=n_points,
                   length=length,
                   floor=floor)

def running_average(y, window: int=100):
    cumsum = np.cumsum(np.insert(y, 0, 0))
    return (cumsum[window:] - cumsum[:-window]) / float(window)

def plot_running_average(loss_values, title: str, path: str):
    average_loss = running_average(loss_values, window=100)
    fig, ax = plt.subplots(figsize=(8, 6), dpi=100)
    ax.set_title(title)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.plot(average_loss)
    ax.set_yscale('log')
    
    fig.savefig(f'./results/{path}.png')

def plot_solution(pinn: 'PINN', x: torch.Tensor, t: torch.Tensor, figsize=(8, 6), dpi=100):

    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    x_raw = torch.unique(x).reshape(-1, 1)
    t_raw = torch.unique(t)

    def animate(i):
        if not i % 10 == 0:
            t_partial = torch.ones_like(x_raw) * t_raw[i]
            f_final = f(pinn, x_raw, t_partial)
            ax.clear()
            ax.plot(
                x_raw.detach().numpy(), f_final.detach().numpy(), label=f"Time {float(t[i])}"
            )
            ax.set_ylim(-1, 1)
            ax.legend()

    n_frames = t_raw.shape[0]
    return FuncAnimation(fig, animate, frames=n_frames, interval=100, repeat=False)

def plot_color(z: torch.Tensor, x: torch.Tensor, y: torch.Tensor, n_points_x, n_points_y, title, figsize=(8, 6), dpi=100, cmap="viridis"):
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    z_raw = z.detach().cpu().numpy()
    x_raw = x.detach().cpu().numpy()
    y_raw = y.detach().cpu().numpy()
    X = x_raw.reshape(n_points_x, n_points_y)
    Y = y_raw.reshape(n_points_x, n_points_y)
    Z = z_raw.reshape(n_points_x, n_points_y)
    ax.set_title(title)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    c = ax.pcolormesh(X, Y, Z, cmap=cmap)
    fig.colorbar(c, ax=ax)

    return fig

def plot_3D(z: torch.Tensor, x: torch.Tensor, y: torch.Tensor, n_points_x, n_points_y, length: float, floor: torch.Tensor, title, figsize=(8, 6), dpi=100, limit=4):
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(projection='3d')
    z_raw = z.detach().cpu().numpy()
    x_raw = x.detach().cpu().numpy()
    y_raw = y.detach().cpu().numpy()
    X = x_raw.reshape(n_points_x, n_points_y)
    Y = y_raw.reshape(n_points_x, n_points_y)
    Z = z_raw.reshape(n_points_x, n_points_y)
    ax.set_title(title)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.axes.set_zlim3d(bottom=0, top=limit)

    c = ax.plot_surface(X, Y, Z)

    # mesh
    x_floor, y_floor, z_floor = dump_points(MESH_FILENAME)
    ax.plot_trisurf(x_floor, y_floor, z_floor, linewidth=0.2)


    # based on floor function
    # x_floor = torch.linspace(0.0, length, floor_steps)
    # y_floor = torch.linspace(0.0, length, floor_steps)
    # z_floor = torch.zeros((floor_steps, 50))
    # for x_idx, x_coord in enumerate(x_floor):
    #     for y_idx, y_coord in enumerate(y_floor):
    #         z_floor[x_idx, y_idx] = floor(x_coord, y_coord)
    # x_floor = torch.tile(x_floor, (50, 1))
    # y_floor = torch.tile(y_floor, (50, 1)).T
    # f = ax.plot_surface(x_floor, y_floor, z_floor, color='green', alpha=0.7)
    return fig