In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from pprint import pprint
import h5py
import time
import matplotlib.pyplot as plt
import torch.nn.init as init

<h3>Geometry</h3>

In [None]:
# DEVICE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# PINN DEFAULT SETTING :
PDE_batch_size = 4096
IC_batch_size  = 4096
BC_batch_size  = 1024

# Initial/Boundary condition voltage
V0 = 1

# GEOMETRY
x_min = -1
x_max = 1
y_min = -1
y_max = 1
z_min = -1.5
z_max = 1.5
factor_x, factor_y, factor_z = x_max-x_min, y_max-y_min, z_max-z_min

t_min = 0
t_max = 10
factor_t = t_max-t_min
c_time_err = 1e-2
logt_min = np.log(t_min + c_time_err)
logt_max = np.log(t_max + c_time_err)
factor_logt = logt_max-logt_min


flag_norm_log = True

if flag_norm_log == True:
    # with log-transform
    def Norm_time(time_coords):
        if isinstance(time_coords, torch.Tensor):
            return ( torch.log( time_coords + c_time_err ) ) / factor_logt
        else:
            return ( np.log( time_coords + c_time_err ) ) / factor_logt
            
    def Inv_Norm_time(norm_time_coords):
        if isinstance(norm_time_coords, torch.Tensor):
            return ( torch.exp(norm_time_coords * factor_logt ) - c_time_err )
        else:
            return ( np.exp(norm_time_coords * factor_logt ) - c_time_err )
else:
    # with linear-transform
    def Norm_time(time_coords):
        if isinstance(time_coords, torch.Tensor):
            return time_coords / factor_t
        else:
            return time_coords / factor_t
    
    def Inv_Norm_time(norm_time_coords):
        if isinstance(norm_time_coords, torch.Tensor):
            return norm_time_coords * factor_t
        else:
            return norm_time_coords * factor_t


norm_t_min, norm_t_max = Norm_time(t_min), Norm_time(t_max)

In [None]:
print("Error of normalization:")
print(Norm_time(Inv_Norm_time(2)) - 2)
print(Inv_Norm_time(Norm_time(2)) - 2)

<h1>PINN</h1>

<h3>Initial & Boundary conditions</h3>

In [None]:
# ==== IC ====
def gen_IC_points(n_points=IC_batch_size):
    n_points = n_points//2
    t_coords = norm_t_min * torch.ones(n_points, 1, device=device)
    x_coords = x_min + (x_max - x_min) * torch.rand(n_points, 1, device=device)
    y_coords = y_min + (y_max - y_min) * torch.rand(n_points, 1, device=device)
    z_coords = z_min + (z_max - z_min) * torch.rand(n_points, 1, device=device)
    points = torch.cat( [t_coords, x_coords, y_coords, z_coords], dim=-1 )
    values = torch.sin(torch.pi*((z_coords-z_min)/(z_max-z_min)))
    return points, values

def compute_IC_loss(model, n_points=IC_batch_size, loss_function=nn.MSELoss()):
    coords, values = gen_IC_points(n_points)
    coords.requires_grad = True
    values_pred = model.forward(coords)
    loss = loss_function( values_pred, values )
    return loss

# ==== BC ====
def gen_BC_points(n_points=BC_batch_size):
    t_coords = norm_t_min + (norm_t_max - norm_t_min) * torch.rand(n_points, 1, device=device)
    x_coords = x_min + (x_max - x_min) * torch.rand(n_points, 1, device=device)
    y_coords = y_min + (y_max - y_min) * torch.rand(n_points, 1, device=device)
    # z_min BC
    z_coords = z_min * torch.ones(n_points, 1, device=device)
    points2 = torch.cat( [t_coords, x_coords, y_coords, z_coords], dim=-1 )
    values2 = torch.zeros_like(z_coords, device=device)
    # z_max BC
    z_coords = z_max * torch.ones(n_points, 1, device=device)
    points3 = torch.cat( [t_coords, x_coords, y_coords, z_coords], dim=-1 )
    values3 = torch.zeros_like(z_coords, device=device)
    points = torch.cat([points2, points3], dim=0)
    values = torch.cat([values2, values3], dim=0)
    return points, values

def compute_BC_loss(model, n_points=BC_batch_size, loss_function=nn.MSELoss()):
    coords, values = gen_BC_points(n_points)
    coords.requires_grad = True
    values_pred = model.forward(coords)
    loss = loss_function( values_pred, values )
    return loss


# ==== PLOTS ====

# IC PLOTs
points, IC_values = gen_IC_points(10000)
# plane XZ
x_coords = points[:,1]
z_coords = points[:,3]
x_coords = x_coords.detach().cpu().numpy()
z_coords = z_coords.detach().cpu().numpy()
IC_values = IC_values.detach().cpu().numpy()
# plots
fig = plt.scatter(x_coords, z_coords, c=IC_values, s=2)
plt.title("Initial condition")
plt.xlabel("x")
plt.ylabel("z")
plt.colorbar(fig)
plt.show()
# plots
fig = plt.scatter(z_coords, IC_values, s=2)
plt.title("Initial condition")
plt.xlabel("z")
plt.ylabel("V")
plt.show()

# BC PLOTs
points, BC_values = gen_BC_points(10000)
# plane XZ
x_coords = points[:,1]
z_coords = points[:,3]
x_coords = x_coords.detach().cpu().numpy()
z_coords = z_coords.detach().cpu().numpy()
BC_values = BC_values.detach().cpu().numpy()
# plots
fig = plt.scatter(x_coords, z_coords, c=BC_values, s=2, vmin=0, vmax=V0)
plt.title("Boundary condition")
plt.xlabel("x")
plt.ylabel("z")
plt.colorbar(fig)
plt.show()

<h3>PDE</h3>

In [None]:
## HEAT EQUATION
class QSM_PDE_heat:       

    def gen_PDE_points(self, n_points=PDE_batch_size):
        t_coords = norm_t_min + (norm_t_max - norm_t_min) * torch.rand(n_points, 1, device=device)
        x_coords = x_min + (x_max - x_min) * torch.rand(n_points, 1, device=device)
        y_coords = y_min + (y_max - y_min) * torch.rand(n_points, 1, device=device)
        z_coords = z_min + (z_max - z_min) * torch.rand(n_points, 1, device=device)
        points = torch.cat( [t_coords, x_coords, y_coords, z_coords], dim=-1 )
        return points

    
    def compute_PDE(self, coords, pred_func, flag_norm_log=True):
        u = pred_func[:,0]
        u_t  = self.get_derivative(pred_func, coords, 1)[:,0]
        u_x  = self.get_derivative(pred_func, coords, 1)[:,1]
        u_xx = self.get_derivative(u_x, coords, 1)[:,1]
        u_y  = self.get_derivative(pred_func, coords, 1)[:,2]
        u_yy = self.get_derivative(u_y, coords, 1)[:,2]
        u_z  = self.get_derivative(pred_func, coords, 1)[:,3]
        u_zz = self.get_derivative(u_z, coords, 1)[:,3]
        Delta_u = u_xx + u_yy + u_zz
        if flag_norm_log == True:
            # with log-transform
            factorrr = ( torch.exp(coords[:,0]*factor_logt ) * factor_logt )
            eq = u_t - Delta_u * factorrr
        else:
            # without log-transform
            eq = u_t - Delta_u * factor_t
        return eq

    
    def get_derivative(self, y, x, n: int = 1):
        """
            compute n-times 1D derivatives of y along x-direction
        """
        if n == 0:
            return y
        else:
            dy_dx = torch.autograd.grad(y, x, torch.ones_like(y).to(y.device), create_graph=True, retain_graph=True, allow_unused=True)[0]              
        return self.get_derivative(dy_dx, x, n - 1)

    
    def compute_PDE_loss(self, model, n_points=PDE_batch_size, loss_function=nn.MSELoss(), flag_norm_log=True):
        coords = self.gen_PDE_points(n_points=n_points)
        coords.requires_grad = True
        values_pred = model.forward(coords)
        pde_pred = self.compute_PDE(coords, values_pred, flag_norm_log=flag_norm_log)
        loss = loss_function( pde_pred, torch.zeros_like(pde_pred, dtype=torch.float32, device=device) )
        return loss

<h3>PINN model</h3>

In [None]:
class PINN_model(nn.Module):
   
    def __init__(self, input_dim, n_nodes, n_layers, n_batches, dropout):
        super().__init__()
        # PDE part
        self._PDE = QSM_PDE_heat()
        self.history = []
        # DNN part
        self.n_batches = n_batches
        self.dropout   = dropout
        layers = [nn.Linear(input_dim, n_nodes), nn.Tanh(), nn.Dropout(self.dropout)]
        for _ in range(n_layers - 1):
            layers.append(nn.Linear(n_nodes, n_nodes))
            layers.append(nn.Tanh())
            layers.append(nn.Dropout(self.dropout))
        layers.append(nn.Linear(n_nodes, 1))
        layers.append(nn.Tanh())
        # network
        self.network = nn.Sequential(*layers)
        # initialization of NN
        for layer in self.network:
            if isinstance(layer, nn.Linear):
                init.xavier_normal_(layer.weight, gain=1.0)
                init.zeros_(layer.bias)

    
    def forward(self, coords):
        return self.network(coords)
                

    def store_training_df(self, history, epochs):
        training_loss = history[:,4] * history[:,0] + history[:,5] * history[:,1] + history[:,8] * history[:,9]
        validation_loss = history[:,4] * history[:,2] + history[:,5] * history[:,3] + history[:,8] * history[:,10]
        df_train = pd.DataFrame(
            {
                "epochs"            : [ e for e in range(epochs) ],
                "training_loss"     : training_loss,
                "validation_loss"   : validation_loss,
                'lr'                : history[:,7],
                "IC_train_losses"   : history[:,0],
                "BC_train_losses"   : history[:,9],
                "PDE_train_losses"  : history[:,1],
                "weight_IC"         : history[:,4],
                "weight_BC"         : history[:,8],
                "weight_PDE"        : history[:,5]
            }
        )
        df_train.to_csv(f'history_training_plane.csv')
        torch.save(self.state_dict(), "end_plane_model.pth")      

    
    def train_model(self, optimizer, patience = 10, loss_function = nn.MSELoss(), epochs = 200, validation_split = 0.1):
        print("Start :")
        start_time = time.time()   
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=patience)
        # number of batches for training and validation
        n_train_batches = int( (1 - validation_split) * n_batches)
        n_val_batches   = n_batches - n_train_batches
        print(f"  number of epochs             : {epochs}")
        print(f"  number of train batches      : {n_train_batches}")
        print(f"  number of validation batches : {n_val_batches}")
        print(" ")        
        print("Training in progress...")
        print(" ")

        weight_IC, weight_BC, weight_PDE = 1, 1, 1
        # EPOCHS
        for epoch in range(epochs):
            start = time.time()
            
            # === TRAINING STEP ===
            self.train()
            start_tr = time.time()
            IC_total_loss, BC_total_loss, PDE_total_loss = 0, 0, 0            
            for indx_batch in range(n_train_batches):
                optimizer.zero_grad()
                IC_loss  = compute_IC_loss(model=self)
                BC_loss  = compute_BC_loss(model=self)
                PDE_loss = self._PDE.compute_PDE_loss(model=self)
                Loss = weight_IC * IC_loss + weight_BC * BC_loss + weight_PDE * PDE_loss
                Loss.backward()
                optimizer.step()
                IC_total_loss   += IC_loss.item()
                BC_total_loss   += BC_loss.item()
                PDE_total_loss  += PDE_loss.item()
            IC_total_loss  /= n_train_batches
            BC_total_loss  /= n_train_batches
            PDE_total_loss /= n_train_batches
            end_tr = time.time()

            # === VALIDATION STEP ===
            self.eval()
            start_val = time.time()
            IC_val_loss, BC_val_loss, PDE_val_loss = 0, 0, 0  
            for indx_batch in range(n_val_batches):
                optimizer.zero_grad()
                IC_val_loss  += (compute_IC_loss(model=self) ).item()
                BC_val_loss  += ( compute_BC_loss(model=self) ).item()
                PDE_val_loss += ( self._PDE.compute_PDE_loss(model=self) ).item()
            IC_val_loss   /= n_val_batches
            BC_val_loss   /= n_val_batches
            PDE_val_loss  /= n_val_batches
            scheduler.step( weight_IC * IC_val_loss + weight_BC * BC_val_loss + weight_PDE * PDE_val_loss )
            end_val = time.time()
            
            end = time.time()

            
            lambda_eff = weight_PDE/np.maximum(weight_IC,1e-8) * PDE_total_loss/np.maximum(IC_total_loss,1e-8)
            current_lr = optimizer.param_groups[0]['lr']
            # === PRINT LOSSES ===
            print(f"Epoch {epoch+1}/{epochs}")
            print(f"  Training losses   ==>  IC: {IC_total_loss} - BC: {BC_total_loss} - pde: {PDE_total_loss}")
            print(f"  Validation losses ==>  IC: {IC_val_loss} - BC: {BC_val_loss} - PDE: {PDE_val_loss}")
            print(f"  time : {end - start} s  - train: {end_tr - start_tr} s  - val: {end_val - start_val} s")
            print(f"     ( w_IC : {weight_IC} - w_BC : {weight_IC} - w_PDE : {weight_PDE} - lambda_eff : {lambda_eff} )")
            print(f"     ( learning rate : {current_lr} )")
            self.history.append([ IC_total_loss, PDE_total_loss, IC_val_loss, PDE_val_loss,
                            weight_IC, weight_PDE, lambda_eff, current_lr, weight_BC, BC_total_loss, BC_val_loss ])
            print(f" ")
            print(f" ")
            

        # END EPOCHS
        self.store_training_df(np.array(self.history), epochs)
        print(f"Done!  (time : {time.time() - start_time})")
        return self.history

In [None]:
# construction of PINN
model = 0
input_dim = 4     # (t, x, y, z)
n_nodes = 64
n_layers = 4
n_batches = 32
dropout = 0
model = PINN_model( input_dim, n_nodes, n_layers, n_batches, dropout ).to(device)

# training of PINN
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-9)
loss_function = nn.MSELoss()
history = model.train_model(optimizer, patience=100, loss_function=loss_function, epochs=1000)
print(model)


# PLOTS OF TRAINING AND VALIDATION STEPS
# plot IC
plt.plot([pair[0] for pair in history], label="Training")
plt.plot([pair[2] for pair in history], label="Validation")
plt.legend(title="IC error")
plt.xlabel("Epoch")
plt.ylabel("Loss values (MSE)")
plt.grid(alpha=0.2)
plt.yscale("log")
plt.show()
# plot BC
plt.plot([pair[9] for pair in history], label="Training")
plt.plot([pair[10] for pair in history], label="Validation")
plt.legend(title="BC error")
plt.xlabel("Epoch")
plt.ylabel("Loss values (MSE)")
plt.grid(alpha=0.2)
plt.yscale("log")
plt.show()
# plot PDE
plt.plot([pair[1] for pair in history], label="Training")
plt.plot([pair[3] for pair in history], label="Validation")
plt.legend(title="PDE error")
plt.xlabel("Epoch")
plt.ylabel("Loss values (MSE)")
plt.grid(alpha=0.2)
plt.yscale("log")
plt.show()
# plot LAMBDA EFFECTIVE
plt.plot([pair[6] for pair in history], label="lambda eff")
plt.legend(title="Lambda effective")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.grid(alpha=0.2)
plt.yscale("log")
plt.show()
# plot LEARNING RATE
plt.plot([pair[7] for pair in history], label="lr")
plt.legend(title="Learning rate")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.grid(alpha=0.2)
plt.yscale("log")
plt.show()

In [None]:
# PREDICTION of V using PINN model
sampled_times = np.linspace( norm_t_min, norm_t_max, 6 )
sampled_x = (x_max + x_min) / 2
sampled_y = (y_max + y_min) / 2
n_samples = 20000
plt.figure(figsize=(16,4))
model.eval()
pde_class = QSM_PDE_heat()

for i_plot, n_t in enumerate(sampled_times, 1):
    real_t = Inv_Norm_time(n_t)
    X_pred = pde_class.gen_PDE_points(n_samples)
    X_pred[:,0] = n_t * torch.ones_like(X_pred[:,0], device=device)
    X_Sample = X_pred.detach().cpu().numpy()
    # predicted result
    with torch.no_grad():
        V_pred = model.forward(X_pred)
    V_pred = V_pred.detach().cpu().numpy()
    
    plt.subplot(1, len(sampled_times), i_plot)
    plt.title(f"Pred {real_t:.5f} ns")
    im0=plt.scatter(X_Sample[:, 1], X_Sample[:, 3], c=V_pred, s=2, cmap='plasma', vmin=0, vmax=1)
    plt.ylabel("z")
    plt.xlabel("x")
    plt.colorbar(im0)
    
plt.tight_layout()
plt.savefig("plot_heat.jpg", format="jpg", dpi=300)
plt.show()




# === Settings ===
sampled_times = norm_t_min + (norm_t_max - norm_t_min) * np.linspace(0, 1, 12)
plt.figure(figsize=(8, 3 * len(sampled_times)))
model.eval()
pde_class = QSM_PDE_heat()

for i_plot, n_t in enumerate(sampled_times, 1):
    real_t = Inv_Norm_time(n_t)
    X_pred = pde_class.gen_PDE_points(n_samples)
    X_pred[:, 0] = n_t * torch.ones_like(X_pred[:, 0], device=device)
    X_Sample = X_pred.detach().cpu().numpy()
    with torch.no_grad():
        V_pred = model.forward(X_pred)
    V_pred = V_pred.detach().cpu().numpy().squeeze()
    ax_left = plt.subplot(len(sampled_times), 2, 2 * i_plot - 1)
    im0 = ax_left.scatter(X_Sample[:, 1], X_Sample[:, 3], c=V_pred, s=2, cmap='plasma', vmin=0, vmax=1)
    ax_left.set_title(f"t = {real_t:.2f} ns\nV(x,z)")
    ax_left.set_xlabel("x")
    ax_left.set_ylabel("z")
    plt.colorbar(im0, ax=ax_left)

    X_pred[:, 1] = sampled_x * torch.ones_like(X_pred[:, 0], device=device)
    X_pred[:, 2] = sampled_y * torch.ones_like(X_pred[:, 0], device=device)
    X_pred[:, 3] = torch.linspace(z_min, z_max, len(X_pred[:,0]))
    X_Sample = X_pred.detach().cpu().numpy()
    with torch.no_grad():
        V_pred = model.forward(X_pred)
    V_pred = V_pred.detach().cpu().numpy().squeeze()
    ax_right = plt.subplot(len(sampled_times), 2, 2 * i_plot)
    ax_right.plot(X_Sample[:,3], V_pred)
    ax_right.set_title(f"V(z) at x={sampled_x:.2f}, y={sampled_y:.2f}")
    ax_right.set_xlabel("z")
    ax_right.set_ylabel("V")
    ax_right.set_ylim(0, 1)

plt.tight_layout()
plt.savefig("plot_heat2.jpg", dpi=300)
plt.show()

