In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter, context
from mindspore.common.initializer import initializer, HeUniform, Uniform

In [None]:
# set MindSpore context
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")

In [None]:
# define neural network for PINNs
class NN(nn.Cell):
    def __init__(self, input_size, hidden_size, output_size, depth, act=nn.Tanh):
        super(NN, self).__init__()
        layers = []
        # Input layer
        layers.append(nn.Dense(input_size, hidden_size, weight_init=HeUniform()))
        layers.append(act())
        
        # Hidden layers
        for _ in range(depth - 1):
            layers.append(nn.Dense(hidden_size, hidden_size, weight_init=HeUniform()))
            layers.append(act())
        
        # Output layer
        layers.append(nn.Dense(hidden_size, output_size, weight_init=Uniform(scale=0.1)))
        
        self.net = nn.SequentialCell(layers)
    
    def construct(self, x):
        return self.net(x)

In [None]:
class Net:
    def __init__(self, convection_coeff=0.5, diffusion_coeff=0.01):
        self.device = context.get_context("device_target")
        self.loss_history = []
        self.epoch_history = []
        
        self.model = NN(
            input_size=2,
            hidden_size=20,
            output_size=1,
            depth=4,
            act=nn.Tanh
        )
        
        # Define the spatial and temporal domain
        self.x_min, self.x_max = 0.0, 1.0
        self.t_min, self.t_max = 0.0, 1.0
        
        # Define convection and diffusion coefficients
        self.c = convection_coeff
        self.D = diffusion_coeff
        
        # Visualization grid
        self.h_vis = 0.01
        self.k_vis = 0.01
        x_vis = np.arange(self.x_min, self.x_max + self.h_vis, self.h_vis)
        t_vis = np.arange(self.t_min, self.t_max + self.k_vis, self.k_vis)
        
        # setting up the PDE points for training
        x_train_pde_res = np.linspace(self.x_min, self.x_max, 100)
        t_train_pde_res = np.linspace(self.t_min, self.t_max, 100)
        pde_points_x, pde_points_t = np.meshgrid(x_train_pde_res, t_train_pde_res, indexing='ij')
        self.pde_points = np.stack([pde_points_x.flatten(), pde_points_t.flatten()], axis=1)
        self.pde_points = Tensor(self.pde_points, mindspore.float32)
        self.pde_points.requires_grad = True
        
        # Training data
        x_bc_res = np.linspace(self.x_min, self.x_max, 50)
        t_bc_res = np.linspace(self.t_min, self.t_max, 50)
        
        # Boundary conditions
        bc1 = np.stack(np.meshgrid([self.x_min], t_bc_res, indexing='ij')).reshape(2, -1).T
        bc2 = np.stack(np.meshgrid([self.x_max], t_bc_res, indexing='ij')).reshape(2, -1).T
        ic = np.stack(np.meshgrid(x_bc_res, [self.t_min], indexing='ij')).reshape(2, -1).T
        
        self.X_train = np.concatenate([bc1, bc2, ic])
        self.X_train = Tensor(self.X_train, mindspore.float32)
        
        # Boundary and initial conditions
        y_bc1 = np.ones(len(bc1))  # LBC
        y_bc2 = np.zeros(len(bc2))  # RBC
        y_ic = np.exp(-10 * ic[:, 0]**2)  # IC
        
        self.y_train = np.concatenate([y_bc1, y_bc2, y_ic])
        self.y_train = Tensor(self.y_train.reshape(-1, 1), mindspore.float32)
        
        # Loss function
        self.criterion = nn.MSELoss()
        self.iter = 1
        
        # Adam
        self.optimizer = nn.Adam(self.model.trainable_params(), learning_rate=1e-3)
        
        # visualization grid
        self.plot_x = x_vis
        self.plot_t = t_vis
        self.plot_X, self.plot_T = np.meshgrid(self.plot_x, self.plot_t)
        self.plot_points = np.vstack([self.plot_X.ravel(), self.plot_T.ravel()]).T
        self.plot_points = Tensor(self.plot_points, mindspore.float32)
    
    # PDE Loss
    def pde_loss(self, x_pde):
        u = self.model(x_pde)
        grad_fn = mindspore.ops.grad(self.model, grad_position=0)
        grads = grad_fn(x_pde)
        u_x = grads[:, 0:1]
        u_t = grads[:, 1:2]
        
        # second derivative
        grad2_fn = mindspore.ops.grad(lambda x: grad_fn(x)[:, 0:1].sum(), grad_position=0)
        u_xx = grad2_fn(x_pde)[:, 0:1]
        
        f_pde = u_t + self.c * u_x - self.D * u_xx
        return self.criterion(f_pde, ops.ZerosLike()(f_pde))
    
    # Train step
    def train_step(self, epochs):
        # Forward pass
        def forward_fn(x, y, pde_points):
            pred_data = self.model(x)
            loss_data = self.criterion(pred_data, y)
            loss_pde = self.pde_loss(pde_points)
            total_loss = loss_data + loss_pde
            return total_loss, loss_data, loss_pde
        
        # Gradient function
        grad_fn = mindspore.value_and_grad(forward_fn, None, self.optimizer.parameters, has_aux=True)
        
        # Training loop
        for epoch in range(epochs):
            # compute gradients and loss
            (total_loss, loss_data, loss_pde), grads = grad_fn(
                self.X_train, self.y_train, self.pde_points
            )
            
            # update parameters
            self.optimizer(grads)
            
            # Loss
            self.loss_history.append(total_loss.item())
            self.epoch_history.append(self.iter)
            
            if self.iter % 100 == 0:
                print(f"Iter {self.iter}, Total Loss: {total_loss.item():.4e}, "
                      f"Data Loss: {loss_data.item():.4e}, PDE Loss: {loss_pde.item():.4e}")
            self.iter += 1
    
    # Train
    def train(self, epochs_adam=8000):
        print("Starting Adam training...")
        self.model.set_train()
        self.train_step(epochs_adam)
        print("Training complete.")
    
    # Visualization
    def plot_results(self):
        self.model.set_train(False)
        u_pred = self.model(self.plot_points).asnumpy().reshape(self.plot_T.shape)
        
        plt.figure(figsize=(10, 6))
        
        times_to_plot = [0.0, 0.1, 0.25, 0.5, 0.75, 1.0] 
        colors = plt.cm.viridis(np.linspace(0, 1, len(times_to_plot))) 

        for i, t_val in enumerate(times_to_plot):
            t_idx = np.argmin(np.abs(self.plot_t - t_val)) 
            plt.plot(self.plot_x, u_pred[t_idx, :], label=f'PINN Pred. at t={t_val:.2f}', color=colors[i])
        
        # draw initial condition points
        ic_train_mask = (self.X_train[:, 1] == self.t_min)
        ic_x_train = self.X_train[ic_train_mask, 0].asnumpy()
        ic_y_train = self.y_train[ic_train_mask].asnumpy().flatten()
        plt.scatter(ic_x_train, ic_y_train, color='red', marker='o', s=30, label='IC Training Data', zorder=5)

        # draw boundary condition points
        bc_left_mask = (self.X_train[:, 0] == self.x_min)
        bc_right_mask = (self.X_train[:, 0] == self.x_max)
        
        bc_x_left_train = self.X_train[bc_left_mask, 0].asnumpy()
        bc_y_left_train = self.y_train[bc_left_mask].asnumpy().flatten()
        plt.scatter(bc_x_left_train, bc_y_left_train, color='green', marker='x', s=30, label=f'BC Training Data (x={self.x_min})', zorder=5)

        bc_x_right_train = self.X_train[bc_right_mask, 0].asnumpy()
        bc_y_right_train = self.y_train[bc_right_mask].asnumpy().flatten()
        plt.scatter(bc_x_right_train, bc_y_right_train, color='blue', marker='x', s=30, label=f'BC Training Data (x={self.x_max})', zorder=5)

        plt.title('PINN Predicted Solution $u(x,t)$ at Different Time Snapshots')
        plt.xlabel('x')
        plt.ylabel('u')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        plt.figure(figsize=(10, 6))
        
        x_to_plot = [0.0, 0.25, 0.5, 0.75, 1.0] 
        colors = plt.cm.plasma(np.linspace(0, 1, len(x_to_plot)))

        for i, x_val in enumerate(x_to_plot):
            x_idx = np.argmin(np.abs(self.plot_x - x_val)) 
            plt.plot(self.plot_t, u_pred[:, x_idx], label=f'PINN Pred. at x={x_val:.2f}', color=colors[i])
        
        plt.title('PINN Predicted Solution $u(x,t)$ at Different Spatial Locations Over Time')
        plt.xlabel('t')
        plt.ylabel('u')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # heatmap/contour plot
        plt.figure(figsize=(10, 8))
        plt.contourf(self.plot_X, self.plot_T, u_pred, 50, cmap='viridis')
        cbar = plt.colorbar(label='u(x,t)')
        
        # draw training data points
        X_train_np = self.X_train.asnumpy()
        plt.scatter(X_train_np[:, 0], X_train_np[:, 1], color='red', marker='o', s=10, label='Training Data Points', alpha=0.5)

        plt.title('PINN Predicted Solution $u(x,t)$ (Heatmap)')
        plt.xlabel('x')
        plt.ylabel('t')
        plt.legend()
        plt.tight_layout()
        plt.show()

        # 3D surface plot
        fig = plt.figure(figsize=(12, 10))
        ax = fig.add_subplot(111, projection='3d')
        surf = ax.plot_surface(self.plot_X, self.plot_T, u_pred, cmap=cm.viridis,
                               linewidth=0, antialiased=False)
        fig.colorbar(surf, shrink=0.5, aspect=5, label='u(x,t)')
        ax.set_xlabel('x')
        ax.set_ylabel('t')
        ax.set_zlabel('u')
        ax.set_title('PINN Predicted Solution $u(x,t)$ (3D Surface)')
        plt.show()
        
        # Loss History Plot
        plt.figure(figsize=(8, 6))
        plt.plot(self.epoch_history, self.loss_history)
        plt.xlabel('Iteration')
        plt.ylabel('Total Loss (Log Scale)')
        plt.title('PINN Training Loss History')
        plt.yscale('log') 
        plt.grid(True)
        plt.show()

In [None]:
if __name__ == '__main__':
    convection_speed = 0.5
    diffusion_coeff = 0.01
    
    pinn_solver = Net(convection_coeff=convection_speed, diffusion_coeff=diffusion_coeff)
    
    print("Starting PINN training...")
    pinn_solver.train(epochs_adam=8000)
    print("Training complete. Generating plots...")
    
    pinn_solver.plot_results()