In [None]:
# ====================================================
# Kaggle Inference Notebook (Single Regression + 4 Post-Processing Strategies)
# ====================================================

import os
import gc
import sys
import math
import time
import random
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig
from tqdm.auto import tqdm
from scipy.stats import rankdata
import html

# ------------------- 配置 (Configuration) -------------------
class CFG:
    # 您的訓練權重路徑
    model_dir = "/kaggle/input/deberta-finetuned/pytorch/arch1/10" 
    
    # Tokenizer 路徑
    base_model = "/kaggle/input/deberta-tokenizer/deberta-v3-base-tokenizer" 
    
    pooling_strategy = 'arch1_6groups' 
    
    # 【關鍵開關】選擇後處理方式
    # Options: 'raw', 'optimized', 'voters', 'distribution'
    post_processing = 'voters' 
    
    max_len = 512
    batch_size = 16
    num_workers = 2
    seed = 42
    n_fold = 5
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

TARGET_COLS = [
    'question_asker_intent_understanding', 'question_body_critical', 'question_conversational',
    'question_expect_short_answer', 'question_fact_seeking', 'question_has_commonly_accepted_answer',
    'question_interestingness_others', 'question_interestingness_self', 'question_multi_intent',
    'question_not_really_a_question', 'question_opinion_seeking', 'question_type_choice',
    'question_type_compare', 'question_type_consequence', 'question_type_definition',
    'question_type_entity', 'question_type_instructions', 'question_type_procedure',
    'question_type_reason_explanation', 'question_type_spelling', 'question_well_written',
    'answer_helpful', 'answer_level_of_information', 'answer_plausible', 'answer_relevance',
    'answer_satisfaction', 'answer_type_instructions', 'answer_type_procedure',
    'answer_type_reason_explanation', 'answer_well_written'
]

# ------------------- 資料處理 -------------------
def modern_preprocess(text):
    if pd.isna(text): return ""
    text = str(text)
    text = html.unescape(text)
    text = " ".join(text.split())
    return text

class QuestDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        self.questions = [
            modern_preprocess(t) + " " + modern_preprocess(b) 
            for t, b in zip(df['question_title'].values, df['question_body'].values)
        ]
        self.answers = [modern_preprocess(a) for a in df['answer'].values]
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        answer = self.answers[idx]
        
        inputs = self.tokenizer(
            question,
            answer,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )
        
        item = {
            'input_ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(inputs['attention_mask'], dtype=torch.long)
        }
        
        if 'token_type_ids' in inputs:
            item['token_type_ids'] = torch.tensor(inputs['token_type_ids'], dtype=torch.long)
            
        return item

# ------------------- 模型定義 (Single Regression) -------------------
class QuestModel(nn.Module):
    def __init__(self, model_name, num_targets, pooling_strategy='arch1', dropout_rate=0.1):
        super().__init__()
        self.pooling_strategy = pooling_strategy
        self.config = AutoConfig.from_pretrained(model_name)
        if pooling_strategy in ['arch2', 'cls_all']:
            self.config.update({'output_hidden_states': True})
            
        self.backbone = AutoModel.from_config(self.config)
        hidden_size = self.config.hidden_size
        
        self.idx_g1 = [3, 4, 5, 16, 17]          
        self.idx_g2 = [0, 1, 6, 7, 20]           
        self.idx_g3 = [2, 10]                    
        self.idx_g4 = [8, 9, 11, 12, 13, 14, 15, 18, 19] 
        self.idx_g5 = [26, 27]                   
        self.idx_g6 = [21, 22, 23, 24, 25, 28, 29] 
        
        if self.pooling_strategy == 'mean':
            self.fc = nn.Linear(hidden_size, num_targets)
            
        elif self.pooling_strategy == 'arch1':
            self.q_head = self._make_head(hidden_size * 3, 21, dropout_rate)
            self.a_head = self._make_head(hidden_size * 3, 9, dropout_rate)
            
        elif self.pooling_strategy == 'arch1_6groups':
            self.head_g1 = self._make_head(hidden_size * 3, len(self.idx_g1), dropout_rate)
            self.head_g2 = self._make_head(hidden_size * 3, len(self.idx_g2), dropout_rate)
            self.head_g3 = self._make_head(hidden_size * 3, len(self.idx_g3), dropout_rate)
            self.head_g4 = self._make_head(hidden_size * 3, len(self.idx_g4), dropout_rate)
            self.head_g5 = self._make_head(hidden_size * 3, len(self.idx_g5), dropout_rate)
            self.head_g6 = self._make_head(hidden_size * 3, len(self.idx_g6), dropout_rate)

        elif self.pooling_strategy == 'arch2':
            self.fc = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_size * 4, num_targets)
            )

        elif self.pooling_strategy == 'cls_all':
            # Concatenate CLS tokens from all hidden layers
            # num_hidden_layers = 12 for DeBERTa-v3-base
            # +1 for embedding layer
            num_hidden_states = self.config.num_hidden_layers + 1
            cls_concat_dim = hidden_size * num_hidden_states
            
            self.head_g1 = self._make_head(cls_concat_dim, len(self.idx_g1), dropout_rate)
            self.head_g2 = self._make_head(cls_concat_dim, len(self.idx_g2), dropout_rate)
            self.head_g3 = self._make_head(cls_concat_dim, len(self.idx_g3), dropout_rate)
            self.head_g4 = self._make_head(cls_concat_dim, len(self.idx_g4), dropout_rate)
            self.head_g5 = self._make_head(cls_concat_dim, len(self.idx_g5), dropout_rate)
            self.head_g6 = self._make_head(cls_concat_dim, len(self.idx_g6), dropout_rate)

        elif self.pooling_strategy == 'mlp_only':
            # MLP that reduces sequence length dimension to 1
            self.mlp = nn.Sequential(
                nn.Linear(512, hidden_size // 2),
                nn.Tanh(),
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_size // 2, 1)
            )
            self.head_g1 = self._make_head(hidden_size, len(self.idx_g1), dropout_rate)
            self.head_g2 = self._make_head(hidden_size, len(self.idx_g2), dropout_rate)
            self.head_g3 = self._make_head(hidden_size, len(self.idx_g3), dropout_rate)
            self.head_g4 = self._make_head(hidden_size, len(self.idx_g4), dropout_rate)
            self.head_g5 = self._make_head(hidden_size, len(self.idx_g5), dropout_rate)
            self.head_g6 = self._make_head(hidden_size, len(self.idx_g6), dropout_rate)

    def _make_head(self, input_dim, output_dim, dropout_rate):
        return nn.Sequential(
            nn.Linear(input_dim, self.config.hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout_rate),
            nn.Linear(self.config.hidden_size, output_dim)
        )

    def _masked_mean_pooling(self, hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
        sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def _get_pooling_features(self, last_hidden_state, attention_mask, token_type_ids):
        cls_token = last_hidden_state[:, 0, :]
        global_avg = self._masked_mean_pooling(last_hidden_state, attention_mask)
        
        if token_type_ids is None:
            q_avg = global_avg; a_avg = global_avg
        else:
            q_mask = attention_mask * (1 - token_type_ids)
            q_avg = self._masked_mean_pooling(last_hidden_state, q_mask)
            a_mask = attention_mask * token_type_ids
            a_avg = self._masked_mean_pooling(last_hidden_state, a_mask)
            
        return cls_token, global_avg, q_avg, a_avg

    def _pool_arch2(self, all_hidden_states):
        last_4_layers = all_hidden_states[-4:]
        cls_embeddings = [layer[:, 0, :] for layer in last_4_layers]
        return torch.cat(cls_embeddings, dim=1)
    
    def _pool_cls_concat(self, all_hidden_states):
        '''Concatenate CLS tokens from all hidden layers'''
        # all_hidden_states: tuple of (num_layers,), each [batch, seq_len, hidden]
        cls_embeddings = [layer[:, 0, :] for layer in all_hidden_states]
        return torch.cat(cls_embeddings, dim=1)  # [batch, hidden * num_layers]

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden_state = outputs.last_hidden_state
        
        if self.pooling_strategy == 'mean':
            feature = self._masked_mean_pooling(outputs.last_hidden_state, attention_mask)
            output = self.fc(feature)
            
        elif self.pooling_strategy == 'arch1':
            cls, glob, q, a = self._get_pooling_features(last_hidden_state, attention_mask, token_type_ids)
            q_feat = torch.cat([cls, glob, q], dim=1)
            a_feat = torch.cat([cls, glob, a], dim=1)
            output = torch.cat([self.q_head(q_feat), self.a_head(a_feat)], dim=1)
            
        elif self.pooling_strategy == 'arch1_6groups':
            cls, glob, q, a = self._get_pooling_features(last_hidden_state, attention_mask, token_type_ids)
            feat_pure_q = torch.cat([cls, glob, q], dim=1)
            feat_pure_a = torch.cat([cls, glob, a], dim=1)
            
            out_g1 = self.head_g1(feat_pure_q)
            out_g2 = self.head_g2(feat_pure_q)
            out_g3 = self.head_g3(feat_pure_q)
            out_g4 = self.head_g4(feat_pure_q)
            out_g5 = self.head_g5(feat_pure_a)
            out_g6 = self.head_g6(feat_pure_a)
            
            batch_size = input_ids.size(0)
            output = torch.zeros(batch_size, 30, dtype=out_g1.dtype, device=input_ids.device)
            output[:, self.idx_g1] = out_g1
            output[:, self.idx_g2] = out_g2
            output[:, self.idx_g3] = out_g3
            output[:, self.idx_g4] = out_g4
            output[:, self.idx_g5] = out_g5
            output[:, self.idx_g6] = out_g6
            
        elif self.pooling_strategy == 'arch2':
            feature = self._pool_arch2(outputs.hidden_states)
            output = self.fc(feature)

        elif self.pooling_strategy == 'cls_all':
            # Concatenate CLS tokens from all hidden layers, feed into 6 heads
            a_feat = self._pool_cls_concat(outputs.hidden_states)
            
            # Pass through 6 heads
            out_g1 = self.head_g1(a_feat)
            out_g2 = self.head_g2(a_feat)
            out_g3 = self.head_g3(a_feat)
            out_g4 = self.head_g4(a_feat)
            out_g5 = self.head_g5(a_feat)
            out_g6 = self.head_g6(a_feat)
            
            # Re-assemble output
            batch_size = input_ids.size(0)
            output = torch.zeros(batch_size, 30, dtype=out_g1.dtype, device=input_ids.device)
            
            output[:, self.idx_g1] = out_g1
            output[:, self.idx_g2] = out_g2
            output[:, self.idx_g3] = out_g3
            output[:, self.idx_g4] = out_g4
            output[:, self.idx_g5] = out_g5
            output[:, self.idx_g6] = out_g6

        elif self.pooling_strategy == 'mlp_only':
            batch_size, seq_len, hidden = last_hidden_state.shape
            transposed = last_hidden_state.transpose(1, 2)
            reduced = self.mlp(transposed)
            a_feat = reduced.squeeze(-1)
            
            out_g1 = self.head_g1(a_feat)
            out_g2 = self.head_g2(a_feat)
            out_g3 = self.head_g3(a_feat)
            out_g4 = self.head_g4(a_feat)
            out_g5 = self.head_g5(a_feat)
            out_g6 = self.head_g6(a_feat)
            
            output = torch.zeros(batch_size, 30, dtype=out_g1.dtype, device=input_ids.device)
            output[:, self.idx_g1] = out_g1
            output[:, self.idx_g2] = out_g2
            output[:, self.idx_g3] = out_g3
            output[:, self.idx_g4] = out_g4
            output[:, self.idx_g5] = out_g5
            output[:, self.idx_g6] = out_g6
            
        return output

# ------------------- Post-Processing Strategies -------------------

# 1. Optimized Rounder (閾值截斷)
class OptimizedRounder:
    def __init__(self):
        self.coef_ = [0.05, 0.95]

    def predict(self, X):
        X = np.nan_to_num(X, nan=0.5)
        X_p = np.copy(X)
        low, high = self.coef_[0], self.coef_[1]
        X_p = np.clip(X_p, low, high)
        if np.unique(X_p).size == 1:
            eps = 1e-6
            max_idx = np.argmax(X)
            X_p[max_idx] += eps
        return X_p

# 2. Voters Rounder (網格吸附 - 防呆修正版)
class VotersRounder:
    def __init__(self, train_vals, dev_threshold=0.01):
        """
        Voters Rounder with deviation-based fallback
        
        Args:
            train_vals: Training values to build the grid
            dev_threshold: Standard deviation threshold. If snapped values have 
                          deviation below this, return original predictions
        """
        # 過濾掉可能的 NaN，確保網格乾淨
        clean_vals = train_vals[~np.isnan(train_vals)]
        self.unique_vals = np.sort(np.unique(clean_vals))
        self.dev_threshold = dev_threshold
    
    def predict(self, X):
        # 1. 清理輸入 NaN
        X_clean = np.nan_to_num(X, nan=0.5)
        
        # 2. 吸附到網格 (Snap to Grid)
        idx = np.abs(X_clean[:, None] - self.unique_vals[None, :]).argmin(axis=1)
        X_p = self.unique_vals[idx]
        
        # 3. 【改進】檢查標準差是否太小
        # 如果吸附後的標準差小於閾值，代表網格太粗，回傳原始預測值
        deviation = np.std(X_p)
        if deviation < self.dev_threshold:
            return X_clean
            
        return X_p

# 3. Distribution Rounder (分佈擬合)
class DistributionRounder:
    def __init__(self, train_vals):
        # 儲存訓練集的分佈 (排序後)
        self.train_vals = np.sort(train_vals)
        
    def predict(self, X):
        # 1. 取得預測值的排名 (0 ~ 1)
        # argsort twice 得到 rank
        n = len(X)
        ranks = rankdata(X, method='ordinal') - 1
        
        # 2. 對映回訓練集的值 (Quantile Mapping)
        return np.interp(
            np.linspace(0, 1, n),
            np.linspace(0, 1, len(self.train_vals)),
            self.train_vals
        )[ranks]

# ------------------- Inference Logic -------------------
def inference_fn(test_loader, model, device):
    model.eval()
    preds = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Predicting"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch.get('token_type_ids')
            if token_type_ids is not None:
                token_type_ids = token_type_ids.to(device)
            
            y_preds = model(input_ids, attention_mask, token_type_ids)
            preds.append(y_preds.sigmoid().cpu().numpy())
            
    return np.concatenate(preds)

# ------------------- Main -------------------
if __name__ == '__main__':
    TEST_PATH = 'test.csv'
    TRAIN_PATH = 'train.csv' # 用於後處理學習分佈
    
    # Kaggle 路徑檢查
    if not os.path.exists(TEST_PATH):
        TEST_PATH = '/kaggle/input/google-quest-challenge/test.csv'
        TRAIN_PATH = '/kaggle/input/google-quest-challenge/train.csv'
    
    test = pd.read_csv(TEST_PATH)
    train = pd.read_csv(TRAIN_PATH)
    
    print(f"Test Shape: {test.shape}, Train Shape: {train.shape}")
    print(f"Selected Post-Processing Strategy: {CFG.post_processing}")
    
    tokenizer = AutoTokenizer.from_pretrained(CFG.base_model)
    test_dataset = QuestDataset(test, tokenizer, max_len=CFG.max_len)
    test_loader = DataLoader(
        test_dataset, 
        batch_size=CFG.batch_size, 
        shuffle=False, 
        num_workers=CFG.num_workers,
        pin_memory=True
    )
    
    # 尋找模型
    weight_paths = []
    for fold in range(CFG.n_fold):
        path = os.path.join(CFG.model_dir, f"deberta_v3_fold{fold}_best.pth")
        if os.path.exists(path): weight_paths.append(path)
    if os.path.exists(os.path.join(CFG.model_dir, "deberta_v3_single_run_best.pth")):
        weight_paths.append(os.path.join(CFG.model_dir, "deberta_v3_single_run_best.pth"))
        
    if not weight_paths:
        print("No weights found!")
        sys.exit(1)
        
    print(f"Ensembling {len(weight_paths)} models...")
    
    model_preds = []
    for weight_path in weight_paths:
        model = QuestModel(
            CFG.base_model, 
            num_targets=len(TARGET_COLS), 
            pooling_strategy=CFG.pooling_strategy
        )
        model.load_state_dict(torch.load(weight_path, map_location=CFG.device))
        model.to(CFG.device)
        
        preds = inference_fn(test_loader, model, CFG.device)
        model_preds.append(preds)
        
        del model
        torch.cuda.empty_cache()
        gc.collect()
        
    if len(model_preds) > 0:
        # 1. 取得平均預測
        avg_preds = np.mean(model_preds, axis=0)
        final_preds = np.zeros_like(avg_preds)
        
        # 2. 應用選擇的後處理策略
        print(f"Applying Post-Processing: {CFG.post_processing}")
        
        for i, col in enumerate(TARGET_COLS):
            # 原始訓練數據 (供 Distribution/Voters 參考)
            train_col_values = train[col].values
            # 當前欄位的預測值
            curr_preds = avg_preds[:, i]
            
            if CFG.post_processing == 'raw':
                final_preds[:, i] = curr_preds
                
            elif CFG.post_processing == 'optimized':
                opt = OptimizedRounder()
                final_preds[:, i] = opt.predict(curr_preds)
                
            elif CFG.post_processing == 'voters':
                voter = VotersRounder(train_col_values)
                final_preds[:, i] = voter.predict(curr_preds)
                
            elif CFG.post_processing == 'distribution':
                dist_rounder = DistributionRounder(train_col_values)
                final_preds[:, i] = dist_rounder.predict(curr_preds)
        
        submission = pd.read_csv('/kaggle/input/google-quest-challenge/sample_submission.csv')
        submission[TARGET_COLS] = final_preds
        submission.to_csv('submission.csv', index=False)
        print("submission.csv saved successfully!")
    else:
        print("Error: No predictions.")