# 导入必要模块

In [None]:
import os
import torch
import time
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import data_processing as dp
from model import PhysicsInformedNN

np.random.seed(1234)
torch.manual_seed(1234)
matplotlib.use('TkAgg')
torch.autograd.set_detect_anomaly(True)
plt.rc('font', family='Times New Roman')
plt.rc('text', usetex=True)
plt.rc('grid', color='k', alpha=0.2)
current_path = os.getcwd()

# 设置参数，加载数据

In [None]:
# Configuration
epochs = 300
layers = [2, 50, 50, 50, 50, 1]
connections = [0, 1, 0, 1, 0, 1]
# Check CUDA availability (for GPU acceleration)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("========  Using device  ========")
print(f"============  {device}  ============")

# Load Data
option = 'npz'
filename = 'variables.npz'
# filename = 'test16-1.mat'
# filename = 'dvSave-2023_03_26_02_21_16.npy'
Timestamp, xEvent, yEvent, polarities = dp.load_data('npz', current_path, filename)
# Data Cleansing
fig = plt.figure()
dp.plot_data(
    fig.add_subplot(131, projection='3d'),
    xEvent, Timestamp, yEvent,
    title='Original Data', color=yEvent
)
(xEvent, Timestamp, yEvent, polarities) = dp.HotPixel_cleansing(xEvent, Timestamp, yEvent, polarities)
dp.plot_data(
    fig.add_subplot(132, projection='3d'),
    xEvent, Timestamp, yEvent,
    title='After HotPixel Cleansing', color=yEvent
)
(xEvent, Timestamp, yEvent) = dp.data_rotate(xEvent, Timestamp, yEvent, option='TLS')
dp.plot_data(
    fig.add_subplot(133, projection='3d'),
    xEvent, Timestamp, yEvent,
    title='After Data Rotation', color=yEvent
)

# Convert to torch.Tensor
xEvent = torch.tensor(
    xEvent,
    dtype=torch.float32,
    device=device,
    requires_grad=True
).unsqueeze(1)
Timestamp = torch.tensor(
    Timestamp,
    dtype=torch.float32,
    device=device,
    requires_grad=True
).unsqueeze(1)
yEvent = torch.tensor(
    yEvent,
    dtype=torch.float32,
    device=device,
    requires_grad=True
).unsqueeze(1)
print('====== Data Loading Done! ======')

# 初始化模型

In [None]:
USE_pth = False
print('===== Model Initialization =====')
pinn = PhysicsInformedNN(
    layers, connections, device,
    xEvent, Timestamp, yEvent,
    epochs
)
if USE_pth: pinn.load(torch.load('model.pth'))
print(pinn.dnn)