In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

import torch
import torch.nn as nn
from torch.optim import LBFGS, Adam
from lib.PINN.Thermodynamics.Satellite.pinnsformer_sun.utils import *
from lib.PINN.Thermodynamics.Satellite.pinnsformer_sun.Pinnsformer import PINNsformer

# 设置随机种子以确保结果可复现
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# 检查设备是否支持CUDA
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

# 定义时间步长
step_size = 2e-2

# 加载数据
filename = '/data/datasets/satellite_heattransfer/data_7246.txt'
x_star, y_star, z_star, t_star, T_star = loading_evaluate_data(filename)

N = x_star.shape[0]

# 扩展数据维度
x_star = np.expand_dims(np.tile(x_star[:], (1)), -1)
y_star = np.expand_dims(np.tile(y_star[:], (1)), -1)
z_star = np.expand_dims(np.tile(z_star[:], (1)), -1)
t_star = make_time_sequence(t_star, num_step=1, step=step_size)

# 转换为PyTorch张量
x_star = torch.tensor(x_star, dtype=torch.float32, requires_grad=True).to(device)
y_star = torch.tensor(y_star, dtype=torch.float32, requires_grad=True).to(device)
z_star = torch.tensor(z_star, dtype=torch.float32, requires_grad=True).to(device)
t_star = torch.tensor(t_star, dtype=torch.float32, requires_grad=True).to(device)

# 加载模型
model = PINNsformer(d_out=1, d_hidden=512, d_model=64, N=1, heads=2).to(device)
load_params = torch.load('/data/checkpoints/pinnsformer_withsun.pt', map_location=device)
model.load_state_dict(load_params)

# 进行预测
T_pred = model(x_star, y_star, z_star, t_star)
T_pred = T_pred.cpu().detach().numpy()

# 重新调整预测数据的形状
T_pred_flattened = T_pred[:, 0, 0] 
T_pred = T_pred_flattened.reshape(-1)
x_star = x_star.cpu().detach().numpy()
x_star_flattened = x_star[:, 0, 0] 
x_star = x_star_flattened.reshape(-1)
y_star = y_star.cpu().detach().numpy()
y_star_flattened = y_star[:, 0, 0]
y_star = y_star_flattened.reshape(-1)
z_star = z_star.cpu().detach().numpy()
z_star_flattened = z_star[:, 0, 0] 
z_star = z_star_flattened.reshape(-1)
t_star = t_star.cpu().detach().numpy()
t_star_flattened = t_star[:, 0, 0]
t_star = t_star_flattened.reshape(-1)

# 加载评估数据
def loading_evaluate_data1(filename):
    data = pd.read_csv(filename, delimiter=' ', header=None)
    x = data.iloc[:, 0].to_numpy()
    y = data.iloc[:, 1].to_numpy()
    z = data.iloc[:, 2].to_numpy()
    t = data.iloc[:, 3].to_numpy()
    T = data.iloc[:, 4].to_numpy()
    return x, y, z, t, T

x, y, z, t, T = loading_evaluate_data1(filename)

# 获取时间步骤和唯一时间点
t_steps = len(np.unique(t))
times = np.unique(t)

# 创建图形
fig = plt.figure(figsize=(12, 8))
ax1 = fig.add_subplot(121, projection='3d')
ax2 = fig.add_subplot(122, projection='3d')

scat1 = None  
scat2 = None  

def show(frame):
    current_time = times[frame]
    mask = np.isclose(t, current_time, atol=1e-6)
    global scat1, scat2

    if scat1 is None:
        scat1 = ax1.scatter(x[mask], y[mask], z[mask], c=T[mask], cmap='viridis', s=100,
                            vmin=T.min(), vmax=T.max())
        fig.colorbar(scat1, ax=ax1, shrink=0.4, aspect=10, label='Temperature data')
    else:
        scat1._offsets3d = (x[mask], y[mask], z[mask])
        scat1.set_array(T[mask])
    
    if scat2 is None:
        scat2 = ax2.scatter(x_star[mask], y_star[mask], z_star[mask], c=T_pred[mask], cmap='viridis', s=100,
                            vmin=T_pred.min(), vmax=T_pred.max())
        fig.colorbar(scat2, ax=ax2, shrink=0.4, aspect=10, label='Temperature pred', location='right')
    else:
        scat2._offsets3d = (x_star[mask], y_star[mask], z_star[mask])
        scat2.set_array(T_pred[mask])

    ax1.set_title(f'Dataset, Time = {current_time:.2f}s')
    ax2.set_title(f'Predict, Time = {current_time:.2f}s')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    ax2.set_xlabel('X')
    ax2.set_ylabel('Y')
    ax2.set_zlabel('Z')

# 设置嵌入限制为 50 MB
plt.rcParams['animation.embed_limit'] = 50 * 1024 * 1024

# 创建动画
anim = FuncAnimation(fig, show, frames=len(times), interval=100, blit=False)

# 在 Jupyter 环境中显示动画
plt.close(fig)  # 避免显示静态图

# 保存动画为文件
anim.save('satellite_temperature_animation.gif', writer=PillowWriter(fps=60))

# 在 Jupyter Notebook 中嵌入动画文件
from IPython.display import Image
Image(filename='satellite_temperature_animation.gif')
