In [None]:
import os
import zipfile
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt

# ========== 1. 解壓資料 ==========
zip_path = '/content/processed_data.zip'
extract_path = '/content/data/processed/'
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)
print(f"processed_data.zip 已解壓縮到 {extract_path}")

# ========== 2. 載入資料 ==========
X = np.load(os.path.join(extract_path, 'X_train.npy'))
Y = np.load(os.path.join(extract_path, 'Y_train.npy'))
print(f"成功讀取：X形狀 {X.shape}，Y形狀 {Y.shape}")

X_tensor = torch.tensor(X, dtype=torch.float32)
Y_tensor = torch.tensor(Y, dtype=torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ========== 3. 清除 NaN ==========
mask = ~torch.isnan(X_tensor).any(dim=(1, 2)) & ~torch.isnan(Y_tensor).any(dim=1)
X_tensor = X_tensor[mask]
Y_tensor = Y_tensor[mask]
print(f"清除 NaN 後資料形狀：X {X_tensor.shape}, Y {Y_tensor.shape}")

# ========== 4. 訓練/測試切分 ==========
total_len = len(X_tensor)
train_len = int(0.8 * total_len)
test_len = total_len - train_len

full_dataset = TensorDataset(X_tensor, Y_tensor)
train_dataset, test_dataset = random_split(full_dataset, [train_len, test_len])
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

print(f"切分完成：訓練集 {train_len} 筆，測試集 {test_len} 筆")

# ========== 5. TimesNet（簡化版） ==========
class SimpleTimesNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleTimesNet, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear2(self.relu(self.linear1(x)))

input_size = X.shape[1] * X.shape[2]
hidden_size = 128
output_size = Y.shape[1]
model = SimpleTimesNet(input_size, hidden_size, output_size).to(device)
print("TimesNet 模型建立完成")

# ========== 6. 模型訓練（雨日加權） ==========
criterion = nn.MSELoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=0.0003)
epochs = 50
loss_history = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb)
        loss_raw = criterion(pred, yb)
        weight = (yb[:, 0] > 0).float() + 1.0
        loss = (loss_raw.mean(dim=1) * weight).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    loss_history.append(epoch_loss)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Weighted Loss: {epoch_loss:.6f}")

# ========== 7. 儲存模型與 loss ==========
os.makedirs('/content/models/', exist_ok=True)
torch.save(model.state_dict(), '/content/models/timesnet_model.pth')
np.save('/content/models/loss_history_timesnet.npy', np.array(loss_history))
print("模型與 Loss 已儲存")

# ========== 8. 評估與預測（使用測試集） ==========
model.eval()
pred_list, true_list = [], []

with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        pred_list.append(model(xb).cpu().numpy())
        true_list.append(yb.numpy())

pred = np.concatenate(pred_list, axis=0)
true = np.concatenate(true_list, axis=0)

mae = mean_absolute_error(true, pred)
rmse = np.sqrt(mean_squared_error(true, pred))
r2 = r2_score(true, pred)

print(f"\n測試集評估結果：MAE: {mae:.4f}, RMSE: {rmse:.4f}, R²: {r2:.4f}")

# ========== 9. 圖表儲存 ==========
# 9-1. Loss 曲線
plt.figure(figsize=(8, 5))
plt.plot(loss_history, label='Loss')
plt.title("TimesNet Training Loss Curve (Rain-weighted)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid()
plt.legend()
plt.savefig('/content/models/loss_curve_timesnet.png')
plt.show()

# 9-2. 預測圖（前200筆）
plt.figure(figsize=(12, 6))
plt.plot(true[:200, 0], label='Actual', linestyle='-')
plt.plot(pred[:200, 0], label='Predicted', linestyle='--')
plt.title('TimesNet: Actual vs Predicted Rainfall (Test Set)')
plt.xlabel('Time Step')
plt.ylabel('Precp')
plt.grid()
plt.legend()
plt.savefig('/content/models/timesnet_prediction_curve.png')
plt.show()

# 9-3. 誤差分布圖
error = true[:, 0] - pred[:, 0]
plt.figure(figsize=(8, 5))
plt.hist(error, bins=50, color='lightcoral', edgecolor='black')
plt.title("TimesNet Prediction Error Distribution (Test Set)")
plt.xlabel("Prediction Error")
plt.ylabel("Frequency")
plt.grid()
plt.savefig('/content/models/timesnet_error_hist.png')
plt.show()
