In [None]:
import pandas as pd
import torch
import random
import numpy as np
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import KFold, GroupKFold
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.utils.data import Sampler
from sklearn.preprocessing import StandardScaler

df = pd.read_csv("../data/csr_embeddings_leq2019_citation_count_cpc_lagged.csv")

## Seed

In [None]:
def set_seed(seed=42):
    random.seed(seed) # 控制 Python 內建的 random 模組（例如 random.shuffle()、random.randint()）。
    np.random.seed(seed) # 控制 NumPy 所有隨機操作（例如 np.random.rand()、np.random.shuffle()）。
    torch.manual_seed(seed) # 控制 CPU 上 PyTorch 的隨機性（例如 torch.rand()）。
    torch.cuda.manual_seed(seed)  # for CUDA
    torch.cuda.manual_seed_all(seed)  # for multi-GPU

    torch.backends.cudnn.deterministic = True  # 強制使用 deterministic（確定性）版本的 CuDNN 操作，避免某些 kernel 造成非一致性輸出。
    torch.backends.cudnn.benchmark = False     # 關閉 CuDNN 根據輸入自動選最佳演算法的機制，避免因選到不同 kernel 而有不同結果。

set_seed(42)

In [None]:
df

## Dataset

In [None]:
class SlidingWindowDatasetDynamic(torch.utils.data.Dataset):
    def __init__(self, df, company_list, max_len=5):
        self.samples = []
        self.max_len = max_len

        # 只保留指定公司
        df = df[df['ticker'].isin(company_list)].copy()

        for company_id, group in df.groupby('ticker'):
            group = group.sort_values('year').reset_index(drop=True)

            years = group['year'].values
            embeddings = group[[f'dim_{i}' for i in range(1024)]].values
            counts = group['patents_count'].values
            citations = group['total_5yr_forward_citations'].values

            # log-transform label
            y_all = np.log1p(np.vstack([counts, citations]).T)  # shape: (N, 2)

            for end in range(len(group)):
                start = max(0, end - max_len + 1)
                x_seq = embeddings[start:end+1]
                y_seq = y_all[end]
                y_year = years[end]

                self.samples.append({
                    'x_seq': torch.tensor(x_seq, dtype=torch.float32),  # 動態長度 tensor
                    'y_seq': torch.tensor(y_seq, dtype=torch.float32),
                    'ticker': company_id,
                    'y_year': y_year,
                    'index': f"{company_id}_{years[start]}_{y_year}"
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        return s['x_seq'], s['y_seq'], s['ticker'], s['y_year'], s['index']

## Model

In [None]:
class Attention_MTL_Stepwise(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=256, head_dim=32):
        super().__init__()

        # Attention Projection Layers
        self.key_layer = nn.Linear(input_dim, hidden_dim)
        self.query_layer = nn.Linear(input_dim, hidden_dim)
        self.value_layer = nn.Linear(input_dim, hidden_dim)

        # Shared Sequential Layer
        self.shared_head = nn.Sequential(
            nn.Linear(hidden_dim, head_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
        )
        
        # Head for patents_count
        self.head_count = nn.Sequential(
            nn.Linear(head_dim, 1),
            nn.ReLU()
        )
        
        # Head for citations
        self.head_citation = nn.Sequential(
            nn.Linear(head_dim, 1),
            nn.ReLU()
        )

    def forward(self, x, lengths):
        B, T_max, D = x.shape

        # 預先取出 Query 位置 (最後一個有效時間步)
        batch_indices = torch.arange(B, device=x.device)
        last_indices = lengths - 1  # [B]
        x_q = x[batch_indices, last_indices, :]  # [B, D]

        # Q from last step
        Q = self.query_layer(x_q).unsqueeze(1)  # [B, 1, H]
        
        # 針對每筆樣本，取出前 T-1 步做 K/V
        K_all = self.key_layer(x)   # [B, T_max, H]
        V_all = self.value_layer(x) # [B, T_max, H]

        # 為了符合每筆樣本長度，mask 超過有效長度的部份 (padding masking)
        mask = torch.arange(T_max, device=x.device).unsqueeze(0) < lengths.unsqueeze(1)  # [B, T_max]
        mask = mask.unsqueeze(-1)  # [B, T_max, 1]
        K = K_all * mask
        V = V_all * mask

        # Attention Scores：Q 乘以前面所有步驟的 K
        scores = torch.matmul(Q, K.transpose(1, 2)) / np.sqrt(K.shape[-1])  # [B, 1, T_max]

        # 將超過有效長度的位置設為 -inf (做 masking)
        scores = scores.masked_fill(~mask.transpose(1, 2), float('-inf'))
        weights = torch.softmax(scores, dim=-1)  # [B, 1, T_max]

        # 加權求和得到 context
        context = torch.matmul(weights, V).squeeze(1)  # [B, H]

        # 這裡已經移除 residual，直接使用 context 作為最終表示
        final_repr = context  # [B, H]

        # Shared bottleneck
        shared_out = self.shared_head(final_repr)  # [B, head_dim]

        y1 = self.head_count(shared_out)
        y2 = self.head_citation(shared_out)
        return torch.cat([y1, y2], dim=1)  # [B, 2]

## Evaluation metrics

In [None]:
# --- Metrics (expm1 to reverse log1p) ---
def compute_metrics(y_pred, y_true):
    y_pred = torch.expm1(y_pred).cpu().numpy()
    y_true = torch.expm1(y_true).cpu().numpy()
    metrics = {}

    # 加上 round
    # y_pred = np.round(y_pred)
    # y_true = np.round(y_true)
    
    for i, name in enumerate(["count", "citation"]):
        y_p, y_t = y_pred[:, i], y_true[:, i]
        mse = np.mean((y_p - y_t) ** 2)
        mae = np.mean(np.abs(y_p - y_t))
        rmse = np.sqrt(mse)
        smape = np.mean(2 * np.abs(y_p - y_t) / (np.abs(y_p) + np.abs(y_t) + 1e-8))*100
        metrics[name] = {"MSE": mse, "MAE": mae, "RMSE": rmse, "SMAPE": smape}
    return metrics

In [None]:
def inspect_folds_and_epochs(df, folds, epochs=3):
    for fold_id, (train_coms, test_coms) in enumerate(folds):
        print(f"\n====== Inspect Fold {fold_id+1} ======")

        # Dataset   
        train_dataset = SlidingWindowDatasetDynamic(df, train_coms)
        test_dataset = SlidingWindowDatasetDynamic(df, test_coms)

        # DataLoader
        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=lambda x: x)
        val_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=lambda x: x)

        for epoch in range(epochs):
            print(f"\n--- Fold {fold_id+1} Epoch {epoch+1} [Train] ---")
            company_counter = 0
            sample_counter = 0
            for batch in train_loader:
                _, _, tickers, y_years, indices = zip(*batch)
                print(f"Company: {tickers[0]} | Samples: {len(indices)} | Years: {min(y_years)}-{max(y_years)}")
                print(f"  Indices: {indices}")
                company_counter += 1
                sample_counter += len(indices)
            print(f"===> [Train Summary] Companies: {company_counter} | Total Samples: {sample_counter}")

            print(f"\n--- Fold {fold_id+1} Epoch {epoch+1} [Validation] ---")
            company_counter = 0
            sample_counter = 0
            for batch in val_loader:
                _, _, tickers, y_years, indices = zip(*batch)
                print(f"Company: {tickers[0]} | Samples: {len(indices)} | Years: {min(y_years)}-{max(y_years)}")
                print(f"  Indices: {indices}")
                company_counter += 1
                sample_counter += len(indices)
            print(f"===> [Validation Summary] Companies: {company_counter} | Total Samples: {sample_counter}")


In [None]:
# inspect_folds_and_epochs(df, folds=5, epochs=3)

## Loss function

In [None]:
class MSEPunishLoss(nn.Module):
    def __init__(self, alpha=1.0, punish_mode='', return_each=False):
        """
        :param alpha: 懲罰項權重
        :param punish_mode: 'abs', 'square', 'binary', 'huber', 'ratio', 'threshold'
        :param return_each: 是否回傳 count loss 和 citation loss
        """
        super().__init__()
        self.mse = nn.MSELoss()
        self.alpha = alpha
        self.punish_mode = punish_mode
        self.return_each = return_each

    def forward(self, preds, targets):
        pred_count, pred_citation = preds[:, 0], preds[:, 1]
        true_count, true_citation = targets[:, 0], targets[:, 1]

        loss_count = self.mse(pred_count, true_count)
        loss_citation = self.mse(pred_citation, true_citation)
        loss_total = loss_count + loss_citation

        # Initialize punishment
        punishment = 0.0

        if self.punish_mode:
            mask = (true_count == 0)
            if mask.sum() > 0:
                if self.punish_mode == 'abs':
                    punishment = torch.sum(torch.abs(pred_citation[mask]))
                elif self.punish_mode == 'square':
                    punishment = torch.sum((pred_citation[mask]) ** 2)
                elif self.punish_mode == 'binary':
                    punishment = torch.sum((pred_citation[mask] > 1e-4).float())
                elif self.punish_mode == 'huber':
                    punishment = nn.functional.smooth_l1_loss(
                        pred_citation[mask],
                        torch.zeros_like(pred_citation[mask]),
                        reduction='sum'
                    )
                elif self.punish_mode == 'ratio':
                    safe_count = true_count + 1e-6
                    ratio = pred_citation / safe_count
                    punishment = torch.sum(ratio[mask])
                elif self.punish_mode == 'threshold':
                    threshold = 1.0
                    over = torch.clamp(pred_citation[mask] - threshold, min=0)
                    punishment = torch.sum(over)
                else:
                    raise ValueError(f"Unsupported punish_mode: {self.punish_mode}")

        loss_total = loss_total + self.alpha * punishment

        if self.return_each:
            return loss_total, loss_count.item(), loss_citation.item()
        else:
            return loss_total

## StratifiedKFold

In [None]:
from sklearn.model_selection import StratifiedKFold

company_df = df.groupby('ticker')[['patents_count', 'total_5yr_forward_citations']].mean().reset_index()

# Step 1: 建立 count 是否為 0 的 flag
company_df['has_patents'] = (company_df['patents_count'] > 0).astype(int)

# Step 2: 對 citation 進行分箱（只對有產出的公司分）
bins = [-1, 0, 10, 50, 200, np.inf]
company_df['citation_bin'] = pd.cut(company_df['total_5yr_forward_citations'], bins=bins, labels=False)

# Step 3: 建立綜合 stratify 標籤
# 沒有產出的公司統一為 0，有產出者根據 citation_bin 分層
company_df['stratify_label'] = company_df.apply(
    lambda row: 0 if row['has_patents'] == 0 else row['citation_bin'] + 1,
    axis=1
)

stratified_kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

folds = []
for train_idx, test_idx in stratified_kf.split(company_df['ticker'], company_df['stratify_label']):
    train_coms = company_df.iloc[train_idx]['ticker'].values
    test_coms = company_df.iloc[test_idx]['ticker'].values
    folds.append((train_coms, test_coms))


In [None]:
for i, (_, test_coms) in enumerate(folds):
    sub_df = df[df['ticker'].isin(test_coms)]
    sample_size = len(sub_df)

    # Citation 統計
    mean_cit = sub_df['total_5yr_forward_citations'].mean()
    std_cit = sub_df['total_5yr_forward_citations'].std()
    max_cit = sub_df['total_5yr_forward_citations'].max()

    # Count 統計
    mean_cnt = sub_df['patents_count'].mean()
    std_cnt = sub_df['patents_count'].std()
    max_cnt = sub_df['patents_count'].max()

    # 無專利公司比例
    company_level = sub_df.groupby('ticker')['patents_count'].sum().reset_index()
    no_patent_ratio = (company_level['patents_count'] == 0).mean() * 100

    print(f"[Fold {i+1}]")
    print(f"  Sample Size: {sample_size}")
    print(f"  Citation → mean: {mean_cit:.2f} | std: {std_cit:.2f} | max: {max_cit:.2f}")
    print(f"  Count    → mean: {mean_cnt:.2f} | std: {std_cnt:.2f} | max: {max_cnt:.2f}")
    print(f"  % Companies with NO patents: {no_patent_ratio:.1f}%")

## Evaluate best model

In [None]:
def evaluate_best_model(model, val_loader, fold_id, best_epoch):
    """
    Evaluate the best model on the validation set, returning predictions, true labels, and metadata.
    """
    model.eval()
    all_preds = []
    all_y = []
    val_meta_info = []

    with torch.no_grad():
        for batch_id, batch in enumerate(val_loader):
            x_seqs, y_seqs, tickers, y_years, indices = zip(*batch)
            lengths = torch.tensor([x.shape[0] for x in x_seqs]).to(next(model.parameters()).device)
            x_seqs = pad_sequence(x_seqs, batch_first=True).to(next(model.parameters()).device)
            y_seqs = torch.stack(y_seqs).to(next(model.parameters()).device)

            preds = model(x_seqs, lengths)

            all_preds.append(preds.cpu())
            all_y.append(y_seqs.cpu())

            for i, index in enumerate(indices):
                val_meta_info.append({
                    "fold": fold_id + 1,
                    "epoch": best_epoch + 1,
                    "batch": batch_id + 1,
                    "sample": i,
                    "index": index
                })

    return [torch.cat(all_preds)], [torch.cat(all_y)], [val_meta_info]


## Training

In [None]:
## Training
# --- Training Config ---
## 5-fold Cross Validation
company_ids = df['ticker'].unique()
# kf = KFold(n_splits=5, shuffle=True, random_state=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

epochs = 100
learning_rate = 1e-4

# 初始化儲存所有 fold 的 loss
all_train_losses_total, all_train_losses_count, all_train_losses_citation = [], [], []
all_train_preds_total, all_train_targets_total, all_train_meta_info_total = [], [], []

all_val_losses_count, all_val_losses_citation ,all_val_losses_total = [], [], []
all_val_preds_total, all_val_targets_total, all_val_meta_info_total = [], [], []

best_epoch_list = []

for fold_id, (train_coms, test_coms) in enumerate(folds):
    print(f"\n====== Fold {fold_id+1} ======")

    fold_train_preds_epochs, fold_train_targets_epochs, fold_train_meta_epochs = [], [], []
    fold_val_preds_epochs, fold_val_targets_epochs, fold_val_meta_epochs = [], [], []
    
    # Dataset   
    train_dataset = SlidingWindowDatasetDynamic(df, train_coms, max_len=5)
    test_dataset = SlidingWindowDatasetDynamic(df, test_coms, max_len=5)

    generator = torch.Generator().manual_seed(42)

    # DataLoader
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=lambda x: x, generator=generator)
    val_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=lambda x: x)

    # Model
    model = Attention_MTL_Stepwise().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = MSEPunishLoss(alpha=1.0, punish_mode='', return_each=True)

    # 儲存每個 epoch 的 loss（每 fold 一份）
    train_loss_total_list, train_loss_count_list, train_loss_citation_list = [], [], []
    val_loss_total_list, val_loss_count_list, val_loss_citation_list = [], [], []

    # <-- Early Stopping 初始化 -->
    best_val_loss = float('inf')
    epochs_no_improve = 0
    patience = 50
    best_model_state = None
    best_epoch = None  

    for epoch in range(epochs):
        model.train()
        epoch_losses, epoch_loss_count_all, epoch_loss_citation_all = [], [], []
        train_preds, train_targets, train_meta_info = [], [], []
        for batch_id, batch in enumerate(train_loader):
            x_seqs, y_seqs, tickers, y_years, indices = zip(*batch)
            lengths = torch.tensor([x.shape[0] for x in x_seqs]).to(device)
            x_seqs = pad_sequence(x_seqs, batch_first=True).to(device)
            y_seqs = torch.stack(y_seqs).to(device)

            preds = model(x_seqs, lengths)

            loss_total, loss_count, loss_citation = criterion(preds, y_seqs)
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step()

            epoch_losses.append(loss_total.item())
            epoch_loss_count_all.append(loss_count)
            epoch_loss_citation_all.append(loss_citation)

            train_preds.append(preds.detach().cpu())
            train_targets.append(y_seqs.detach().cpu())

            for i, index in enumerate(indices):
                train_meta_info.append({
                    "fold": fold_id + 1,
                    "epoch": epoch + 1,
                    "batch": batch_id + 1,
                    "sample": i,
                    "index": index
                })

            # # ✔ 每個 batch 印 loss、predictions、truth
            # print(f"[Fold {fold_id+1}][Epoch {epoch+1}][Train][Batch] Company: {tickers[0]} | Samples: {len(indices)}")
            # print(f"  Total Loss: {loss_total.item():.4f} | Count Loss: {loss_count:.4f} | Citation Loss: {loss_citation:.4f}")
            # print(f"  Predictions (expm1, rounded): {np.round(torch.expm1(preds).detach().cpu().numpy())}")
            # print(f"  Ground truth (expm1, rounded): {np.round(torch.expm1(y_seqs).detach().cpu().numpy())}")
            # print(f"  Indices: {indices}")

        avg_train_loss = np.mean(epoch_losses)
        avg_train_count = np.mean(epoch_loss_count_all)
        avg_train_citation = np.mean(epoch_loss_citation_all)

        train_loss_total_list.append(avg_train_loss)
        train_loss_count_list.append(avg_train_count)
        train_loss_citation_list.append(avg_train_citation)

        fold_train_meta_epochs.append(train_meta_info)
        fold_train_preds_epochs.append(torch.cat(train_preds))
        fold_train_targets_epochs.append(torch.cat(train_targets))

        # Validation
        model.eval()
        with torch.no_grad():
            val_losses = []
            val_loss_count_all = []
            val_loss_citation_all = []
            all_preds = []
            all_y = []
            val_meta_info = []  # <--- 新增：本 epoch 的 meta 資訊
            for batch_id, batch in enumerate(val_loader):
                x_seqs, y_seqs, tickers, y_years, indices = zip(*batch)
                lengths = torch.tensor([x.shape[0] for x in x_seqs]).to(device)
                x_seqs = pad_sequence(x_seqs, batch_first=True).to(device)
                y_seqs = torch.stack(y_seqs).to(device)

                preds = model(x_seqs, lengths)
                loss_total, loss_count, loss_citation = criterion(preds, y_seqs)
                
                val_losses.append(loss_total.item())
                val_loss_count_all.append(loss_count)
                val_loss_citation_all.append(loss_citation)

                all_preds.append(preds)
                all_y.append(y_seqs)

                # 💾 新增：記錄每一筆的 fold / epoch / batch / sample index
                for i, index in enumerate(indices):
                    val_meta_info.append({
                        "fold": fold_id + 1,
                        "epoch": epoch + 1,
                        "batch": batch_id + 1,
                        "sample": i,
                        "index": index
                    })

                # ✔ 每個 batch 印 loss、predictions、truth
                print(f"[Fold {fold_id+1}][Epoch {epoch+1}][Val][Batch] Company: {tickers[0]} | Samples: {len(indices)}")
                print(f"  Total Loss: {loss_total.item():.4f} | Count Loss: {loss_count:.4f} | Citation Loss: {loss_citation:.4f}")
                print(f"  Predictions (expm1, rounded): {np.round(torch.expm1(preds).detach().cpu().numpy())}")
                print(f"  Ground truth (expm1, rounded): {np.round(torch.expm1(y_seqs).detach().cpu().numpy())}")
                print(f"  Indices: {indices}")

            fold_val_preds_epochs.append(torch.cat(all_preds).cpu())
            fold_val_targets_epochs.append(torch.cat(all_y).cpu())
            fold_val_meta_epochs.append(val_meta_info)

            avg_val_loss = np.mean(val_losses)
            avg_val_count = np.mean(val_loss_count_all)
            avg_val_citation = np.mean(val_loss_citation_all)

            val_loss_total_list.append(avg_val_loss)
            val_loss_count_list.append(avg_val_count)
            val_loss_citation_list.append(avg_val_citation)

            all_preds = torch.cat(all_preds)
            all_y = torch.cat(all_y)

            # ✔ 每個 epoch 最後 summary
            print(f"[Fold {fold_id+1}][Epoch {epoch+1}] Summary")
            print(f"  Train Avg Loss     -> Total: {avg_train_loss:.4f} | Count: {avg_train_count:.4f} | Citation: {avg_train_citation:.4f}")
            print(f"  Validation Avg Loss-> Total: {avg_val_loss:.4f} | Count: {avg_val_count:.4f} | Citation: {avg_val_citation:.4f}")

        metrics = compute_metrics(all_preds, all_y)
        for name, vals in metrics.items():
            print(f"  [{name.upper()}] Metrics -> MSE: {vals['MSE']:.4f} | MAE: {vals['MAE']:.4f} | "
                f"RMSE: {vals['RMSE']:.4f} | SMAPE: {vals['SMAPE']:.2f}%")

        # <-- Early Stopping 判斷 -->
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
            best_epoch = epoch
            print(f"  ✅ Validation loss improved. Saving model.")
        else:
            epochs_no_improve += 1
            print(f"  ❌ No improvement. Patience: {epochs_no_improve}/{patience}")

        if epochs_no_improve >= patience:
            print(f"  ⏹ Early stopping triggered at epoch {epoch+1}. Best val_loss: {best_val_loss:.4f}")
            break
        print(f"  🔄 Best model loaded from epoch {best_epoch}")

    # 還原最佳模型狀態（可選）
    if best_model_state is not None:
        # 有早停（代表中途有最佳模型）
        model.load_state_dict(best_model_state)
        print("  🔄 Best model loaded for final evaluation or saving.")
        fold_val_preds_epochs, fold_val_targets_epochs, fold_val_meta_epochs = evaluate_best_model(
            model, val_loader, fold_id, best_epoch
        )
    else:
        # 沒有早停（跑滿 epochs），此時就保留最後一輪模型，照樣 evaluate
        print("  ⚠️ No best model found. Using final epoch model for evaluation.")
        fold_val_preds_epochs, fold_val_targets_epochs, fold_val_meta_epochs = evaluate_best_model(
            model, val_loader, fold_id, epoch  # 此時 epoch 就是最後一輪 index
        )
    
    # 每個 fold 儲存進總表中
    all_train_losses_total.append(train_loss_total_list)
    all_train_losses_count.append(train_loss_count_list)
    all_train_losses_citation.append(train_loss_citation_list)

    all_train_preds_total.append(fold_train_preds_epochs)
    all_train_targets_total.append(fold_train_targets_epochs)
    all_train_meta_info_total.append(fold_train_meta_epochs)

    all_val_losses_total.append(val_loss_total_list)
    all_val_losses_count.append(val_loss_count_list)
    all_val_losses_citation.append(val_loss_citation_list)

    all_val_preds_total.append(fold_val_preds_epochs)
    all_val_targets_total.append(fold_val_targets_epochs)
    all_val_meta_info_total.append(fold_val_meta_epochs)

    best_epoch_list.append(best_epoch + 1 if best_epoch is not None else epochs)

## compute_sample_metrics

In [None]:
def compute_sample_metrics(y_true, y_pred):
    """
    y_true, y_pred: shape [N, 2] numpy array, already expm1'd
    回傳 dict，包含每筆樣本的 MSE, MAE, RMSE, SMAPE for count & citation
    """
    abs_error = np.abs(y_pred - y_true)
    squared_error = (y_pred - y_true) ** 2
    rmse_error = np.sqrt(squared_error)
    smape_error = 2 * abs_error / (np.abs(y_pred) + np.abs(y_true) + 1e-8)

    return {
        "count_MSE": squared_error[:, 0],
        "count_MAE": abs_error[:, 0],
        "count_RMSE": rmse_error[:, 0],
        "count_SMAPE": smape_error[:, 0],
        "citation_MSE": squared_error[:, 1],
        "citation_MAE": abs_error[:, 1],
        "citation_RMSE": rmse_error[:, 1],
        "citation_SMAPE": smape_error[:, 1],
    }

# Save validation result to csv

In [None]:
val_summary_records, train_summary_records = [], []

for fold_id in range(len(folds)):
    fold_preds = all_val_preds_total[fold_id][0]
    fold_targets = all_val_targets_total[fold_id][0]
    fold_meta_info = all_val_meta_info_total[fold_id][0]

    # inverse log1p
    y_pred = torch.expm1(fold_preds).numpy()
    y_true = torch.expm1(fold_targets).numpy()

    metrics = compute_sample_metrics(y_true, y_pred)

    for i, meta in enumerate(fold_meta_info):
        val_summary_records.append({
            "fold": meta["fold"],
            "epoch": meta["epoch"],
            "batch": meta["batch"],
            "local_sample": meta["sample"],
            "index": meta["index"],
            "val_true_count": round(y_true[i, 0], 4),
            "val_pred_count": round(y_pred[i, 0], 4),
            "val_true_citation": round(y_true[i, 1], 4),
            "val_pred_citation": round(y_pred[i, 1], 4),
            "val_count_MSE": round(metrics["count_MSE"][i], 4),
            "val_count_MAE": round(metrics["count_MAE"][i], 4),
            "val_count_SMAPE": round(metrics["count_SMAPE"][i], 4),
            "val_citation_MSE": round(metrics["citation_MSE"][i], 4),
            "val_citation_MAE": round(metrics["citation_MAE"][i], 4),
            "val_citation_SMAPE": round(metrics["citation_SMAPE"][i], 4),
            # train 留空
            "train_true_count": np.nan,
            "train_pred_count": np.nan,
            "train_true_citation": np.nan,
            "train_pred_citation": np.nan,
            "train_loss_total": np.nan,
            "train_loss_count": np.nan,
            "train_loss_citation": np.nan,
            "train_count_MSE": np.nan,
            "train_count_MAE": np.nan,
            "train_count_SMAPE": np.nan,
            "train_citation_MSE": np.nan,
            "train_citation_MAE": np.nan,
            "train_citation_SMAPE": np.nan,
        })

    # === Training ===
    for epoch in range(len(all_train_preds_total[fold_id])):
        fold_preds = all_train_preds_total[fold_id][epoch].cpu()
        fold_targets = all_train_targets_total[fold_id][epoch].cpu()
        fold_meta_info = all_train_meta_info_total[fold_id][epoch]

        y_pred = torch.expm1(fold_preds).numpy()
        y_true = torch.expm1(fold_targets).numpy()

        metrics = compute_sample_metrics(y_true, y_pred)

        for i, meta in enumerate(fold_meta_info):
            train_summary_records.append({
                "fold": meta["fold"],
                "epoch": meta["epoch"],
                "batch": meta["batch"],
                "local_sample": meta["sample"],
                "index": meta["index"],
                # val 留空
                "val_true_count": np.nan,
                "val_pred_count": np.nan,
                "val_true_citation": np.nan,
                "val_pred_citation": np.nan,
                "val_loss_total": np.nan,
                "val_loss_count": np.nan,
                "val_loss_citation": np.nan,
                "val_count_MSE": np.nan,
                "val_count_MAE": np.nan,
                "val_count_SMAPE": np.nan,
                "val_citation_MSE": np.nan,
                "val_citation_MAE": np.nan,
                "val_citation_SMAPE": np.nan,
                # train 留資料
                "train_true_count": round(y_true[i, 0].item(), 4),
                "train_pred_count": round(y_pred[i, 0].item(), 4),
                "train_true_citation": round(y_true[i, 1].item(), 4),
                "train_pred_citation": round(y_pred[i, 1].item(), 4),
                "train_count_MSE": round(metrics["count_MSE"][i].item(), 4),
                "train_count_MAE": round(metrics["count_MAE"][i].item(), 4),
                "train_count_SMAPE": round(metrics["count_SMAPE"][i].item(), 4),
                "train_citation_MSE": round(metrics["citation_MSE"][i].item(), 4),
                "train_citation_MAE": round(metrics["citation_MAE"][i].item(), 4),
                "train_citation_SMAPE": round(metrics["citation_SMAPE"][i].item(), 4),
            })


# 儲存 CSV
val_summary_df = pd.DataFrame(val_summary_records)
train_summary_df = pd.DataFrame(train_summary_records)
combined_df = pd.concat([val_summary_df, train_summary_df], ignore_index=True)
val_summary_df.to_csv("../output/attention_v0_log_csr_t_val_detailed.csv", index=False)
combined_df.to_csv("../output/attention_v0_log_csr_t_train_val_detailed.csv", index=False)
print("✅ Saved detailed validation results to 'train_val_detailed.csv'")

# Draw

In [None]:
import matplotlib.pyplot as plt

# Count Loss Learning Curve
plt.figure(figsize=(8, 5))
for fold_id in range(5):
    plt.plot(all_train_losses_count[fold_id], label=f"Train Count - Fold {fold_id+1}")
    plt.plot(all_val_losses_count[fold_id], label=f"Val Count - Fold {fold_id+1}", linestyle='--')
plt.xlabel("Epoch")
plt.ylabel("Count Loss")
plt.title("Count Loss Learning Curve (Train vs. Validation)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Citation Loss Learning Curve
plt.figure(figsize=(8, 5))
for fold_id in range(5):
    plt.plot(all_train_losses_citation[fold_id], label=f"Train Citation - Fold {fold_id+1}")
    plt.plot(all_val_losses_citation[fold_id], label=f"Val Citation - Fold {fold_id+1}", linestyle='--')
plt.xlabel("Epoch")
plt.ylabel("Citation Loss")
plt.title("Citation Loss Learning Curve (Train vs. Validation)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# 合併所有驗證資料
y_pred_all = torch.cat([tensor for fold in all_val_preds_total for tensor in fold])
y_true_all = torch.cat([tensor for fold in all_val_targets_total for tensor in fold])

# 還原 log1p
y_pred_all = np.expm1(y_pred_all)
y_true_all = np.expm1(y_true_all)

# Count 圖
plt.figure(figsize=(6, 6))
plt.scatter(y_true_all[:, 0], y_pred_all[:, 0], alpha=0.5)
plt.plot([y_true_all[:, 0].min(), y_true_all[:, 0].max()],
         [y_pred_all[:, 0].min(), y_pred_all[:, 0].max()], 'r--')
plt.xlabel("Ground Truth (Count)")
plt.ylabel("Predictions (Count)")
plt.title("Predictions vs Ground Truth (Count) – All Folds")
plt.grid(True)
plt.show()

# Citation 圖
plt.figure(figsize=(6, 6))
plt.scatter(y_true_all[:, 1], y_pred_all[:, 1], alpha=0.5)
plt.plot([y_true_all[:, 1].min(), y_true_all[:, 1].max()],
         [y_pred_all[:, 1].min(), y_pred_all[:, 1].max()], 'r--')
plt.xlabel("Ground Truth (Citation)")
plt.ylabel("Predictions (Citation)")
plt.title("Predictions vs Ground Truth (Citation) – All Folds")
plt.grid(True)
plt.show()

In [None]:
print(f"Total number of samples: {y_true_all.shape[0]}")