In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# -- 1. MLP和PINN模型定义 --

class MLP(nn.Module):
    """
    一个标准的多层感知机（MLP）
    """
    def __init__(self, layer_dims):
        super().__init__()
        self.layers = nn.ModuleList()
        # 使用 Xavier/Glorot 初始化来帮助稳定训练
        for i in range(len(layer_dims) - 1):
            linear_layer = nn.Linear(layer_dims[i], layer_dims[i+1])
            nn.init.xavier_uniform_(linear_layer.weight)
            nn.init.zeros_(linear_layer.bias)
            self.layers.append(linear_layer)
        self.activation = nn.Tanh()

    def forward(self, x):
        for i in range(len(self.layers) - 1):
            x = self.activation(self.layers[i](x))
        # 最后一层没有激活函数
        x = self.layers[-1](x)
        return x

class PINN_Burgers(nn.Module):
    """
    用于伯格斯方程的物理信息神经网络
    """
    def __init__(self, layer_dims, true_nu=None):
        super().__init__()
        self.network = MLP(layer_dims)
        
        # 将 nu 定义为一个可学习的参数
        initial_nu = 0.1 if true_nu is None else true_nu
        # 将 nu 的值限制为正数，使用 softplus 的逆变换来初始化
        # 这样可以保证在优化过程中 nu 始终为正
        self.log_nu = nn.Parameter(torch.tensor([np.log(initial_nu)], dtype=torch.float32))

    @property
    def nu(self):
        # 使用 softplus 来确保 nu 始终为正
        # 这比直接优化 nu 更稳定
        return torch.exp(self.log_nu)

    def forward(self, x, t):
        # 拼接 x 和 t 以创建网络输入
        inputs = torch.cat([x, t], dim=1)
        return self.network(inputs)

    def compute_pde_residual(self, x, t):
        # 为输入设置 requires_grad=True 以计算导数
        x.requires_grad_(True)
        t.requires_grad_(True)
        
        u = self.forward(x, t)
        
        # 使用自动微分计算导数
        # create_graph=True 允许我们计算高阶导数
        u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]
        u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
        
        # 伯格斯方程残差: u_t + u * u_x - nu * u_xx = 0
        residual = u_t + u * u_x - self.nu * u_xx
        return residual

# -- 2. 主训练脚本 --
if __name__ == '__main__':
    # 设备设置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 数据加载和准备 ---
    try:
        data = np.load('burgers_shock_solution.npz')
        x_data = torch.tensor(data['x'], dtype=torch.float32) 
        t_data = torch.tensor(data['t'], dtype=torch.float32)
        u_solution = torch.tensor(data['u'], dtype=torch.float32)
    except FileNotFoundError:
        print("错误：找不到 'burgers_shock_solution.npz' 文件。")
        print("请确保该文件与脚本在同一目录下。")
        exit()

    # 创建用于训练的坐标网格
    T, X = torch.meshgrid(t_data.squeeze(), x_data.squeeze(), indexing='ij')
    
    # 准备训练数据三元组 (x, t, u)
    x_train = X.reshape(-1, 1)
    t_train = T.reshape(-1, 1)
    u_train = u_solution.reshape(-1, 1)
    
    # 将所有训练数据移动到选定的设备
    x_train = x_train.to(device)
    t_train = t_train.to(device)
    u_train = u_train.to(device)

    # --- 模型、损失和优化器设置 ---
    pinn_model = PINN_Burgers(layer_dims=[2, 20, 20, 20, 20, 1]).to(device)
    
    # 降低学习率以增加稳定性
    optimizer = torch.optim.Adam(pinn_model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()

    # --- 训练循环 ---
    epochs = 20000
    # 物理损失的权重，这是一个需要仔细调整的关键超参数
    lambda_physics = 1e-2 
    # 梯度裁剪的阈值
    clip_value = 1.0

    print("--- 开始训练 ---")
    for epoch in range(epochs):
        optimizer.zero_grad()
        
        # 1. 数据损失
        u_pred = pinn_model(x_train, t_train)
        data_loss = loss_fn(u_pred, u_train)
        
        # 2. 物理损失
        pde_residual = pinn_model.compute_pde_residual(x_train, t_train)
        physics_loss_raw = loss_fn(pde_residual, torch.zeros_like(pde_residual))
        
        # 组合损失
        total_loss = data_loss + lambda_physics * physics_loss_raw
        
        # 检查损失是否为 nan
        if torch.isnan(total_loss):
            print(f"错误：在第 {epoch+1} 个 epoch 损失变为 NaN。训练提前终止。")
            print("建议尝试进一步降低学习率或 lambda_physics 的值。")
            break
            
        total_loss.backward()
        
        # 在更新之前进行梯度裁剪，防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(pinn_model.parameters(), clip_value)
        
        optimizer.step()
        
        if (epoch + 1) % 1000 == 0:
            # 打印未加权的物理损失，以便更好地监控其量级
            print(f"Epoch [{epoch+1}/{epochs}], Total Loss: {total_loss.item():.6f}, "
                  f"Data Loss: {data_loss.item():.6f}, Physics Loss (raw): {physics_loss_raw.item():.6f}, "
                  f"Predicted nu: {pinn_model.nu.item():.5f}")

    print("\n--- 训练完成 ---")
    if not torch.isnan(total_loss):
        print(f"最终预测的粘性系数 nu 是: {pinn_model.nu.item():.5f}")
        print(f"（真实值为 0.07）")

