In [None]:
# ====================================================
# Kaggle Inference Notebook
# Model: DeBERTa v3 (Arch1 - 6 Grouped Heads Strategy)
# ====================================================

import os
import gc
import sys
import math
import time
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig
from tqdm.auto import tqdm
import html

# ------------------- 配置 (Configuration) -------------------
class CFG:
    # 您的訓練權重路徑 (請指向存放 arch1_6groups 權重的資料夾)
    model_dir = "/kaggle/input/deberta-finetuned/pytorch/arch1/8" 
    
    # 您的 Tokenizer 資料夾路徑
    base_model = "/kaggle/input/deberta-tokenizer/deberta-v3-base-tokenizer" 
    
    # 必須與訓練時一致！
    pooling_strategy = 'arch1_6groups' 
    
    max_len = 512
    batch_size = 16
    num_workers = 2
    seed = 42
    n_fold = 5
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

TARGET_COLS = [
    'question_asker_intent_understanding', 'question_body_critical', 'question_conversational',
    'question_expect_short_answer', 'question_fact_seeking', 'question_has_commonly_accepted_answer',
    'question_interestingness_others', 'question_interestingness_self', 'question_multi_intent',
    'question_not_really_a_question', 'question_opinion_seeking', 'question_type_choice',
    'question_type_compare', 'question_type_consequence', 'question_type_definition',
    'question_type_entity', 'question_type_instructions', 'question_type_procedure',
    'question_type_reason_explanation', 'question_type_spelling', 'question_well_written',
    'answer_helpful', 'answer_level_of_information', 'answer_plausible', 'answer_relevance',
    'answer_satisfaction', 'answer_type_instructions', 'answer_type_procedure',
    'answer_type_reason_explanation', 'answer_well_written'
]

# ------------------- 資料處理 -------------------
def modern_preprocess(text):
    if pd.isna(text): return ""
    text = str(text)
    text = html.unescape(text)
    text = " ".join(text.split())
    return text

class QuestDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        self.questions = [
            modern_preprocess(t) + " " + modern_preprocess(b) 
            for t, b in zip(df['question_title'].values, df['question_body'].values)
        ]
        self.answers = [modern_preprocess(a) for a in df['answer'].values]
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        answer = self.answers[idx]
        
        inputs = self.tokenizer(
            question,
            answer,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )
        
        item = {
            'input_ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(inputs['attention_mask'], dtype=torch.long)
        }
        
        if 'token_type_ids' in inputs:
            item['token_type_ids'] = torch.tensor(inputs['token_type_ids'], dtype=torch.long)
            
        return item

# ------------------- 模型定義 (6-Head 架構) -------------------
class QuestModel(nn.Module):
    def __init__(self, model_name, num_targets, pooling_strategy='arch1', dropout_rate=0.1):
        super().__init__()
        self.pooling_strategy = pooling_strategy
        self.config = AutoConfig.from_pretrained(model_name)
        
        if pooling_strategy == 'arch2':
            self.config.update({'output_hidden_states': True})
            
        # 使用 from_config 初始化結構 (不載入預訓練權重，因為我們會覆蓋它)
        self.backbone = AutoModel.from_config(self.config)
        hidden_size = self.config.hidden_size
        
        # -----------------------------------------------------------------------
        # 定義 6-Head 分群索引 (QA Splitted Strategy)
        # -----------------------------------------------------------------------
        # Question Groups (G1-G4)
        self.idx_g1 = [3, 4, 5, 16, 17]          # Fact/Instructions
        self.idx_g2 = [0, 1, 6, 7, 20]           # Quality/Intent
        self.idx_g3 = [2, 10]                    # Conversational
        self.idx_g4 = [8, 9, 11, 12, 13, 14, 15, 18, 19] # Type/Class
        
        # Answer Groups (G5-G6)
        self.idx_g5 = [26, 27]                   # Instructions
        self.idx_g6 = [21, 22, 23, 24, 25, 28, 29] # Quality/Helpful
        
        # -----------------------------------------------------------------------
        # 定義 Heads
        # -----------------------------------------------------------------------
        if self.pooling_strategy == 'mean':
            self.fc = nn.Linear(hidden_size, num_targets)
            
        elif self.pooling_strategy == 'arch1':
            # 舊版 2-Head 邏輯
            self.q_head = self._make_head(hidden_size * 3, 21, dropout_rate)
            self.a_head = self._make_head(hidden_size * 3, 9, dropout_rate)
            
        elif self.pooling_strategy == 'arch1_6groups':
            # --- 6-Head 策略 (QA Splitted) ---
            # 所有的 Head 輸入都是 3 * hidden (因為只用 Q 特徵或 A 特徵)
            # Question Heads
            self.head_g1 = self._make_head(hidden_size * 3, len(self.idx_g1), dropout_rate)
            self.head_g2 = self._make_head(hidden_size * 3, len(self.idx_g2), dropout_rate)
            self.head_g3 = self._make_head(hidden_size * 3, len(self.idx_g3), dropout_rate)
            self.head_g4 = self._make_head(hidden_size * 3, len(self.idx_g4), dropout_rate)
            
            # Answer Heads
            self.head_g5 = self._make_head(hidden_size * 3, len(self.idx_g5), dropout_rate)
            self.head_g6 = self._make_head(hidden_size * 3, len(self.idx_g6), dropout_rate)

        elif self.pooling_strategy == 'arch2':
            self.fc = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_size * 4, num_targets)
            )

    def _make_head(self, input_dim, output_dim, dropout_rate):
        head = nn.Sequential(
            nn.Linear(input_dim, self.config.hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout_rate),
            nn.Linear(self.config.hidden_size, output_dim)
        )
        return head

    def _masked_mean_pooling(self, hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
        sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def _get_pooling_features(self, last_hidden_state, attention_mask, token_type_ids):
        cls_token = last_hidden_state[:, 0, :]
        global_avg = self._masked_mean_pooling(last_hidden_state, attention_mask)
        
        if token_type_ids is None:
            q_avg = global_avg
            a_avg = global_avg
        else:
            q_mask = attention_mask * (1 - token_type_ids)
            q_avg = self._masked_mean_pooling(last_hidden_state, q_mask)
            a_mask = attention_mask * token_type_ids
            a_avg = self._masked_mean_pooling(last_hidden_state, a_mask)
            
        return cls_token, global_avg, q_avg, a_avg

    def _pool_arch2(self, all_hidden_states):
        last_4_layers = all_hidden_states[-4:]
        cls_embeddings = [layer[:, 0, :] for layer in last_4_layers]
        return torch.cat(cls_embeddings, dim=1)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden_state = outputs.last_hidden_state
        
        if self.pooling_strategy == 'mean':
            feature = self._masked_mean_pooling(last_hidden_state, attention_mask)
            output = self.fc(feature)
            
        elif self.pooling_strategy == 'arch1':
            cls, glob, q, a = self._get_pooling_features(last_hidden_state, attention_mask, token_type_ids)
            q_feat = torch.cat([cls, glob, q], dim=1)
            a_feat = torch.cat([cls, glob, a], dim=1)
            output = torch.cat([self.q_head(q_feat), self.a_head(a_feat)], dim=1)
            
        elif self.pooling_strategy == 'arch1_6groups':
            cls, glob, q, a = self._get_pooling_features(last_hidden_state, attention_mask, token_type_ids)
            
            # 1. 準備特徵 (3x Hidden)
            # Question 特徵：[CLS, Global, Q_Avg]
            feat_pure_q = torch.cat([cls, glob, q], dim=1)
            # Answer 特徵：[CLS, Global, A_Avg]
            feat_pure_a = torch.cat([cls, glob, a], dim=1)
            
            # 2. 通過各個 Head
            # Q Groups 使用 Q 特徵
            out_g1 = self.head_g1(feat_pure_q)
            out_g2 = self.head_g2(feat_pure_q)
            out_g3 = self.head_g3(feat_pure_q)
            out_g4 = self.head_g4(feat_pure_q)
            
            # A Groups 使用 A 特徵
            out_g5 = self.head_g5(feat_pure_a)
            out_g6 = self.head_g6(feat_pure_a)
            
            batch_size = input_ids.size(0)
            # 確保 dtype 一致性 (for mixed precision)
            output = torch.zeros(batch_size, 30, dtype=out_g1.dtype, device=input_ids.device)
            
            output[:, self.idx_g1] = out_g1
            output[:, self.idx_g2] = out_g2
            output[:, self.idx_g3] = out_g3
            output[:, self.idx_g4] = out_g4
            output[:, self.idx_g5] = out_g5
            output[:, self.idx_g6] = out_g6
            
        elif self.pooling_strategy == 'arch2':
            feature = self._pool_arch2(outputs.hidden_states)
            output = self.fc(feature)
            
        return output

# ------------------- 優化器 (Robust OptimizedRounder) -------------------
class OptimizedRounder:
    def __init__(self):
        self.coef_ = [0.015, 0.985]

    def predict(self, X, coef):
        X = np.nan_to_num(X, nan=0.5)
        X_p = np.copy(X)
        low, high = coef[0], coef[1]
        X_p = np.clip(X_p, low, high)
        
        if np.unique(X_p).size == 1:
            eps = 1e-3
            max_idx = np.argmax(X)
            X_p[max_idx] += eps
            
        return X_p

# ------------------- 推論邏輯 -------------------
def inference_fn(test_loader, model, device):
    model.eval()
    preds = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Predicting"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            token_type_ids = batch.get('token_type_ids')
            if token_type_ids is not None:
                token_type_ids = token_type_ids.to(device)
            
            y_preds = model(input_ids, attention_mask, token_type_ids)
            preds.append(y_preds.sigmoid().cpu().numpy())
            
    return np.concatenate(preds)

# ------------------- 主程式 -------------------
if __name__ == '__main__':
    TEST_PATH = 'test.csv'
    if not os.path.exists(TEST_PATH):
        TEST_PATH = '/kaggle/input/google-quest-challenge/test.csv'
    
    test = pd.read_csv(TEST_PATH)
    print(f"Test Data Shape: {test.shape}")
    
    tokenizer = AutoTokenizer.from_pretrained(CFG.base_model)
    
    test_dataset = QuestDataset(test, tokenizer, max_len=CFG.max_len)
    test_loader = DataLoader(
        test_dataset, 
        batch_size=CFG.batch_size, 
        shuffle=False, 
        num_workers=CFG.num_workers,
        pin_memory=True
    )
    
    fold_preds = []
    
    for fold in range(CFG.n_fold):
        weight_path = os.path.join(CFG.model_dir, f"deberta_v3_fold{fold}_best.pth")
        
        if not os.path.exists(weight_path):
            print(f"Warning: Weights for fold {fold} not found at {weight_path}. Skipping.")
            continue
            
        print(f"Loading Fold {fold} Model...")
        
        model = QuestModel(
            CFG.base_model, 
            num_targets=len(TARGET_COLS), 
            pooling_strategy=CFG.pooling_strategy
        )
        
        state_dict = torch.load(weight_path, map_location=CFG.device)
        model.load_state_dict(state_dict)
        model.to(CFG.device)
        
        preds = inference_fn(test_loader, model, CFG.device)
        fold_preds.append(preds)
        
        del model, state_dict
        torch.cuda.empty_cache()
        gc.collect()
        
    if len(fold_preds) > 0:
        avg_preds = np.mean(fold_preds, axis=0)
        
        print("Applying OptimizedRounder...")
        final_preds = np.zeros_like(avg_preds)
        opt = OptimizedRounder()
        
        for i in range(len(TARGET_COLS)):
            final_preds[:, i] = opt.predict(avg_preds[:, i], opt.coef_)
            
        submission = pd.read_csv('/kaggle/input/google-quest-challenge/sample_submission.csv')
        submission[TARGET_COLS] = final_preds
        submission.to_csv('submission.csv', index=False)
        print("submission.csv saved successfully!")
    else:
        print("Error: No predictions generated.")