In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.nn.utils.parametrizations import weight_norm
import time
import pandas as pd

# Set matplotlib font to support Chinese characters
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# Set initial random seed
init_seed = 42
np.random.seed(init_seed)
torch.manual_seed(init_seed)
torch.cuda.manual_seed(init_seed)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

def fwd_gradients(Y, x):
    """Compute forward gradients"""
    dummy = torch.ones_like(Y)
    G = torch.autograd.grad(Y, x, dummy, create_graph=True)[0]
    return G

class Net(nn.Module):
    """Define the neural network model"""
    def __init__(self, layer_dim, X, device):
        super().__init__()
        self.X_mean = torch.from_numpy(X.mean(0, keepdims=True)).float().to(device)
        self.X_std = torch.from_numpy(X.std(0, keepdims=True)).float().to(device)
        self.num_layers = len(layer_dim)
        temp = []
        for l in range(1, self.num_layers):
            layer = torch.nn.Linear(layer_dim[l-1], layer_dim[l])
            layer = weight_norm(layer, name='weight', dim=0)
            torch.nn.init.normal_(layer.weight)
            temp.append(layer)
        self.layers = torch.nn.ModuleList(temp)
        
    def forward(self, x):
        """Forward propagation"""
        x = ((x - self.X_mean) / self.X_std)
        for i in range(0, self.num_layers-1):
            x = self.layers[i](x)
            if i < self.num_layers-2:
                x = F.silu(x)
        return x

class TSONN:
    """TSONN model class"""
    def __init__(self, layers, device):
        self.Nx = 101
        self.Nt = 101
        self.layers = layers
        self.device = device
        t = torch.linspace(0.0, 5.0, self.Nt)
        x = torch.linspace(0.0, 10.0, self.Nx)
        xx, tt = torch.meshgrid(x, t, indexing='ij')
        self.X_ref = torch.cat([xx.reshape(-1,1), tt.reshape(-1,1)], dim=1).to(self.device).requires_grad_(True)
        self.X_ic = torch.cat([xx[:, [0]], tt[:, [0]]], dim=1).to(self.device)
        self.u_ic = torch.full_like(xx[:, [0]], 0, dtype=torch.float32, device=self.device)
        self.X_lbc = torch.cat([xx[[0]], tt[[0]]], dim=0).T.to(self.device)
        self.X_lbc.requires_grad = True
        self.X_ubc = torch.cat([xx[[-1]], tt[[-1]]], dim=0).T.to(self.device)
        self.X_ubc.requires_grad = True
        self.log = {'losses':[], 'losses_b':[], 'losses_i':[], 'losses_f':[], 'losses_s':[], 'mse_exact':[], 'time':[], 'a':[]}
        self.min_loss = 1
        self.model = Net(layers, self.X_ref.cpu().detach().numpy(), self.device).to(self.device)
        self.a_raw = nn.Parameter(torch.tensor([0.016], device=self.device))
        
        # Generate dataset using exact solution
        Nz = 10  # 10 points on z axis
        Nt = 100  # 100 points on t axis
        z = torch.linspace(0.0, 10.0, Nz, device=self.device).reshape(-1, 1)
        t = torch.linspace(0.0, 5.0, Nt, device=self.device).reshape(-1, 1)
        zz, tt = torch.meshgrid(z.squeeze(), t.squeeze(), indexing='ij')
        self.data_points = torch.cat([zz.reshape(-1,1), tt.reshape(-1,1)], dim=1).to(self.device)
        self.u_data = self.exact_solution(self.data_points[:, 0:1], self.data_points[:, 1:2], 
                                         torch.tensor([0.016], device=self.device)).to(self.device)

    def exact_solution(self, z, t, a, beta_deg=33, L=10, psi_d=-10, c=0.104):
        """Compute exact solution"""
        beta = torch.tensor(beta_deg * torch.pi / 180.0, dtype=torch.float32, device=self.device)
        z = z.to(self.device, dtype=torch.float32)
        t = t.to(self.device, dtype=torch.float32)
        a = a.to(self.device, dtype=torch.float32)
        psi_d = torch.tensor(psi_d, dtype=torch.float32, device=self.device)
        L = torch.tensor(L, dtype=torch.float32, device=self.device)
        c = torch.tensor(c, dtype=torch.float32, device=self.device)
        part_1 = (1 - torch.exp(a * psi_d)) * (1 - torch.exp(-a * torch.cos(beta) * z)) / (1 - torch.exp(-a * torch.cos(beta) * L))
        sum_terms = 0
        for k in range(1, 9999):
            lambda_k = k * torch.pi / L
            mu_k = (a**2 / 4 + lambda_k**2) / c
            sum_terms += ((-1) ** k) * (lambda_k / mu_k) * torch.sin(lambda_k * z) * torch.exp(-mu_k * t)
        part_2 = (2 * (1 - torch.exp(a * psi_d)) / (L * c)) * torch.exp(a * torch.cos(beta) * (L - z) / 2) * sum_terms
        u_true = part_1 + part_2 
        return u_true

    def Msei(self):
        """Initial condition loss"""
        u = self.model(self.X_ic)
        msei = F.mse_loss(u, self.u_ic)
        return msei
    
    def Mseb(self):
        """Boundary condition loss"""
        u_lbc = self.model(self.X_lbc)
        u_ubc = self.model(self.X_ubc)
        mseb = F.mse_loss(u_lbc, torch.full_like(u_lbc, 0, device=self.device)) + \
               F.mse_loss(u_ubc, torch.full_like(u_ubc, 1.0 - torch.exp(torch.tensor(-0.16, device=self.device)), device=self.device))
        return mseb
    
    def TimeStepping(self):
        """Time stepping update"""
        u = self.model(self.X_ref)
        self.U0 = u.detach()
    
    def Msef(self):
        """PDE loss"""
        beta = torch.tensor(33 * torch.pi / 180.0, dtype=torch.float32, device=self.device)
        u = self.model(self.X_ref)
        u_xt = fwd_gradients(u, self.X_ref)
        u_x = u_xt[:,0:1]
        u_t = u_xt[:,1:2]
        u_xx = fwd_gradients(u_x, self.X_ref)[:,0:1]
        a = self.a_raw
        f = u_t * a * 0.39 / 0.06 - u_xx - a * torch.cos(beta) * u_x 
        dt = 10
        msef = 1/dt**2 * ((u - self.U0 + dt * f)**2).mean()
        return msef
    
    def Mses(self):
        """Relative L2 error"""
        z = self.X_ref[:, 0:1]
        t = self.X_ref[:, 1:2]
        u_true = self.exact_solution(z, t, torch.tensor([0.016], device=self.device))
        u_pred = self.model(self.X_ref)
        mses = torch.norm(u_pred - u_true) / torch.norm(u_true)
        return mses

    def Loss(self):
        """Total loss"""
        msei = self.Msei()
        mseb = self.Mseb()
        msef = self.Msef()
        u_pred_data = self.model(self.data_points)
        mse_data = F.mse_loss(u_pred_data, self.u_data)
        loss = msei + mseb +  msef + mse_data 
        return loss, msei, mseb, msef, mse_data

    def ResidualPoint(self):
        """Generate random residual points"""
        t = torch.rand((10000, 1), device=self.device) * 5
        z = torch.rand((10000, 1), device=self.device) * 10
        self.X = torch.cat([z, t], dim=1).requires_grad_(True)

    def train(self, epoch):
        """Train the model"""
        if len(self.log['time']) == 0:
            t_start = time.time()
        else:
            t_start = time.time() - self.log['time'][-1]
        
        best_mse = float('inf')
        for i in range(epoch):
            def closure():
                self.optimizer.zero_grad()
                self.loss, self.loss_i, self.loss_b, self.loss_f, self.loss_data = self.Loss()
                self.loss.backward(retain_graph=True)
                return self.loss
            
            self.optimizer = torch.optim.LBFGS(list(self.model.parameters()) + [self.a_raw], max_iter=100)
            self.ResidualPoint()
            self.TimeStepping()
            self.optimizer.step(closure)
            
            self.loss_s = self.Mses()
            
            z = self.X_ref[:, 0:1]
            t = self.X_ref[:, 1:2]
            u_true = self.exact_solution(z, t, torch.tensor([0.016], device=self.device))
            u_pred = self.model(self.X_ref)
            mse_exact = F.mse_loss(u_pred, u_true)
            self.log['mse_exact'].append(mse_exact.item())
            
            t_end = time.time()
            elapsed = t_end - t_start
            self.log['time'].append(elapsed)
            t_start = time.time()
            
            self.log['losses'].append(self.loss.item())
            self.log['losses_f'].append(self.loss_f.item())
            self.log['losses_b'].append(self.loss_b.item())
            self.log['losses_i'].append(self.loss_i.item())
            self.log['losses_s'].append(self.loss_s.item())
            a = self.a_raw
            self.log['a'].append(a.item())
            
            if mse_exact.item() < best_mse:
                best_mse = mse_exact.item()
                torch.save(self.model.state_dict(), 'best_model.pth')
                torch.save(self.a_raw, 'best_a_raw.pth')
            
            if (self.loss != self.loss) or ((i > 1) and (self.loss.item() > 3 * self.log['losses'][-2])):
                if i == 0:
                    self.model = Net(self.layers, self.X_ref.cpu().detach().numpy(), self.device).to(self.device)
                    self.a_raw = nn.Parameter(torch.tensor([0.00], device=self.device))
                    continue
                else:
                    self.model.load_state_dict(torch.load('best_model.pth'))
                    self.a_raw = torch.load('best_a_raw.pth')
                    print('Loaded best model')
                    self.ResidualPoint()
                    self.optimizer = torch.optim.Adam(list(self.model.parameters()) + [self.a_raw], lr=1e-3)
                    self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.9)
                    continue
        
            self.TimeStepping()
            
            if (i+1) % 1 == 0 or (i+1) == epoch:
                print(f'Epoch {i+1}/{epoch} | Loss: {self.loss.item():.4e} | PDE Loss: {self.loss_f.item():.4e} | Relative Error: {self.loss_s.item():.4e} | Exact Solution MSE: {mse_exact.item():.4e} | a: {a.item():.4e}')

if __name__ == '__main__':
    """Main program"""
    t1 = time.time()
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    layers = [2, 128, 128, 128, 1]
    model = TSONN(layers, device)
    model.train(100)
    t2 = time.time()
    print(f'Total training time: {t2 - t1:.2f} seconds')
    
    # Record data for 100 epochs
    epochs = np.arange(1, 101)
    total_loss = np.array(model.log['losses'])
    a_values = np.array(model.log['a'])
    relative_l2_error = np.array(model.log['losses_s'])
    
    # Plot total loss, parameter a, and relative L2 error
    plt.figure(figsize=(10, 6))
    ax1 = plt.gca()  # Main axis for total loss
    ax2 = ax1.twinx()  # Secondary axis for parameter a and relative L2 error
    
    # Plot total loss (log scale)
    ax1.plot(epochs, np.log10(total_loss), label='Total Loss (Log10)', color='blue')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Log10(Total Loss)', color='blue')
    ax1.tick_params(axis='y', labelcolor='blue')
    
    # Plot parameter a
    ax2.plot(epochs, a_values, label='Parameter a', color='green')
    ax2.axhline(y=0.016, color='green', linestyle='--', label='True a=0.016')
    # Plot relative L2 error
    ax2.plot(epochs, relative_l2_error, label='Relative L2 Error', color='red')
    ax2.set_ylabel('Parameter a / Relative L2 Error', color='black')
    ax2.tick_params(axis='y', labelcolor='black')
    
    # Combine legends
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
    
    plt.title('Total Loss, Parameter a, and Relative L2 Error during Training')
    plt.tight_layout()
    plt.savefig('training_metrics_plot.png')
    plt.show()
    
    # Save data to Excel
    data = {
        'Epoch': epochs,
        'Total Loss': total_loss,
        'Parameter a': a_values,
        'Relative L2 Error': relative_l2_error
    }
    df = pd.DataFrame(data)
    df.to_excel('training_metrics.xlsx', index=False)
    
    print("Plot saved as 'training_metrics_plot.png', data saved as 'training_metrics.xlsx'")
    
    # Plot original loss curve
    plt.figure(figsize=(10,6))
    plt.plot(model.log['time'], np.log10(model.log['losses']), label='Total Loss')
    plt.plot(model.log['time'], np.log10(model.log['losses_s']), label='Relative L2 Error')
    plt.plot(model.log['time'], np.log10(model.log['losses_f']), label='PDE Loss')
    plt.plot(model.log['time'], np.log10(model.log['losses_b']), label='Boundary Loss')
    plt.plot(model.log['time'], np.log10(model.log['losses_i']), label='Initial Loss')
    plt.plot(model.log['time'], np.log10(model.log['mse_exact']), label='Exact Solution MSE')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Log10(Loss)')
    plt.legend()
    plt.title('Loss and Error Changes during Training')
    plt.savefig('loss_plot.png')
    plt.show()
    
    # Plot parameter a's change
    plt.figure(figsize=(10,6))
    plt.plot(model.log['a'], label='Parameter a')
    plt.axhline(y=0.016, color='r', linestyle='--', label='True a=0.016')
    plt.xlabel('Training Steps')
    plt.ylabel('Parameter a')
    plt.legend()
    plt.title('Change of Parameter a during Training')
    plt.savefig('a_plot.png')
    plt.show()
    
    # Plot solution comparison
    XX = model.X_ref[:, 0].cpu().detach().numpy().reshape(model.Nx, model.Nt)
    TT = model.X_ref[:, 1].cpu().detach().numpy().reshape(model.Nx, model.Nt)
    z = model.X_ref[:, 0:1]
    t = model.X_ref[:, 1:2]
    u_pred = model.model(model.X_ref).cpu().detach().numpy().reshape(model.Nx, model.Nt)
    u_true = model.exact_solution(z, t, torch.tensor([0.016], device=model.device)).cpu().detach().numpy().reshape(model.Nx, model.Nt)
    error = np.abs(u_true - u_pred)
    
    fig, axs = plt.subplots(1, 3, figsize=(18, 5))
    c1 = axs[0].pcolor(TT, XX, u_true, cmap='jet')
    fig.colorbar(c1, ax=axs[0])
    axs[0].set_xlabel('$t$')
    axs[0].set_ylabel('$x$')
    axs[0].set_title('Reference Solution')
    
    c2 = axs[1].pcolor(TT, XX, u_pred, cmap='jet')
    fig.colorbar(c2, ax=axs[1])
    axs[1].set_xlabel('$t$')
    axs[1].set_ylabel('$x$')
    axs[1].set_title('Predicted Solution (iTSONN)')
    
    c3 = axs[2].pcolor(TT, XX, error, cmap='jet')
    fig.colorbar(c3, ax=axs[2])
    axs[2].set_xlabel('$t$')
    axs[2].set_ylabel('$x$')
    axs[2].set_title('Absolute Error')
    
    plt.tight_layout()
    plt.savefig('solution_comparison.png')
    plt.show()
    
    # Plot solutions at different time points
    fig, axs = plt.subplots(1, 3, figsize=(18, 5))
    time_indices = [25, 50, 100]
    for idx, t_idx in enumerate(time_indices):
        axs[idx].plot(XX[:,0], u_true[:,t_idx], label='Exact Solution', color='blue')
        axs[idx].plot(XX[:,0], u_pred[:,t_idx], '--', label='Predicted Solution', color='red')
        axs[idx].set_xlabel('$x$')
        axs[idx].set_ylabel('$u$')
        axs[idx].set_title(f'$t = {TT[0, t_idx]:.2f}$')
        axs[idx].legend()
    plt.tight_layout()
    plt.savefig('time_slice_comparison.png')
    plt.show()
