In [None]:
import pandas as pd
import torch
import random
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

df = pd.read_csv("../data/csr_embeddings_leq2019_citation_count_cpc_lagged.csv")

## Seed

In [None]:
def set_seed(seed=42):
    random.seed(seed) # 控制 Python 內建的 random 模組（例如 random.shuffle()、random.randint()）。
    np.random.seed(seed) # 控制 NumPy 所有隨機操作（例如 np.random.rand()、np.random.shuffle()）。
    torch.manual_seed(seed) # 控制 CPU 上 PyTorch 的隨機性（例如 torch.rand()）。
    torch.cuda.manual_seed(seed)  # for CUDA
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    torch.backends.cudnn.deterministic = True  # 強制使用 deterministic（確定性）版本的 CuDNN 操作，避免某些 kernel 造成非一致性輸出。
    torch.backends.cudnn.benchmark = False     # 關閉 CuDNN 根據輸入自動選最佳演算法的機制，避免因選到不同 kernel 而有不同結果。

set_seed(42)

In [None]:
df

## Dataset

In [None]:
class SimpleCSRDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.x = torch.tensor(df[[f'dim_{i}' for i in range(1024)]].values, dtype=torch.float32)
        y = df[['patents_count', 'total_5yr_forward_citations']].values
        self.y = torch.log1p(torch.tensor(y, dtype=torch.float32))
        self.indexes = df.index.tolist()
        self.years = df["year"].values

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx], self.indexes[idx], self.years[idx]

## Model

In [None]:
class MLP_MTL(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=256, head_dim=32):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, head_dim),
            nn.ReLU()
        )
        self.head_count = nn.Sequential(nn.Linear(head_dim, 1), nn.ReLU())
        self.head_citation = nn.Sequential(nn.Linear(head_dim, 1), nn.ReLU())

    def forward(self, x):  # x: [B, 1024]
        shared_out = self.shared(x)
        y1 = self.head_count(shared_out)
        y2 = self.head_citation(shared_out)
        return torch.cat([y1, y2], dim=1)  # [B, 2]

## Evaluation metrics

In [None]:
# --- Metrics (expm1 to reverse log1p) ---
def compute_metrics(y_pred, y_true):
    y_pred = torch.expm1(y_pred).cpu().numpy()
    y_true = torch.expm1(y_true).cpu().numpy()
    metrics = {}

    for i, name in enumerate(["count", "citation"]):
        y_p, y_t = y_pred[:, i], y_true[:, i]
        mse = np.mean((y_p - y_t) ** 2)
        mae = np.mean(np.abs(y_p - y_t))
        rmse = np.sqrt(mse)
        smape = np.mean(2 * np.abs(y_p - y_t) / (np.abs(y_p) + np.abs(y_t) + 1e-8))*100
        metrics[name] = {"MSE": mse, "MAE": mae, "RMSE": rmse, "SMAPE": smape}
    return metrics

## Loss function

In [None]:
class MSEPunishLoss(nn.Module):
    def __init__(self, alpha=1.0, punish_mode='', return_each=False):
        """
        :param alpha: 懲罰項權重
        :param punish_mode: 'abs', 'square', 'binary', 'huber', 'ratio', 'threshold'
        :param return_each: 是否回傳 count loss 和 citation loss
        """
        super().__init__()
        self.mse = nn.MSELoss()
        self.alpha = alpha
        self.punish_mode = punish_mode
        self.return_each = return_each

    def forward(self, preds, targets):
        pred_count, pred_citation = preds[:, 0], preds[:, 1]
        true_count, true_citation = targets[:, 0], targets[:, 1]

        loss_count = self.mse(pred_count, true_count)
        loss_citation = self.mse(pred_citation, true_citation)
        loss_total = loss_count + loss_citation

        # Initialize punishment
        punishment = 0.0

        if self.punish_mode:
            mask = (true_count == 0)
            if mask.sum() > 0:
                if self.punish_mode == 'abs':
                    punishment = torch.sum(torch.abs(pred_citation[mask]))
                elif self.punish_mode == 'square':
                    punishment = torch.sum((pred_citation[mask]) ** 2)
                elif self.punish_mode == 'binary':
                    punishment = torch.sum((pred_citation[mask] > 1e-4).float())
                elif self.punish_mode == 'huber':
                    punishment = nn.functional.smooth_l1_loss(
                        pred_citation[mask],
                        torch.zeros_like(pred_citation[mask]),
                        reduction='sum'
                    )
                elif self.punish_mode == 'ratio':
                    safe_count = true_count + 1e-6
                    ratio = pred_citation / safe_count
                    punishment = torch.sum(ratio[mask])
                elif self.punish_mode == 'threshold':
                    threshold = 1.0
                    over = torch.clamp(pred_citation[mask] - threshold, min=0)
                    punishment = torch.sum(over)
                else:
                    raise ValueError(f"Unsupported punish_mode: {self.punish_mode}")

        loss_total = loss_total + self.alpha * punishment

        if self.return_each:
            return loss_total, loss_count.item(), loss_citation.item()
        else:
            return loss_total

## Fold

In [None]:
df = df.sort_values(by=['year', 'ticker']).reset_index(drop=True)

min_year = df['year'].min()
max_year = df['year'].max()

folds = []
min_train_years = 5

# expanding window loop
for train_length in range(min_train_years, max_year - min_year + 1):
    train_years = list(range(min_year, min_year + train_length))
    val_year = min_year + train_length  # 只取下一年做驗證

    train_df = df[df['year'].isin(train_years)].copy()
    val_df = df[df['year'] == val_year].copy()

    if not train_df.empty and not val_df.empty:
        folds.append((train_df, val_df))

## Evaluate best model

In [None]:
def evaluate_best_model(model, val_loader, fold_id, best_epoch):
    """
    回傳最佳模型在驗證集上的預測、標籤與 meta info（用於 CSV 儲存）
    """
    model.eval()
    all_preds = []
    all_y = []
    val_meta_info = []

    with torch.no_grad():
        for batch_id, batch in enumerate(val_loader):
            x_seqs, y_seqs, indices, years = zip(*batch)

            x_seqs = torch.stack(x_seqs).to(next(model.parameters()).device)
            y_seqs = torch.stack(y_seqs).to(next(model.parameters()).device)

            preds = model(x_seqs)

            all_preds.append(preds.cpu())
            all_y.append(y_seqs.cpu())

            for i, (index, year) in enumerate(zip(indices, years)):
                val_meta_info.append({
                    "fold": fold_id + 1,
                    "epoch": best_epoch + 1,
                    "batch": batch_id + 1,
                    "sample": i,
                    "index": index,
                    "year": year
                })

    return [torch.cat(all_preds)], [torch.cat(all_y)], [val_meta_info]

## Training

In [None]:
# --- Training Config ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

epochs = 100
learning_rate = 1e-4

# 初始化儲存所有 fold 的 loss
all_train_losses_total, all_train_losses_count, all_train_losses_citation = [], [], []
all_train_preds_total, all_train_targets_total, all_train_meta_info_total = [], [], []

all_val_losses_count, all_val_losses_citation ,all_val_losses_total = [], [], []
all_val_preds_total, all_val_targets_total, all_val_meta_info_total = [], [], []

best_epoch_list = []

for fold_id, (train_df, val_df) in enumerate(folds):
    print(f"\n====== Fold {fold_id+1} ({train_df['year'].min()}–{train_df['year'].max()} → {val_df['year'].iloc[0]}) ======")

    fold_train_preds_epochs, fold_train_targets_epochs, fold_train_meta_epochs = [], [], []
    fold_val_preds_epochs, fold_val_targets_epochs, fold_val_meta_epochs = [], [], []
    
    train_dataset = SimpleCSRDataset(train_df)
    val_dataset = SimpleCSRDataset(val_df)

    generator = torch.Generator().manual_seed(42)

    # DataLoader
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=lambda x: x, generator=generator)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=lambda x: x)

    # Model
    model = MLP_MTL().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = MSEPunishLoss(alpha=1.0, punish_mode='', return_each=True)

    # 儲存每個 epoch 的 loss（每 fold 一份）
    train_loss_total_list, train_loss_count_list, train_loss_citation_list = [], [], []
    val_loss_total_list, val_loss_count_list, val_loss_citation_list = [], [], []

    # <-- Early Stopping 初始化 -->
    best_val_loss = float('inf')
    epochs_no_improve = 0
    patience = 50
    best_model_state = None
    best_epoch = None  

    for epoch in range(epochs):
        model.train()
        epoch_losses, epoch_loss_count_all, epoch_loss_citation_all = [], [], []
        train_preds, train_targets, train_meta_info = [], [], []
        for batch_id, batch in enumerate(train_loader):
            x, y, idxs, years = zip(*batch)
            x = torch.stack(x).to(device)
            y = torch.stack(y).to(device)

            preds = model(x)

            loss_total, loss_count, loss_citation = criterion(preds, y)
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step()

            epoch_losses.append(loss_total.item())
            epoch_loss_count_all.append(loss_count)
            epoch_loss_citation_all.append(loss_citation)

            train_preds.append(preds.detach().cpu())
            train_targets.append(y.detach().cpu())

            for i, index in enumerate(idxs):
                train_meta_info.append({
                    "fold": fold_id + 1,
                    "epoch": epoch + 1,
                    "batch": batch_id + 1,
                    "sample": i,
                    "index": index
                })

            # # ✔ 每個 batch 印 loss、predictions、truth
            # print(f"[Fold {fold_id+1}][Epoch {epoch+1}][Train][Batch] Company: {tickers[0]} | Samples: {len(indices)}")
            # print(f"  Total Loss: {loss_total.item():.4f} | Count Loss: {loss_count:.4f} | Citation Loss: {loss_citation:.4f}")
            # print(f"  Predictions (expm1, rounded): {np.round(torch.expm1(preds).detach().cpu().numpy())}")
            # print(f"  Ground truth (expm1, rounded): {np.round(torch.expm1(y).detach().cpu().numpy())}")
            # print(f"  Indices: {indices}")

        avg_train_loss = np.mean(epoch_losses)
        avg_train_count = np.mean(epoch_loss_count_all)
        avg_train_citation = np.mean(epoch_loss_citation_all)

        train_loss_total_list.append(avg_train_loss)
        train_loss_count_list.append(avg_train_count)
        train_loss_citation_list.append(avg_train_citation)

        fold_train_meta_epochs.append(train_meta_info)
        fold_train_preds_epochs.append(torch.cat(train_preds))
        fold_train_targets_epochs.append(torch.cat(train_targets))

        # Validation
        model.eval()
        with torch.no_grad():
            val_losses = []
            val_loss_count_all = []
            val_loss_citation_all = []
            all_preds = []
            all_y = []
            val_meta_info = []  # <--- 新增：本 epoch 的 meta 資訊
            for batch_id, batch in enumerate(val_loader):
                x, y, idxs, years = zip(*batch)
                x = torch.stack(x).to(device)
                y = torch.stack(y).to(device)

                preds = model(x)
                loss_total, loss_count, loss_citation = criterion(preds, y)
                
                val_losses.append(loss_total.item())
                val_loss_count_all.append(loss_count)
                val_loss_citation_all.append(loss_citation)

                all_preds.append(preds)
                all_y.append(y)

                # 💾 新增：記錄每一筆的 fold / epoch / batch / sample index
                for i, index in enumerate(idxs):
                    val_meta_info.append({
                        "fold": fold_id + 1,
                        "epoch": epoch + 1,
                        "batch": batch_id + 1,
                        "sample": i,
                        "index": index  # 原始 index 儲存以便後續追蹤
                    })

                print(f"[Fold {fold_id+1}][Epoch {epoch+1}][Val][Batch] | Samples: {len(x)}")
                print(f"  Total Loss: {loss_total.item():.4f} | Count Loss: {loss_count:.4f} | Citation Loss: {loss_citation:.4f}")
                print(f"  Predictions (expm1, rounded): {np.round(torch.expm1(preds).detach().cpu().numpy())}")
                print(f"  Ground truth (expm1, rounded): {np.round(torch.expm1(y).detach().cpu().numpy())}")
                print(f"  Indices: {idxs}")

            fold_val_preds_epochs.append(torch.cat(all_preds).cpu())
            fold_val_targets_epochs.append(torch.cat(all_y).cpu())
            fold_val_meta_epochs.append(val_meta_info)

            avg_val_loss = np.mean(val_losses)
            avg_val_count = np.mean(val_loss_count_all)
            avg_val_citation = np.mean(val_loss_citation_all)

            val_loss_total_list.append(avg_val_loss)
            val_loss_count_list.append(avg_val_count)
            val_loss_citation_list.append(avg_val_citation)

            all_preds = torch.cat(all_preds)
            all_y = torch.cat(all_y)

            # ✔ 每個 epoch 最後 summary
            print(f"[Fold {fold_id+1}][Epoch {epoch+1}] Summary")
            print(f"  Train Avg Loss     -> Total: {avg_train_loss:.4f} | Count: {avg_train_count:.4f} | Citation: {avg_train_citation:.4f}")
            print(f"  Validation Avg Loss-> Total: {avg_val_loss:.4f} | Count: {avg_val_count:.4f} | Citation: {avg_val_citation:.4f}")

        metrics = compute_metrics(all_preds.cpu(), all_y.cpu())
        for name, vals in metrics.items():
            print(f"  [{name.upper()}] Metrics -> MSE: {vals['MSE']:.4f} | MAE: {vals['MAE']:.4f} | "
                f"RMSE: {vals['RMSE']:.4f} | SMAPE: {vals['SMAPE']:.2f}%")

        # <-- Early Stopping 判斷 -->
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
            best_epoch = epoch
            print(f"  ✅ Validation loss improved. Saving model.")
        else:
            epochs_no_improve += 1
            print(f"  ❌ No improvement. Patience: {epochs_no_improve}/{patience}")

        if epochs_no_improve >= patience:
            print(f"  ⏹ Early stopping triggered at epoch {epoch+1}. Best val_loss: {best_val_loss:.4f}")
        if best_epoch is not None:
            print(f"  🔄 Best model loaded from epoch {best_epoch+1}")
        else:
            print(f"  🔄 Best model is not found (no improvement over {epochs} epochs)")

    if best_model_state is not None:
        # 有早停（代表中途有最佳模型）
        model.load_state_dict(best_model_state)
        print("  🔄 Best model loaded for final evaluation or saving.")
        fold_val_preds_epochs, fold_val_targets_epochs, fold_val_meta_epochs = evaluate_best_model(
            model, val_loader, fold_id, best_epoch
        )
    else:
        # 沒有早停（跑滿 epochs），此時就保留最後一輪模型，照樣 evaluate
        print("  ⚠️ No best model found. Using final epoch model for evaluation.")
        fold_val_preds_epochs, fold_val_targets_epochs, fold_val_meta_epochs = evaluate_best_model(
            model, val_loader, fold_id, epoch  # 此時 epoch 就是最後一輪 index
        )

    # 每個 fold 儲存進總表中
    all_train_losses_total.append(train_loss_total_list)
    all_train_losses_count.append(train_loss_count_list)
    all_train_losses_citation.append(train_loss_citation_list)

    all_train_preds_total.append(fold_train_preds_epochs)
    all_train_targets_total.append(fold_train_targets_epochs)
    all_train_meta_info_total.append(fold_train_meta_epochs)

    all_val_losses_total.append(val_loss_total_list)
    all_val_losses_count.append(val_loss_count_list)
    all_val_losses_citation.append(val_loss_citation_list)

    all_val_preds_total.append(fold_val_preds_epochs)
    all_val_targets_total.append(fold_val_targets_epochs)
    all_val_meta_info_total.append(fold_val_meta_epochs)

    best_epoch_list.append(best_epoch + 1 if best_epoch is not None else epochs)

## compute_sample_metrics

In [None]:
def compute_sample_metrics(y_true, y_pred):
    """
    y_true, y_pred: shape [N, 2] numpy array, already expm1'd
    回傳 dict，包含每筆樣本的 MSE, MAE, RMSE, SMAPE for count & citation
    """
    abs_error = np.abs(y_pred - y_true)
    squared_error = (y_pred - y_true) ** 2
    rmse_error = np.sqrt(squared_error)
    smape_error = 2 * abs_error / (np.abs(y_pred) + np.abs(y_true) + 1e-8)

    return {
        "count_MSE": squared_error[:, 0],
        "count_MAE": abs_error[:, 0],
        "count_RMSE": rmse_error[:, 0],
        "count_SMAPE": smape_error[:, 0],
        "citation_MSE": squared_error[:, 1],
        "citation_MAE": abs_error[:, 1],
        "citation_RMSE": rmse_error[:, 1],
        "citation_SMAPE": smape_error[:, 1],
    }

# Save validation result to csv

In [None]:
val_summary_records, train_summary_records = [], []

for fold_id in range(len(folds)):
    # === Validation ===
    best_epoch = best_epoch_list[fold_id] - 1
    fold_preds = torch.expm1(all_val_preds_total[fold_id][0].cpu()).numpy()
    fold_targets = torch.expm1(all_val_targets_total[fold_id][0].cpu()).numpy()
    fold_meta_info = all_val_meta_info_total[fold_id][0]

    metrics = compute_sample_metrics(fold_targets, fold_preds)

    for i, meta in enumerate(fold_meta_info):
        val_summary_records.append({
            "fold": meta["fold"],
            "epoch": meta["epoch"],
            "batch": meta["batch"],
            "local_sample": meta["sample"],
            "index": meta["index"],
            "file": df.iloc[meta["index"]]["file_name"],
            "val_true_count": round(fold_targets[i, 0], 4),
            "val_pred_count": round(fold_preds[i, 0], 4),
            "val_true_citation": round(fold_targets[i, 1], 4),
            "val_pred_citation": round(fold_preds[i, 1], 4),
            "val_count_MSE": round(metrics["count_MSE"][i], 4),
            "val_count_MAE": round(metrics["count_MAE"][i], 4),
            "val_count_SMAPE": round(metrics["count_SMAPE"][i], 4),
            "val_citation_MSE": round(metrics["citation_MSE"][i], 4),
            "val_citation_MAE": round(metrics["citation_MAE"][i], 4),
            "val_citation_SMAPE": round(metrics["citation_SMAPE"][i], 4),
            # train 留空
            "train_true_count": np.nan,
            "train_pred_count": np.nan,
            "train_true_citation": np.nan,
            "train_pred_citation": np.nan,
            "train_loss_total": np.nan,
            "train_loss_count": np.nan,
            "train_loss_citation": np.nan,
            "train_count_MSE": np.nan,
            "train_count_MAE": np.nan,
            "train_count_SMAPE": np.nan,
            "train_citation_MSE": np.nan,
            "train_citation_MAE": np.nan,
            "train_citation_SMAPE": np.nan,
        })

    # === Training ===
    for epoch in range(len(all_train_preds_total[fold_id])):
        fold_preds = torch.expm1(all_train_preds_total[fold_id][epoch]).numpy()
        fold_targets = torch.expm1(all_train_targets_total[fold_id][epoch]).numpy()
        fold_meta_info = all_train_meta_info_total[fold_id][epoch]

        metrics = compute_sample_metrics(fold_targets, fold_preds)

        for i, meta in enumerate(fold_meta_info):
            train_summary_records.append({
                "fold": meta["fold"],
                "epoch": meta["epoch"],
                "batch": meta["batch"],
                "local_sample": meta["sample"],
                "index": meta["index"],
                "file": df.iloc[meta["index"]]["file_name"],
                "year": df.iloc[meta["index"]]["year"],
                # val 留空
                "val_true_count": np.nan,
                "val_pred_count": np.nan,
                "val_true_citation": np.nan,
                "val_pred_citation": np.nan,
                "val_loss_total": np.nan,
                "val_loss_count": np.nan,
                "val_loss_citation": np.nan,
                "val_count_MSE": np.nan,
                "val_count_MAE": np.nan,
                "val_count_SMAPE": np.nan,
                "val_citation_MSE": np.nan,
                "val_citation_MAE": np.nan,
                "val_citation_SMAPE": np.nan,
                # train 留資料
                "train_true_count": round(fold_targets[i, 0], 4),
                "train_pred_count": round(fold_preds[i, 0], 4),
                "train_true_citation": round(fold_targets[i, 1], 4),
                "train_pred_citation": round(fold_preds[i, 1], 4),
                "train_loss_total": round(all_train_losses_total[fold_id][epoch], 4),
                "train_loss_count": round(all_train_losses_count[fold_id][epoch], 4),
                "train_count_MSE": round(metrics["count_MSE"][i].item(), 4),
                "train_count_MAE": round(metrics["count_MAE"][i].item(), 4),
                "train_count_SMAPE": round(metrics["count_SMAPE"][i].item(), 4),
                "train_citation_MSE": round(metrics["citation_MSE"][i].item(), 4),
                "train_citation_MAE": round(metrics["citation_MAE"][i].item(), 4),
                "train_citation_SMAPE": round(metrics["citation_SMAPE"][i].item(), 4),
            })

# 儲存 CSV
val_summary_df = pd.DataFrame(val_summary_records)
train_summary_df = pd.DataFrame(train_summary_records)
combined_df = pd.concat([val_summary_df, train_summary_df], ignore_index=True)
val_summary_df.to_csv("../output/mlp_mtl_v8_expanding_log_val_detailed.csv", index=False)
combined_df.to_csv("../output/mlp_mtl_v8_expanding_log_train_val_detailed.csv", index=False)
print("✅ Saved detailed validation results to 'train_val_detailed.csv'")

In [None]:
print("\n====== Final Summary: Each Fold's Avg Loss (per Y) ======")
for fold_id in range(len(folds)):
    avg_train_count = np.mean(all_train_losses_count[fold_id])
    avg_train_citation = np.mean(all_train_losses_citation[fold_id])
    avg_val_count = np.mean(all_val_losses_count[fold_id])
    avg_val_citation = np.mean(all_val_losses_citation[fold_id])
    
    print(f"[Fold {fold_id+1}]")
    print(f"  Train  -> Count Loss: {avg_train_count:.4f} | Citation Loss: {avg_train_citation:.4f}")
    print(f"  Valid  -> Count Loss: {avg_val_count:.4f} | Citation Loss: {avg_val_citation:.4f}")

# 全部 fold 加總平均
mean_train_count = np.mean([np.mean(x) for x in all_train_losses_count])
mean_train_citation = np.mean([np.mean(x) for x in all_train_losses_citation])
mean_val_count = np.mean([np.mean(x) for x in all_val_losses_count])
mean_val_citation = np.mean([np.mean(x) for x in all_val_losses_citation])

print("\n====== Final Overall Avg Loss Across 5 Folds ======")
print(f"Train  -> Count Loss: {mean_train_count:.4f} | Citation Loss: {mean_train_citation:.4f}")
print(f"Valid  -> Count Loss: {mean_val_count:.4f} | Citation Loss: {mean_val_citation:.4f}")


In [None]:
import numpy as np
import torch

# 合併所有 fold 的預測與 meta 資訊
y_pred_all = torch.cat([torch.cat(fold) for fold in all_val_preds_total])
y_true_all = torch.cat([torch.cat(fold) for fold in all_val_targets_total])
meta_all = [meta for fold in all_val_meta_info_total for epoch in fold for meta in epoch]

# 還原到原始空間
y_pred = torch.expm1(y_pred_all).numpy()
y_true = torch.expm1(y_true_all).numpy()

# 計算每筆樣本的誤差
abs_error = np.abs(y_pred - y_true)
squared_error = (y_pred - y_true) ** 2
rmse_error = np.sqrt(squared_error)
smape = 2 * abs_error / (np.abs(y_pred) + np.abs(y_true) + 1e-8)

# 印出所有資料（可改成 range(50) 減少輸出）
print("\n===== Detailed Per-Sample Metric with Meta Info =====")
for i in range(len(y_true)):
    meta = meta_all[i]
    print(f"\n[Sample {i}] [Fold {meta['fold']}][Epoch {meta['epoch']}][Batch {meta['batch']}][Local Sample {meta['sample']}] Index: {meta['index']}")
    print(f"  COUNT    -> True: {y_true[i, 0]:.2f}, Pred: {y_pred[i, 0]:.2f}")
    print(f"              MSE: {squared_error[i, 0]:.4f}, MAE: {abs_error[i, 0]:.4f}, RMSE: {rmse_error[i, 0]:.4f}, SMAPE: {smape[i, 0]*100:.2f}%")
    print(f"  CITATION -> True: {y_true[i, 1]:.2f}, Pred: {y_pred[i, 1]:.2f}")
    print(f"              MSE: {squared_error[i, 1]:.4f}, MAE: {abs_error[i, 1]:.4f}, RMSE: {rmse_error[i, 1]:.4f}, SMAPE: {smape[i, 1]*100:.2f}%")


# Draw

In [None]:
import matplotlib.pyplot as plt

def plot_loss_curve_no_legend(train_losses, val_losses, ylabel, title, color_cycle=None):
    """
    畫出主圖（訓練與驗證 loss 曲線），不含 legend。
    回傳所有 line 的 handle，供 legend 使用。
    """
    num_folds = len(train_losses)
    fig, ax = plt.subplots(figsize=(8, 5))

    handles = []
    for fold_id in range(num_folds):
        train_label = f"Train - Fold {fold_id+1}"
        val_label = f"Val - Fold {fold_id+1}"

        train_line, = ax.plot(train_losses[fold_id], label=train_label)
        val_line, = ax.plot(val_losses[fold_id], linestyle='--', label=val_label)

        handles.extend([train_line, val_line])

    ax.set_xlabel("Epoch")
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.grid(True)
    plt.tight_layout()
    plt.show()

    return handles


def plot_legend_only(handles, ncol=4, fontsize='small', figsize=(12, 2)):
    """
    單獨畫 legend（圖例），供密集曲線使用者閱讀。
    - handles: 來自 plot 的 line handles
    - ncol: 圖例欄數
    """
    fig_legend = plt.figure(figsize=figsize)
    fig_legend.legend(handles=handles,
                      loc='center',
                      ncol=ncol,
                      frameon=False,
                      fontsize=fontsize)
    plt.axis('off')
    plt.tight_layout()
    plt.show()


def plot_all_loss_curves_with_separate_legend(train_losses_count, val_losses_count,
                                               train_losses_citation, val_losses_citation):
    """
    一次畫出 Count & Citation 的主圖與圖例。
    """
    print("🎨 Drawing Count Loss curve...")
    count_handles = plot_loss_curve_no_legend(
        train_losses=train_losses_count,
        val_losses=val_losses_count,
        ylabel="Count Loss",
        title="Count Loss Learning Curve (Train vs. Validation)"
    )
    plot_legend_only(count_handles, ncol=4)

    print("🎨 Drawing Citation Loss curve...")
    citation_handles = plot_loss_curve_no_legend(
        train_losses=train_losses_citation,
        val_losses=val_losses_citation,
        ylabel="Citation Loss",
        title="Citation Loss Learning Curve (Train vs. Validation)"
    )
    plot_legend_only(citation_handles, ncol=4)

In [None]:
plot_all_loss_curves_with_separate_legend(
    all_train_losses_count, all_val_losses_count,
    all_train_losses_citation, all_val_losses_citation
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# 合併所有驗證資料
y_pred_all = torch.cat([tensor for fold in all_val_preds_total for tensor in fold])
y_true_all = torch.cat([tensor for fold in all_val_targets_total for tensor in fold])

# 還原 log1p
y_pred_all = np.expm1(y_pred_all)
y_true_all = np.expm1(y_true_all)

# Count 圖
plt.figure(figsize=(6, 6))
plt.scatter(y_true_all[:, 0], y_pred_all[:, 0], alpha=0.5)
plt.plot([y_true_all[:, 0].min(), y_true_all[:, 0].max()],
         [y_pred_all[:, 0].min(), y_pred_all[:, 0].max()], 'r--')
plt.xlabel("Ground Truth (Count)")
plt.ylabel("Predictions (Count)")
plt.title("Predictions vs Ground Truth (Count) – All Folds")
plt.grid(True)
plt.show()

# Citation 圖
plt.figure(figsize=(6, 6))
plt.scatter(y_true_all[:, 1], y_pred_all[:, 1], alpha=0.5)
plt.plot([y_true_all[:, 1].min(), y_true_all[:, 1].max()],
         [y_pred_all[:, 1].min(), y_pred_all[:, 1].max()], 'r--')
plt.xlabel("Ground Truth (Citation)")
plt.ylabel("Predictions (Citation)")
plt.title("Predictions vs Ground Truth (Citation) – All Folds")
plt.grid(True)
plt.show()


In [None]:
print(f"Total number of samples: {y_true_all.shape[0]}")