# Physics-informed neural network for 3D inverse modeling of natural-state geothermal systems

In [None]:
import random
from pathlib import Path
from typing import List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

## Defines prameters for the PINN training

In [None]:
from dataclasses import dataclass
from enum import Enum, auto


class OptimizationMode(Enum):
    """
    Enumeration type that represents the optimization mode.

    Attributes
    ----------
    TF
    """
    TF = auto()
    LBFGS = auto()


@dataclass
class TaskConfiguration:
    """
    Data class that holds the settings information for a task.

    Attributes
    ----------
    dtype: str
        Data type
    data_csv: str
        Path of the CSV file
    x1_column_name: str
        Column name for x axis in the CSV file
    x2_column_name: str
        Column name for y axis
    x3_column_name: str
        Column name for z axis (elevation)
    T_column_name: str
        Column name for temperature
    p_column_name: str
        Column name for pressure
    k_column_name: str
        Column name for logarithm of permeability
    N_x1: int
        The number of data points along x axis
    N_x2: int
        The number of data points along y axis
    N_x3: int
        The number of data points along z axis
    data_wells_csv: str
        Path of the CSV file for well data
    x1_column_name_wells: str
        Column name for x-axis in the well data
    x2_column_name_wells: str
        Column name for y-axis in the well data
    x3_column_name_wells: str
        Column name for z-axis in the well data
    T_column_name_wells: str
        Column name for temperature in the well data
    p_column_name_wells: str
        Column name for pressure in the well data
    k_column_name_wells: str
        Column name for permeability in the well data
    observation_point_min_elevation_for_training: int
        Min elevation as a threshold to determine training data
    observation_point_max_elevation_for_validation: int
        Max elevation as a threshold to determine validation data
    number_of_collocation_points: int
        The number of collocation points
    qN: float
        Heat flow at the bottom boundary
    Lambda: float
        Thermal conductivity
    fix_random_numbers: bool
        A flag that determine whether random numbers are fixed
    seed: int
        Seed number to generate random numbers
    first_optimization_mode: OptimizationMode
        Optimization mode for the first training
    second_optimization_mode: OptimizationMode
        Optimization mode for the second training
    third_optimization_mode: OptimizationMode
        Optimization mode for the third training
    fourth_optimization_mode: OptimizationMode
        Optimization mode for the fourth training
    first_optimization_lr: tf.keras.optimizers.schedules.LearningRateSchedule
        Learning rate for the first training
    second_optimization_lr: tf.keras.optimizers.schedules.LearningRateSchedule
        Learning rate for the second training
    third_optimization_lr: tf.keras.optimizers.schedules.LearningRateSchedule
        Learning rate for the third training
    fourth_optimization_lr: tf.keras.optimizers.schedules.LearningRateSchedule
        Learning rate for the fourth training
    first_optimization_steps: int
        Number of maximum iteration (epoch) in the first training
    second_optimization_steps: int
        Number of maximum iteration (epoch) in the second training
    third_optimization_steps: int
        Number of maximum iteration (epoch) in the third training
    fourth_optimization_steps: int
        Number of maximum iteration (epoch) in the fourth training
    """

    dtype: str
    data_csv: str
    x1_column_name: str
    x2_column_name: str
    x3_column_name: str
    T_column_name: str
    p_column_name: str
    k_column_name: str
    N_x1: int
    N_x2: int
    N_x3: int
    data_wells_csv: str
    x1_column_name_wells: str
    x2_column_name_wells: str
    x3_column_name_wells: str
    T_column_name_wells: str
    p_column_name_wells: str
    k_column_name_wells: str
    observation_point_min_elevation_for_training: int
    observation_point_max_elevation_for_validation: int
    number_of_collocation_points: int
    qN: float
    Lambda: float
    fix_random_numbers: bool
    seed: int
    first_optimization_mode: OptimizationMode
    second_optimization_mode: OptimizationMode
    third_optimization_mode: OptimizationMode
    fourth_optimization_mode: OptimizationMode
    first_optimization_lr: tf.keras.optimizers.schedules.LearningRateSchedule
    second_optimization_lr: tf.keras.optimizers.schedules.LearningRateSchedule
    third_optimization_lr: tf.keras.optimizers.schedules.LearningRateSchedule
    fourth_optimization_lr: tf.keras.optimizers.schedules.LearningRateSchedule
    first_optimization_steps: int = 1
    second_optimization_steps: int = 1
    third_optimization_steps: int = 1
    fourth_optimization_steps: int = 1

In [None]:
# Define the PINN's training data and settings
task_config = TaskConfiguration(
    dtype="float32",
    data_csv="Reference_model.csv",
    x1_column_name="X_Easting",
    x2_column_name="Y_Northing",
    x3_column_name="Elevation",
    T_column_name="T_degC",
    p_column_name="P_Pa",
    k_column_name="log10PER",
    N_x1=18,
    N_x2=11,
    N_x3=18,
    data_wells_csv="Welldata_30wells.csv",
    x1_column_name_wells="X_Easting",
    x2_column_name_wells="Y_Northing",
    x3_column_name_wells="Elevation",
    T_column_name_wells="T_degC",
    p_column_name_wells="P_Pa",
    k_column_name_wells="log10PER",
    observation_point_min_elevation_for_training=-1600,
    observation_point_max_elevation_for_validation=-1600,
    number_of_collocation_points=2000,
    qN=0.6,
    Lambda=2.0,
    fix_random_numbers=True,
    seed=211*17,
    first_optimization_mode=OptimizationMode.TF,
    first_optimization_steps=7000,
    first_optimization_lr=tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,10000],[5e-3,2.5e-3,5e-4]),
    second_optimization_mode=OptimizationMode.TF,
    second_optimization_steps=2000,
    second_optimization_lr=tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,10000],[5e-4,1e-4,5e-5]),
    third_optimization_mode=OptimizationMode.TF,
    third_optimization_steps=6000,
    third_optimization_lr=tf.keras.optimizers.schedules.PiecewiseConstantDecay([3000,10000],[5e-4,1e-4,5e-5]),
    fourth_optimization_mode=OptimizationMode.LBFGS,
    fourth_optimization_steps=500,
    fourth_optimization_lr=None,
)

## Create output directory

In [None]:
from datetime import datetime

# Sets the directory name with a timestamp
# The format of the timestamp is "yyMMdd-hhmmss".
timestamp = datetime.now().strftime("%y%m%d-%H%M%S")

# The name of output directory is set to be "result-{timestamp}" .
output_dir = Path(f"result-{timestamp}")

# Create output directory
output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
from dataclasses import asdict
import json

# The task settings are saved in JSON format directly under the output directory.
task_config_save_path = output_dir/'task_config.json'
with open(task_config_save_path, 'w', encoding='utf-8') as jsonfile:
    json.dump(asdict(task_config), jsonfile, indent=2, default=str)

## Fix random number

In [None]:
if task_config.fix_random_numbers:
    random.seed(task_config.seed)
    np.random.seed(task_config.seed)
    tf.random.set_seed(task_config.seed)
    tf.config.experimental.enable_op_determinism()

## Read input csv data

In [None]:
def read_csv_data(
    csv_path: str,
    x1_column_name="X_Easting",
    x2_column_name="Y_Northing",
    x3_column_name="Elevation",
    T_column_name="T_degC",
    p_column_name="P_Pa",
    k_column_name="log10PER",
    dtype="float32",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Read CSV file and extract coordinates and quantities

    Parameters
    ----------
    csv_path: str
        Path of the CSV file
    x1_column_name: str
        Column name for x axis in the CSV file
    x2_column_name: str
        Column name for y axis in the CSV file
    x3_column_name: str
        Column name for z axis (elevation) in the CSV file
    T_column_name: str
        Column name for temperature in the CSV file
    p_column_name: str
        Column name for pressure in the CSV file
    k_column_name: str
        Column name for permeability in the CSV file
    dtype: str
        Data type

    Returns
    -------
    (X, T_star, p_star, k_star)
        X: XYZ coordinates, shape = (num_points, 3)
        T: Temperature, shape = (num_points, 1)
        p: Pressure, shape = (num_points, 1)
        k: Logarithm of permeability, shape = (num_points, 1)
    """
    csv_data = pd.read_csv(csv_path)

    X = np.array((csv_data[x1_column_name], csv_data[x2_column_name], csv_data[x3_column_name]), dtype=dtype).T

    T = np.array(csv_data[T_column_name], dtype=dtype).reshape(X.shape[0], 1)
    p = np.array(csv_data[p_column_name], dtype=dtype).reshape(X.shape[0], 1)
    k = np.array(csv_data[k_column_name], dtype=dtype).reshape(X.shape[0], 1)

    return X, T, p, k

In [None]:
# Read reference data from the specified CSV file
X, T_true, p_true, k_true = read_csv_data(
    task_config.data_csv,
    task_config.x1_column_name,
    task_config.x2_column_name,
    task_config.x3_column_name,
    task_config.T_column_name,
    task_config.p_column_name,
    task_config.k_column_name)

# Read well data from the specified CSV file
X_star, T_star, p_star, k_star = read_csv_data(
    task_config.data_wells_csv,
    task_config.x1_column_name_wells,
    task_config.x2_column_name_wells,
    task_config.x3_column_name_wells,
    task_config.T_column_name_wells,
    task_config.p_column_name_wells,
    task_config.k_column_name_wells)

In [None]:
# Define normalize functions for temperature, pressure, and permeability
from modules.normalizer import MinMaxNormalizer

T_normalizer = MinMaxNormalizer(min_value=T_star.min(), max_value=T_star.max())
p_normalizer = MinMaxNormalizer(min_value=p_star.min(), max_value=p_star.max())
k_normalizer = MinMaxNormalizer(min_value=k_star.min(), max_value=k_star.max())

In [None]:
def plot_true_data(
    X: np.ndarray, T_true: np.ndarray, p_true: np.ndarray, k_true: np.ndarray,
    savefig_prefix: Optional[str] = None,
):

    for true_values, var_name, cmap_name, vmin, vmax in [
        (T_true, "T_true", "hot", 15, 350),
        # (T_true, "T_true", "RdYlBu"),
        (p_true, "p_true", "cool", 1e5, 2.5e7),
        (k_true, "k_true", "jet", -17, -12),
    ]:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        pl = ax.scatter(xs=X[:, 0], ys=X[:, 1], zs=X[:, 2], vmin=vmin, vmax=vmax, c=true_values, s=30, marker="s", cmap=cmap_name)
        plt.colorbar(pl)
        ax.set_title(var_name)
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        plt.tight_layout()
        if savefig_prefix:
            savefig_path = Path(f"{savefig_prefix}{var_name}.png")
            savefig_path.parent.mkdir(parents=True, exist_ok=True)
            # Save figure
            plt.savefig(savefig_path)
        plt.show()


# Plot reference data
plot_true_data(
    X, T_true, p_true, k_true,
    savefig_prefix=f'{output_dir.as_posix()}/figures/',
)

## Network initialization

In [None]:
from modules.neuralnet import PINN_NeuralNet3

# Lower bounds
lb = X.min(0)
# Upper bounds
ub = X.max(0)

# Initialize PINN model
model = PINN_NeuralNet3(lb, ub)
model.build(input_shape=(None, 3))

In [None]:
# A dictionary of indices that indicate which of the elements of the tuples output by NN correspond to each variable
NN_output_indices = {"T": 0, "p": 1, "k": 2}

## Define loss function

### Extract well data below the specified threshold as training set

In [None]:
# Extract well data points equal to and deeper than the specified elevation threshold 
x3_match_train = (
    X_star[:, 2] >= task_config.observation_point_min_elevation_for_training,
)

train_obs_point_indices = np.flatnonzero(x3_match_train)

# Coordinates, temperatures, pressurs, and permeabilities at the extracted indices
train_X_obs = X_star[train_obs_point_indices]
train_T_obs = T_star[train_obs_point_indices]
train_p_obs = p_star[train_obs_point_indices]
train_k_obs = k_star[train_obs_point_indices]

In [None]:
# Plot locations for observation points
def plot_points(points: np.ndarray, x1_range=(None, None), x2_range=(None, None), x3_range=(None, None)):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=10)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_xlim(x1_range)
    ax.set_ylim(x2_range)
    ax.set_zlim(x3_range)
    plt.grid()
    plt.tight_layout()
    plt.show()


plot_points(
    train_X_obs,
    x1_range=(X[:, 0].min(), X[:, 0].max()),
    x2_range=(X[:, 1].min(), X[:, 1].max()),
    x3_range=(X[:, 2].min(), X[:, 2].max()),
)

In [None]:
from modules.loss import ObservationLoss

# Define loss component for each observed quantities (temperature, pressure, permeability)
T_obs_loss = ObservationLoss(
    observed_point_coordinates=train_X_obs,
    observed_values=T_normalizer.normalize(train_T_obs),
    NN_output_index=NN_output_indices["T"],
)
p_obs_loss = ObservationLoss(
    observed_point_coordinates=train_X_obs,
    observed_values=p_normalizer.normalize(train_p_obs),
    NN_output_index=NN_output_indices["p"],
)
k_obs_loss = ObservationLoss(
    observed_point_coordinates=train_X_obs,
    observed_values=k_normalizer.normalize(train_k_obs),
    NN_output_index=NN_output_indices["k"],
)

### Extract well data below the specified threshold as validation set

In [None]:
# Extract well data points deeper than the specified elevation threshold 
x3_match_val = (
    X_star[:, 2] <= task_config.observation_point_max_elevation_for_validation,
)

val_obs_point_indices = np.flatnonzero(x3_match_val)

# Coordinates, temperatures, pressurs, and permeabilities at the extracted indices
val_X_obs = X_star[val_obs_point_indices]
val_T_obs = T_star[val_obs_point_indices]
val_p_obs = p_star[val_obs_point_indices]
val_k_obs = k_star[val_obs_point_indices]

In [None]:
# Plot locations of validation data points
plot_points(
    val_X_obs,
    x1_range=(X[:, 0].min(), X[:, 0].max()),
    x2_range=(X[:, 1].min(), X[:, 1].max()),
    x3_range=(X[:, 2].min(), X[:, 2].max()),
)

In [None]:
# For the observed values used in training, set the loss of the prediction error for each variable.
val_T_obs_loss = ObservationLoss(
    observed_point_coordinates=val_X_obs,
    observed_values=T_normalizer.normalize(val_T_obs),
    NN_output_index=NN_output_indices["T"],
)
val_p_obs_loss = ObservationLoss(
    observed_point_coordinates=val_X_obs,
    observed_values=p_normalizer.normalize(val_p_obs),
    NN_output_index=NN_output_indices["p"],
)
val_k_obs_loss = ObservationLoss(
    observed_point_coordinates=val_X_obs,
    observed_values=k_normalizer.normalize(val_k_obs),
    NN_output_index=NN_output_indices["k"],
)

### Loss for boundary conditions

#### Upper surface boundary condition

In [None]:
from modules.loss import DirichletBoundaryConditionLoss

# Extract point indices at the upper surface boundary
x3 = X[:, 2]
top_boundary_point_indices = np.flatnonzero(x3 == x3.max())
X_top_boundary = X[top_boundary_point_indices]

# Dirichlet boundary condition for temperature
TbD_loss = DirichletBoundaryConditionLoss(
    boundary_coordinates=X_top_boundary,
    boundary_values=T_normalizer.normalize(T_true[top_boundary_point_indices]),
    NN_output_index=NN_output_indices["T"],
)

# Dirichlet boundary condition for pressure
pbD_loss = DirichletBoundaryConditionLoss(
    boundary_coordinates=X_top_boundary,
    boundary_values=p_normalizer.normalize(p_true[top_boundary_point_indices]),
    NN_output_index=NN_output_indices["p"],
)

#### Side boundary condition

In [None]:
from modules.loss import NeumannBoundaryConditionLoss

# Extract point indices at the east and west boundaries
x1 = X[:, 0]
west_boundary_point_indices = np.flatnonzero(x1 == x1.min())
east_boundary_point_indices = np.flatnonzero(x1 == x1.max())

# Normal vector at the west boundary
west_boundary_normal_vectors = np.tile(
    np.array([-1.0, 0.0, 0.0]), (len(west_boundary_point_indices), 1)
)

# Normal vector at the east boundary
east_boundary_normal_vectors = np.tile(
    np.array([1.0, 0.0, 0.0]), (len(east_boundary_point_indices), 1)
)

# Extract point indices at the north and south boundaries
x2 = X[:, 1]
south_boundary_point_indices = np.flatnonzero(x2 == x2.min())
north_boundary_point_indices = np.flatnonzero(x2 == x2.max())

# Normal vector at the south boundary
south_boundary_normal_vectors = np.tile(
    np.array([0.0, -1.0, 0.0]), (len(south_boundary_point_indices), 1)
)

# Normal vector at the north boundary
north_boundary_normal_vectors = np.tile(
    np.array([0.0, 1.0, 0.0]), (len(north_boundary_point_indices), 1)
)

# Neumann boundary condition for pressure (mass flow) 
pbN_loss = NeumannBoundaryConditionLoss(
    boundary_coordinates=np.concatenate(
        [X[west_boundary_point_indices], X[east_boundary_point_indices],
         X[south_boundary_point_indices], X[north_boundary_point_indices]]
    ),
    boundary_values=np.zeros(
        shape=(len(west_boundary_point_indices)
               + len(east_boundary_point_indices)
               + len(south_boundary_point_indices)
               + len(north_boundary_point_indices), 1)
    ),
    normal_vector=np.concatenate(
        [west_boundary_normal_vectors, east_boundary_normal_vectors,
         south_boundary_normal_vectors, north_boundary_normal_vectors]
    ),
    normalizer=p_normalizer,
    NN_output_index=NN_output_indices["p"],
)

#### Bottom boundary condition

In [None]:
# Extract point indices at the bottom boundary
x3 = X[:, 2]
bottom_boundary_point_indices = np.flatnonzero(x3 == x3.min())
X_bottom_boundary = X[bottom_boundary_point_indices]

# Noraml vector at the bottom boundary condition
bottom_boundary_normal_vectors = np.tile(
    np.array([0.0, 0.0, -1.0]), (len(bottom_boundary_point_indices), 1)
)

# Temperature gradient at the bottom
dTdn = task_config.qN / task_config.Lambda

# Neumann boundary condition for temperature (heat flow) at the bottom
TbN_loss = NeumannBoundaryConditionLoss(
    boundary_coordinates=X_bottom_boundary,
    boundary_values=dTdn * np.ones(shape=(len(bottom_boundary_point_indices), 1)),
    normal_vector=bottom_boundary_normal_vectors,
    normalizer=T_normalizer,
    NN_output_index=NN_output_indices["T"],
)

### Loss for physical laws

#### collocation points のみ

In [None]:
from modules.loss import (
    PhysicsInformedLoss_r1,
    PhysicsInformedLoss_r2,
)

# Select collocation points where physical laws are calculated

collocation_point_indices = np.random.choice(
    np.arange(len(X)), task_config.number_of_collocation_points, replace=False
)

PI1_loss = PhysicsInformedLoss_r1(
    collocation_point_coordinates=X[collocation_point_indices],
    T_normalizer=T_normalizer,
    p_normalizer=p_normalizer,
    k_normalizer=k_normalizer,
    dtype=task_config.dtype,
)

PI2_loss = PhysicsInformedLoss_r2(
    collocation_point_coordinates=X[collocation_point_indices],
    T_normalizer=T_normalizer,
    p_normalizer=p_normalizer,
    k_normalizer=k_normalizer,
    Lambda=task_config.Lambda,
    dtype=task_config.dtype,
)

In [None]:
# Plot locations of collocation points
plot_points(
    X[collocation_point_indices],
    x1_range=(X[:, 0].min(), X[:, 0].max()),
    x2_range=(X[:, 1].min(), X[:, 1].max()),
    x3_range=(X[:, 2].min(), X[:, 2].max()),
)

#### Whole grid (for test evaluation)

In [None]:
# Set the grid to calculate physical laws in all target domain
# These are used for test evaluation, not for training

all_PI1_loss = PhysicsInformedLoss_r1(
    collocation_point_coordinates=X,
    T_normalizer=T_normalizer,
    p_normalizer=p_normalizer,
    k_normalizer=k_normalizer,
)

all_PI2_loss = PhysicsInformedLoss_r2(
    collocation_point_coordinates=X,
    T_normalizer=T_normalizer,
    p_normalizer=p_normalizer,
    k_normalizer=k_normalizer,
    Lambda=task_config.Lambda,
)

## Execute training

### First training

#### Setting

In [None]:
from modules.solver_callback import (
    PrintCurrentLossAndMetricsCallback,
    CsvLoggerCallback,
    NNWeightsCheckpointCallback,
    AllGridPredictionCallback,
)

# Define loss function
loss_functions_stage1 = {
    "train_T_MSE": T_obs_loss,
    "train_p_MSE": p_obs_loss,
    "train_k_MSE": k_obs_loss,
}

# Define loss components to be monitored
# These loss components are not used for training
metrics_stage1 = {
    "val_T_MSE": val_T_obs_loss,
    "val_p_MSE": val_p_obs_loss,
    "val_k_MSE": val_k_obs_loss,
    "TbD_loss": TbD_loss,
    "pbD_loss": pbD_loss,
    "TbN_loss": TbN_loss,
    "pbN_loss": pbN_loss,
    "PI1_loss": PI1_loss,
    "PI2_loss": PI2_loss,
    "all_PI1_loss": all_PI1_loss,
    "all_PI2_loss": all_PI2_loss,
}

# Callback setting
history_csv_path_stage1 = output_dir / "history_stage1.csv"
predicts_save_dir_stage1 = output_dir / "save_predicts_stage1"
callbacks_stage1 = [
    PrintCurrentLossAndMetricsCallback(interval=100),
    CsvLoggerCallback(csv_path=history_csv_path_stage1),
    NNWeightsCheckpointCallback(
        checkpoint_dir=output_dir / "checkpoints_stage1", interval=1000
    ),
    AllGridPredictionCallback(
        save_prediction_dir=predicts_save_dir_stage1,
        interval=1000,
        all_grid_point_coordinates=X,
        T_normalizer=T_normalizer,
        p_normalizer=p_normalizer,
        k_normalizer=k_normalizer,
    ),
]

#### Network training

In [None]:
# Execute the first training
from modules.solver import PINNSolver, PINNSolverLBfgs


if task_config.first_optimization_mode == OptimizationMode.TF:
    print('Start training with a TF optimizer')
    optimizer_stage1 = tf.keras.optimizers.Adam(
        learning_rate=task_config.first_optimization_lr
    )
    solver_stage1 = PINNSolver(
        model=model,
        optimizer=optimizer_stage1,
        loss_functions=loss_functions_stage1,
        metrics=metrics_stage1,
    )
    solver_stage1.solve(
        n_steps=task_config.first_optimization_steps,
        callbacks=callbacks_stage1,
    )

elif task_config.first_optimization_mode == OptimizationMode.LBFGS:
    print('Start training with an L-BFGS optimizer')
    solver_stage1 = PINNSolverLBfgs(
        model=model,
        loss_functions=loss_functions_stage1,
        metrics=metrics_stage1,
        lbfgs_kwargs=dict()
    )
    solver_stage1.solve(
        n_steps=task_config.first_optimization_steps,
        callbacks=callbacks_stage1,
    )

else:
    raise ValueError(f'first_optimization_mode must be "TF" or "LBFGS", but {task_config.first_optimization_mode}')

#### Plot training results

In [None]:
def plot_learning_curve(
    history_csv_path: str,
    target_metrics: List[str] = ["train_loss"],
    savefig_path: Optional[str] = None,
):
    # Read loss history
    history_df = pd.read_csv(history_csv_path, index_col=0)

    # Plot learning curve
    plt.figure()
    history_df.plot(
        y=target_metrics,
        use_index=True,
        title="Learning Curve",
        figsize=(10, 4),
        grid=True,
        xlim=(0, None),
        ylim=(0, None),
        logy=False,  # If vertical axis is plotted in logarithm scale, set it "True"
    )
    if savefig_path:
        Path(savefig_path).parent.mkdir(parents=True, exist_ok=True)
        # Save figure
        plt.savefig(savefig_path)
    plt.show()


# Plot learning curve
plot_learning_curve(
    history_csv_path_stage1,
    savefig_path=(output_dir / "figures_stage1/learning_curve.png").as_posix(),
)

In [None]:
def plot_prediction(prediction_csv_path: str, savefig_prefix: Optional[str] = None, dtype='float32'):
    # Predicted values are read from the csv file
    prediction_df = pd.read_csv(prediction_csv_path, index_col=None)

    X = np.array((prediction_df["x1"], prediction_df["x2"], prediction_df["x3"]), dtype=dtype).T

    for var_name, cmap_name, vmin, vmax in [
        ("T_pred", "hot", 15, 350),
        ("p_pred", "cool", 1e5, 2.5e7),
        ("k_pred", "jet", -17, -12),
    ]:
        var = np.array(prediction_df[var_name], dtype=dtype).reshape(X.shape[0], 1)

        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        pl = ax.scatter(
            xs=X[:, 0], ys=X[:, 1], zs=X[:, 2], vmin=vmin, vmax=vmax, c=var, 
            s=30, marker="s", cmap=cmap_name)
        plt.colorbar(pl)
        ax.set_title(var_name)
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        plt.tight_layout()
        if savefig_prefix:
            savefig_path = Path(f"{savefig_prefix}{var_name}.png")
            savefig_path.parent.mkdir(parents=True, exist_ok=True)
            # Save figure
            plt.savefig(savefig_path)
        plt.show()


# Plot the spatial distribution of the predicted quantities (T, P, and log10K)
plot_prediction(
    f"{predicts_save_dir_stage1}/predicted_{solver_stage1.step_counter}.csv",
    savefig_prefix=f"{output_dir.as_posix()}/figures_stage1/",
    dtype=task_config.dtype
)

### Second training

#### Setting

In [None]:
# Define loss function
loss_functions_stage2 = {
    "train_T_MSE": T_obs_loss,
    "train_p_MSE": p_obs_loss,
    "train_k_MSE": k_obs_loss,
    "PI1_loss": PI1_loss,
    "PI2_loss": PI2_loss,
}

# Weights of loss component from the loss history of previous training
last_losses = solver_stage1.get_last_loss_and_metrics()
average_Tpk_loss = np.average(
    [
        last_losses["train_T_MSE"],
        last_losses["train_p_MSE"],
        last_losses["train_k_MSE"],
    ]
)
loss_weights_stage2 = {
    "PI1_loss": 0,
    "PI2_loss": 0,
}

# Set loss component to be monitored during the training
# These loss components do not influence NN training
metrics_stage2 = {
    "val_T_MSE": val_T_obs_loss,
    "val_p_MSE": val_p_obs_loss,
    "val_k_MSE": val_k_obs_loss,
    "TbD_loss": TbD_loss,
    "pbD_loss": pbD_loss,
    "TbN_loss": TbN_loss,
    "pbN_loss": pbN_loss,
    "all_PI1_loss": all_PI1_loss,
    "all_PI2_loss": all_PI2_loss,
}

# Callback setting
history_csv_path_stage2 = output_dir / "history_stage2.csv"
predicts_save_dir_stage2 = output_dir / "save_predicts_stage2"
callbacks_stage2 = [
    PrintCurrentLossAndMetricsCallback(interval=100),
    CsvLoggerCallback(csv_path=history_csv_path_stage2),
    NNWeightsCheckpointCallback(
        checkpoint_dir=output_dir / "checkpoints_stage2", interval=100
    ),
    AllGridPredictionCallback(
        save_prediction_dir=predicts_save_dir_stage2,
        interval=100,
        all_grid_point_coordinates=X,
        T_normalizer=T_normalizer,
        p_normalizer=p_normalizer,
        k_normalizer=k_normalizer,
    ),
]

#### Network training

In [None]:
# Run the second training
if task_config.second_optimization_mode == OptimizationMode.TF:
    optimizer_stage2 = tf.keras.optimizers.Adam(
        learning_rate=task_config.second_optimization_lr
    )
    solver_stage2 = PINNSolver(
        model=model,
        optimizer=optimizer_stage2,
        loss_functions=loss_functions_stage2,
        loss_weights=loss_weights_stage2,
        metrics=metrics_stage2,
    )
    solver_stage2.solve(
        n_steps=task_config.second_optimization_steps,
        callbacks=callbacks_stage2,
    )

elif task_config.second_optimization_mode == OptimizationMode.LBFGS:
    print('Start training with an L-BFGS optimizer')
    solver_stage2 = PINNSolverLBfgs(
        model=model,
        loss_functions=loss_functions_stage2,
        metrics=metrics_stage2,
        loss_weights=loss_weights_stage2,
        lbfgs_kwargs=dict()
    )
    solver_stage2.solve(
        n_steps=task_config.second_optimization_steps,
        callbacks=callbacks_stage2,
    )

else:
    raise ValueError(f'second_optimization_mode must be "TF" or "LBFGS", but {task_config.second_optimization_mode}')

#### Plot training results

In [None]:
# Plot learning curve
plot_learning_curve(
    history_csv_path_stage2,
    savefig_path=(output_dir / "figures_stage2/learning_curve.png").as_posix(),
)

In [None]:
# Plot the spatial distribution of the predicted quantities
plot_prediction(
    f"{predicts_save_dir_stage2}/predicted_{solver_stage2.step_counter}.csv",
    savefig_prefix=f"{output_dir.as_posix()}/figures_stage2/",
    dtype=task_config.dtype
)

### Third training

#### Setting

In [None]:
# Define loss function
loss_functions_stage3 = {
    "train_T_MSE": T_obs_loss,
    "train_p_MSE": p_obs_loss,
    "train_k_MSE": k_obs_loss,
    "TbD_loss": TbD_loss,
    "pbD_loss": pbD_loss,
    "PI1_loss": PI1_loss,
    "PI2_loss": PI2_loss,
}

# Weights of loss component from the loss history of previous training
last_losses = solver_stage2.get_last_loss_and_metrics()
average_Tpk_loss = np.average(
    [
        last_losses["train_T_MSE"],
        last_losses["train_p_MSE"],
        last_losses["train_k_MSE"],
    ]
)
loss_weights_stage3 = {
    "PI1_loss": float(average_Tpk_loss / last_losses["PI1_loss"]),
    "PI2_loss": float(average_Tpk_loss / last_losses["PI2_loss"]),
}

# Set loss component to be monitored during the training
# These loss components do not influence NN training.
metrics_stage3 = {
    "val_T_MSE": val_T_obs_loss,
    "val_p_MSE": val_p_obs_loss,
    "val_k_MSE": val_k_obs_loss,
    "all_PI1_loss": all_PI1_loss,
    "all_PI2_loss": all_PI2_loss,
}

# Callback setting
history_csv_path_stage3 = output_dir / "history_stage3.csv"
predicts_save_dir_stage3 = output_dir / "save_predicts_stage3"
callbacks_stage3 = [
    PrintCurrentLossAndMetricsCallback(interval=100),
    CsvLoggerCallback(csv_path=history_csv_path_stage3),
    NNWeightsCheckpointCallback(
        checkpoint_dir=output_dir / "checkpoints_stage3", interval=1000
    ),
    AllGridPredictionCallback(
        save_prediction_dir=predicts_save_dir_stage3,
        interval=1000,
        all_grid_point_coordinates=X,
        T_normalizer=T_normalizer,
        p_normalizer=p_normalizer,
        k_normalizer=k_normalizer,
    ),
]

#### Network training

In [None]:
# Run the third training
if task_config.third_optimization_mode == OptimizationMode.TF:
    optimizer_stage3 = tf.keras.optimizers.Adam(
        learning_rate=task_config.third_optimization_lr
    )
    solver_stage3 = PINNSolver(
        model=model,
        optimizer=optimizer_stage3,
        loss_functions=loss_functions_stage3,
        loss_weights=loss_weights_stage3,
        metrics=metrics_stage3,
    )
    solver_stage3.solve(
        n_steps=task_config.third_optimization_steps,
        callbacks=callbacks_stage3,
    )

elif task_config.third_optimization_mode == OptimizationMode.LBFGS:
    print('Start training with an L-BFGS optimizer')
    solver_stage3 = PINNSolverLBfgs(
        model=model,
        loss_functions=loss_functions_stage3,
        loss_weights=loss_weights_stage3,
        metrics=metrics_stage3,
        lbfgs_kwargs=dict()
    )
    solver_stage3.solve(
        n_steps=task_config.third_optimization_steps,
        callbacks=callbacks_stage3,
    )

else:
    raise ValueError(f'third_optimization_mode must be "TF" or "LBFGS", but {task_config.third_optimization_mode}')

#### Plot training results

In [None]:
# Plot learning curve
plot_learning_curve(
    history_csv_path_stage3,
    savefig_path=(output_dir / "figures_stage3/learning_curve.png").as_posix(),
)

In [None]:
# Plot the spatial distribution of the predicted quantities
plot_prediction(
    f"{predicts_save_dir_stage3}/predicted_{solver_stage3.step_counter}.csv",
    savefig_prefix=f"{output_dir.as_posix()}/figures_stage3/",
    dtype=task_config.dtype
)

### Fourth training

#### Setting

In [None]:
# Define loss function
loss_functions_stage4 = {
    "train_T_MSE": T_obs_loss,
    "train_p_MSE": p_obs_loss,
    "train_k_MSE": k_obs_loss,
    "TbD_loss": TbD_loss,
    "pbD_loss": pbD_loss,
    "PI1_loss": PI1_loss,
    "PI2_loss": PI2_loss,
}

# # Weights of loss component from the loss history of previous training
last_losses = solver_stage3.get_last_loss_and_metrics()
average_Tpk_loss = np.average(
    [
        last_losses["train_T_MSE"],
        last_losses["train_p_MSE"],
        last_losses["train_k_MSE"],
    ]
)
loss_weights_stage4 = {
    "PI1_loss": float(average_Tpk_loss / last_losses["PI1_loss"]),
    "PI2_loss": float(average_Tpk_loss / last_losses["PI2_loss"]),
}

# Set loss component to be monitored during the training
# These loss components do not influence NN training
metrics_stage4 = {
    "val_T_MSE": val_T_obs_loss,
    "val_p_MSE": val_p_obs_loss,
    "val_k_MSE": val_k_obs_loss,
    "all_PI1_loss": all_PI1_loss,
    "all_PI2_loss": all_PI2_loss,
}

# Callback setting
history_csv_path_stage4 = output_dir / "history_stage4.csv"
predicts_save_dir_stage4 = output_dir / "save_predicts_stage4"
callbacks_stage4 = [
    PrintCurrentLossAndMetricsCallback(interval=100),
    CsvLoggerCallback(csv_path=history_csv_path_stage4),
    NNWeightsCheckpointCallback(
        checkpoint_dir=output_dir / "checkpoints_stage4", interval=1000
    ),
    AllGridPredictionCallback(
        save_prediction_dir=predicts_save_dir_stage4,
        interval=1000,
        all_grid_point_coordinates=X,
        T_normalizer=T_normalizer,
        p_normalizer=p_normalizer,
        k_normalizer=k_normalizer,
    ),
]

#### Network training

In [None]:
# Run the forth training
if task_config.fourth_optimization_mode == OptimizationMode.TF:
    optimizer_stage4 = tf.keras.optimizers.Adam(
        learning_rate=task_config.fourth_optimization_lr
    )
    solver_stage4 = PINNSolver(
        model=model,
        optimizer=optimizer_stage4,
        loss_functions=loss_functions_stage4,
        loss_weights=loss_weights_stage4,
        metrics=metrics_stage4,
    )
    solver_stage4.solve(
        n_steps=task_config.fourth_optimization_steps,
        callbacks=callbacks_stage4,
    )

elif task_config.fourth_optimization_mode == OptimizationMode.LBFGS:
    print('Start training with an L-BFGS optimizer')
    solver_stage4 = PINNSolverLBfgs(
        model=model,
        loss_functions=loss_functions_stage4,
        loss_weights=loss_weights_stage4,
        metrics=metrics_stage4,
        lbfgs_kwargs=dict()
    )
    solver_stage4.solve(
        n_steps=task_config.fourth_optimization_steps,
        callbacks=callbacks_stage4,
    )

else:
    raise ValueError(f'fourth_optimization_mode must be "TF" or "LBFGS", but {task_config.fourth_optimization_mode}')

#### Plot training results

In [None]:
# Plot learning curve
plot_learning_curve(
    history_csv_path_stage4,
    savefig_path=(output_dir / "figures_stage4/learning_curve.png").as_posix(),
)

In [None]:
# Plot the spatial distribution of the predicted quantities
plot_prediction(
    f"{predicts_save_dir_stage4}/predicted_{solver_stage4.step_counter}.csv",
    savefig_prefix=f"{output_dir.as_posix()}/figures_stage4/",
    dtype=task_config.dtype
)