In [None]:
### package
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime, timedelta

In [None]:
### count params
import operator
from functools import reduce

# print the number of parameters
def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul, list(p.size()))
    return c

In [None]:
### read data function
def read_one_year(year, c_idx_list, timestep, datatype, 
                  height=157, width=103, keep_missing_hour=True):  
    start_time = datetime(year, 1, 1, 0)
    end_time = datetime(year + 1, 1, 1, 0)
    delta = timedelta(hours=timestep)

    data_tensor = []  
    mask_tensor = []
    current_time = start_time
    while current_time < end_time:
        timestamp = current_time.strftime('%Y%m%d%H')
        path = construct_file_path(year, timestamp, datatype)   # 建立檔案完整路徑
        grid_data = read_data_from_file(path, c_idx_list)   # 讀取檔案

        if grid_data is not None:
            valid_values = grid_data[torch.isfinite(grid_data)]    # 回傳既不是 NaN 或 ±Inf 的值
            if valid_values.numel() > 0 and torch.all(valid_values == valid_values[0]):   # 如果tensor的數值都一樣，視為異常數據
                data_tensor.append(torch.full((len(c_idx_list), height, width), float('nan')))
                mask_tensor.append(0.0)
                print(f"All valid values are the same in file: {path}. Replacing with NaN tensor.")
            else:
                data_tensor.append(grid_data)
                mask_tensor.append(1.0)
        else:
            if keep_missing_hour:
                data_tensor.append(torch.full((len(c_idx_list), height, width), float('nan')))
                mask_tensor.append(0.0)

        current_time += delta

    return torch.stack(data_tensor), torch.tensor(mask_tensor)    # [total hours, C, H, W]


def construct_file_path(year, timestamp, mode):
    if mode == "surfgrid": 
        return f"../PT_grid_data_{year}(hour)/surfgrid_RCEC_{timestamp}.pt"
    elif mode == "obs":
        return f"../PT_observation_{year}(hour)/observation_{timestamp}.pt"
    else:
        raise ValueError("Invalid mode! Must be 'surfgrid' or 'obs'.")


def read_data_from_file(path, c_idx):
    if os.path.exists(path):
        return torch.load(path, weights_only=True)[c_idx, ...]  # [C_selected, H, W]
    else:
        print(f"⚠️ Missing file: {path}")
        return None


def read_data(year_list: list,c_idx_list: list, timestep: int, 
              datatype: str, keep_missing_hour=True):
    """
    載入多個年份的資料，將其串接成一個大tensor，支援多通道輸入。
    return：data_tensor: [8760, C, H, W], mask_tensor: [8760]
    """
    all_data_tensors = []
    all_mask_tensors = []   # 遺失或是資料不正確為0

    for year in year_list:
        data_tensor, mask_tensor = read_one_year(
            int(year), c_idx_list, int(timestep), datatype, keep_missing_hour=keep_missing_hour
        )
        all_data_tensors.append(data_tensor)
        all_mask_tensors.append(mask_tensor)

    return torch.cat(all_data_tensors), torch.cat(all_mask_tensors)

In [None]:
### loading data tensor
train_data_tensor, train_mask_tensor = read_data(
    year_list=[2020, 2021, 2022],
    c_idx_list=[3, 5, 6, 7, 8], # (PM25, windspeed, winddir, K, humidity%)
    timestep=1,
    datatype="surfgrid",
)  # data_tensor: [T_total, C, H, W], mask_tensor: [T_total]

val_data_tensor, val_mask_tensor = read_data(
    year_list=[2023],
    c_idx_list=[3, 5, 6, 7, 8],
    timestep=1,
    datatype="surfgrid",
)  # data_tensor: [T_total, C, H, W], mask_tensor: [T_total]

print(train_data_tensor.shape)
print(val_data_tensor.shape)

In [None]:
### convert wind to u,v 
def convert_wind_to_uv(data_tensor: torch.Tensor, speed_idx: int, dir_idx: int):
    speed = data_tensor[:, speed_idx]
    direction = data_tensor[:, dir_idx]

    if torch.any(direction < 0):
        print("Skip: already converted to u/v.")
        return data_tensor

    theta_rad = direction * torch.pi / 180.0
    u = -speed * torch.sin(theta_rad)
    v = -speed * torch.cos(theta_rad)

    data_tensor[:, speed_idx] = u
    data_tensor[:, dir_idx] = v

    print("Wind converted to u/v.")
    return data_tensor

train_data_tensor = convert_wind_to_uv(train_data_tensor, speed_idx=1, dir_idx=2)
val_data_tensor = convert_wind_to_uv(val_data_tensor, speed_idx=1, dir_idx=2)

In [None]:
class TimeSeriesWindowDataset(Dataset):
    def __init__(self, data_tensor: torch.Tensor, 
                 mask_tensor: torch.Tensor, 
                 T_in: int, T: int, 
                 stride: int = 1,
                 add_lonlat: bool = False):
        """
        建立基於 sliding window 的時間序列 Dataset，並選擇性合併經緯度資訊。
        - stride: 時間步長 (每幾個時間點取一筆樣本)
        """
        self.data = data_tensor
        self.mask = mask_tensor
        self.T_in = T_in
        self.T = T
        self.stride = stride
        self.window_size = T_in + T
        self.add_lonlat = add_lonlat

        if add_lonlat:
            fixed_path = "C:/Users/kevin/Documents/新 科技部計畫/PT_grid_data_2023(hour)/surfgrid_RCEC_2023010100.pt"
            self.grid_lonlat = self._load_and_process_lonlat(fixed_path)
        else:
            self.grid_lonlat = None

        self.valid_indices = self._compute_valid_indices()

    def _load_and_process_lonlat(self, path: str) -> torch.Tensor:
        latlon_tensor = torch.load(path, weights_only=True)  # [C, H, W]
        lat = latlon_tensor[0]
        lon = latlon_tensor[1]

        def normalize(tensor):
            min_val = tensor.min()
            max_val = tensor.max()
            return 2 * (tensor - min_val) / (max_val - min_val) - 1

        lat = normalize(lat)
        lon = normalize(lon)
        return torch.stack([lon, lat], dim=0)  # [2, H, W]

    def _compute_valid_indices(self):
        valid_indices = []
        total_steps = self.data.shape[0]
        for i in range(0, total_steps - self.window_size + 1, self.stride):  
            window_mask = self.mask[i: i + self.window_size]
            if window_mask.all():
                valid_indices.append(i)
        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        start_idx = self.valid_indices[idx]
        end_idx = start_idx + self.window_size
        window = self.data[start_idx:end_idx]

        X_pm25 = window[:self.T_in, :1]     # PM25
        condition = window[self.T_in:, 1:]
        X = torch.cat([X_pm25, condition], dim=1)
        y = window[self.T_in:, :1]     # PM25

        T_in, C_in = X.shape[0], X.shape[1]
        X = X.reshape(T_in * C_in, *X.shape[2:])

        T_out, C_out = y.shape[0], y.shape[1]
        y = y.reshape(T_out * C_out, *y.shape[2:])

        if self.add_lonlat and self.grid_lonlat is not None:
            X = torch.cat([X, self.grid_lonlat], dim=0)

        return X, y

In [None]:
### build dataset/dataloader
T_in = 1
T = 1
stride = 2

train_dataset = TimeSeriesWindowDataset(train_data_tensor, train_mask_tensor, T_in, T, stride, add_lonlat=True)
val_dataset = TimeSeriesWindowDataset(val_data_tensor, val_mask_tensor, T_in, T, stride, add_lonlat=True)

n_train = len(train_dataset)
n_val = len(val_dataset)
print("訓練資料數量: ", n_train)
print("測試資料數量: ", n_val)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [None]:
### check dataloader
def check_dataloader_nan_and_range(dataloader, name=""):
    print(f"\n📦 檢查 DataLoader：{name}")
    for i, (batch_X, batch_y) in enumerate(dataloader):
        nan_X = torch.isnan(batch_X).any().item()
        nan_y = torch.isnan(batch_y).any().item()

        print(f"🔁 Batch {i+1}:")
        
        print(f" - X shape: {batch_X.shape}, 含 NaN: {nan_X}")
        if not nan_X:
            #  batch_X 形狀為 (B, T*C, H, W)
            C = batch_X.shape[1]
            for c in range(C):
                x_c = batch_X[:, c, :, :].flatten()
                min_val = x_c.min().item()
                max_val = x_c.max().item()
                median_val = x_c.median().item()
                print(f"   - 通道 {c}: min={min_val:.2f}, median={median_val:.2f}, max={max_val:.2f}")
        else:
            print(" - X 數值範圍: (含 NaN，略過 per-channel 統計)")

        print(f" - y shape: {batch_y.shape}, 含 NaN: {nan_y}")
        if not nan_y:
            #  batch_y 形狀為 (B, T*C, H, W)
            C = batch_y.shape[1]
            for c in range(C):
                y_c = batch_y[:, c, :, :].flatten()
                min_val = y_c.min().item()
                max_val = y_c.max().item()
                median_val = y_c.median().item()
                print(f"   - 通道 {c}: min={min_val:.2f}, median={median_val:.2f}, max={max_val:.2f}")
        else:
            print(" - y 數值範圍: (含 NaN，略過 per-channel 統計)")

# ✅ 執行檢查
check_dataloader_nan_and_range(train_loader, "Train")
check_dataloader_nan_and_range(val_loader, "Validation")

In [None]:
### build model
from neuralop.models import TFNO
model = TFNO(
    n_modes=(28, 18),     # 不能設太高，很容易overvitting
    hidden_channels=28,
    in_channels=7, 
    out_channels=1,
    positional_embedding=None,
    domain_padding=0.1,
    n_layers=3,
    factorization='tucker',
    implementation='factorized',
    rank=0.02    # 可以設低
)
model.cuda()
print("模型參數量: ", count_params(model))

In [None]:
### model summary
from torchinfo import summary

print(summary(model, input_size=(2, 7, 157, 103)))

In [None]:
### build optimizer
learning_rate = 1e-3
epochs = 100
iterations = epochs*(n_train//batch_size)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations, eta_min=1e-4)

In [None]:
### record loss / best val loss
best_val_loss = float('inf')
train_loss_history = []
val_loss_history = []

In [None]:
### Training
from timeit import default_timer as timer
from datetime import timedelta
from neuralop.losses import LpLoss

model_name = 'best_FNO_test'
loss_csv_path = 'loss_history_FNO_test.csv'
device = "cuda"
myloss = nn.MSELoss()    
#myloss = LpLoss(d=2, p=1) 
# def myloss(pred, y):
#     # 建立 mask，僅選擇 y > 35 的位置
#     mask = y > 12
    
#     # 若沒有符合條件的資料點，避免除以 0
#     if mask.sum() == 0:
#         return torch.tensor(0.0, device=y.device)
    
#     # 選取符合條件的 pred 與 y，計算 RMSE
#     mse = F.mse_loss(pred[mask], y[mask], reduction='mean')
#     rmse = torch.sqrt(mse)
#     return rmse

# 計時開始
start_time = timer()

# Early stopping 參數
patience = 20
early_stop_counter = 0

for ep in range(epochs):
    # ========== 訓練階段 ==========
    model.train()
    train_loss_accum = 0.0
    train_samples = 0

    for x, y in tqdm(train_loader, desc=f"Training Epoch {ep+1}/{epochs}"):
        x = x.to(device)
        y = y.to(device)

        pred = model(x)
        # loss = myloss(pred, y)
        T0_pm25 = x[:, :1]
        loss = myloss(pred+T0_pm25, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss_accum += loss.item() * x.size(0)
        train_samples += x.size(0)

    train_loss = train_loss_accum / train_samples
    train_loss_history.append(train_loss)

    print(f"Train Loss: {train_loss:.6f}")

    # ========== 驗證階段 ==========
    model.eval()
    val_loss_accum = 0.0
    val_samples = 0

    with torch.no_grad():
        for x, y in tqdm(val_loader, desc=f"Validation Epoch {ep+1}/{epochs}"):
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            #loss = myloss(pred, y)
            T0_pm25 = x[:, :1]
            loss = myloss(pred+T0_pm25, y)

            val_loss_accum += loss.item() * x.size(0)
            val_samples += x.size(0)

    val_loss = val_loss_accum / val_samples
    val_loss_history.append(val_loss)

    print(f"Valid Loss: {val_loss:.6f}")

    # 檢查是否為最佳模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0  # 重置 counter
        model.save_checkpoint(save_folder='checkpoint', save_name=model_name)
        print(f"💾 Best model saved (val_loss={val_loss:.6f})")
    else:
        early_stop_counter += 1
        print(f"⚠️ No improvement. Early stop counter: {early_stop_counter}/{patience}")

    # 每10個epoch儲存loss到CSV檔案
    if (ep + 1) % 10 == 0:
        # 創建DataFrame
        loss_df = pd.DataFrame({
            'Epoch': range(1, len(train_loss_history) + 1),
            'Train_Loss': train_loss_history,
            'Val_Loss': val_loss_history
        })
        
        # 儲存到CSV
        loss_df.to_csv(loss_csv_path, index=False)
        print(f"📊 Save loss history (up to epoch {ep+1})")

    # 若超過耐心次數則提前停止訓練
    if early_stop_counter >= patience:
        print(f"🛑 Early stopping triggered at epoch {ep+1}.")
        break

    # 結束一個 epoch
    print('-'*50)

# 計時結束
end_time = timer()
elapsed_time = end_time - start_time
formatted_time = str(timedelta(seconds=int(elapsed_time)))
print(f"⏱️ Total training time: {formatted_time}")

In [None]:
os.makedirs(os.path.dirname(loss_csv_path), exist_ok=True)

In [None]:
loss_df = pd.DataFrame({
    'epoch': list(range(1, len(train_loss_history)+1)),
    'train_loss': train_loss_history,
    'val_loss': val_loss_history
})
loss_df.to_csv('./project_structural_machine_learning/FNO_H38_W26_loss_history.csv', index=False)