In [None]:
# ====================================================
# Kaggle Inference Notebook (Llama 3.2 1B + 4 Post-Processing Strategies)
# ====================================================

import os
import gc
import sys
import math
import time
import random
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig
from tqdm.auto import tqdm
from scipy.stats import rankdata
import html

# ------------------- 配置 (Configuration) -------------------
class CFG:
    # 您的訓練權重路徑 (請放置 Llama 3.2 1B 微調後的模型權重)
    model_dir = "/kaggle/input/llama"  # TODO: 修改為您在 Kaggle Input 的資料集路徑

    # Tokenizer / Base Model (需與訓練時相同)
    base_model = "/kaggle/input/llama-3-2/transformers/1b/1"

    pooling_strategy = 'arch1_6groups' 

    # 【關鍵開關】選擇後處理方式
    # Options: 'raw', 'optimized', 'voters', 'distribution'
    post_processing = 'voters' 

    max_len = 512
    batch_size = 16
    num_workers = 2
    seed = 42
    n_fold = 5
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trust_remote_code = False
    freeze_backbone = True

TARGET_COLS = [
    'question_asker_intent_understanding', 'question_body_critical', 'question_conversational',
    'question_expect_short_answer', 'question_fact_seeking', 'question_has_commonly_accepted_answer',
    'question_interestingness_others', 'question_interestingness_self', 'question_multi_intent',
    'question_not_really_a_question', 'question_opinion_seeking', 'question_type_choice',
    'question_type_compare', 'question_type_consequence', 'question_type_definition',
    'question_type_entity', 'question_type_instructions', 'question_type_procedure',
    'question_type_reason_explanation', 'question_type_spelling', 'question_well_written',
    'answer_helpful', 'answer_level_of_information', 'answer_plausible', 'answer_relevance',
    'answer_satisfaction', 'answer_type_instructions', 'answer_type_procedure',
    'answer_type_reason_explanation', 'answer_well_written'
]

# ------------------- 資料處理 -------------------
def modern_preprocess(text):
    if pd.isna(text): return ""
    text = str(text)
    text = html.unescape(text)
    text = " ".join(text.split())
    return text

class QuestDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        self.questions = [
            modern_preprocess(t) + " " + modern_preprocess(b) 
            for t, b in zip(df['question_title'].values, df['question_body'].values)
        ]
        self.answers = [modern_preprocess(a) for a in df['answer'].values]
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        question = self.questions[idx]
        answer = self.answers[idx]
        
        inputs = self.tokenizer(
            question,
            answer,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors=None
        )
        
        item = {
            'input_ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(inputs['attention_mask'], dtype=torch.long)
        }
        
        # Llama tokenizer 可能不提供 token_type_ids
        if 'token_type_ids' in inputs:
            item['token_type_ids'] = torch.tensor(inputs['token_type_ids'], dtype=torch.long)
            
        return item

# ------------------- 模型定義 (Single Regression) -------------------
class QuestModel(nn.Module):
    def __init__(
        self,
        model_name,
        num_targets,
        pooling_strategy='arch1',
        dropout_rate=0.1,
        freeze_backbone=True,
        trust_remote_code=False
    ):
        super().__init__()
        self.pooling_strategy = pooling_strategy
        self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
        if pooling_strategy == 'arch2':
            self.config.output_hidden_states = True
            
        self.backbone = AutoModel.from_pretrained(
            model_name,
            config=self.config,
            trust_remote_code=trust_remote_code
        )
        
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        hidden_size = self.config.hidden_size
        
        self.idx_g1 = [3, 4, 5, 16, 17]          
        self.idx_g2 = [0, 1, 6, 7, 20]           
        self.idx_g3 = [2, 10]                    
        self.idx_g4 = [8, 9, 11, 12, 13, 14, 15, 18, 19] 
        self.idx_g5 = [26, 27]                   
        self.idx_g6 = [21, 22, 23, 24, 25, 28, 29] 
        
        if self.pooling_strategy == 'mean':
            self.fc = nn.Linear(hidden_size, num_targets)
            
        elif self.pooling_strategy == 'arch1':
            self.q_head = self._make_head(hidden_size * 3, 21, dropout_rate)
            self.a_head = self._make_head(hidden_size * 3, 9, dropout_rate)
            
        elif self.pooling_strategy == 'arch1_6groups':
            self.head_g1 = self._make_head(hidden_size * 3, len(self.idx_g1), dropout_rate)
            self.head_g2 = self._make_head(hidden_size * 3, len(self.idx_g2), dropout_rate)
            self.head_g3 = self._make_head(hidden_size * 3, len(self.idx_g3), dropout_rate)
            self.head_g4 = self._make_head(hidden_size * 3, len(self.idx_g4), dropout_rate)
            self.head_g5 = self._make_head(hidden_size * 3, len(self.idx_g5), dropout_rate)
            self.head_g6 = self._make_head(hidden_size * 3, len(self.idx_g6), dropout_rate)
        
        elif self.pooling_strategy == 'arch2':
            self.fc = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_size * 4, num_targets)
            )
        
    def _make_head(self, input_dim, output_dim, dropout_rate):
        return nn.Sequential(
            nn.Linear(input_dim, self.config.hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout_rate),
            nn.Linear(self.config.hidden_size, output_dim)
        )
        
    def _masked_mean_pooling(self, hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
        sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask
        
    def _last_token_pool(self, last_hidden_state, attention_mask):
        # 支援 left-padding 與 right-padding
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_state[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_state.shape[0]
            return last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]

    def _get_pooling_features(self, last_hidden_state, attention_mask, token_type_ids):
        # Llama uses mean pooling for better representation
        global_avg = self._masked_mean_pooling(last_hidden_state, attention_mask)
        last_token = self._last_token_pool(last_hidden_state, attention_mask)
        
        # Llama 無 token_type_ids；採用 shared feature
        if token_type_ids is None:
            q_repr = global_avg
            a_repr = global_avg
        else:
            # Fallback for models with token_type_ids
            q_mask = attention_mask * (1 - token_type_ids)
            q_repr = self._masked_mean_pooling(last_hidden_state, q_mask)
            a_mask = attention_mask * token_type_ids
            a_repr = self._masked_mean_pooling(last_hidden_state, a_mask)
        
        return global_avg, last_token, q_repr, a_repr
        
    def _pool_arch2(self, all_hidden_states):
        last_4_layers = all_hidden_states[-4:]
        cls_embeddings = [layer[:, 0, :] for layer in last_4_layers]
        return torch.cat(cls_embeddings, dim=1)
        
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden_state = outputs.last_hidden_state
        
        if self.pooling_strategy == 'mean':
            feature = self._masked_mean_pooling(last_hidden_state, attention_mask)
            output = self.fc(feature)
            
        elif self.pooling_strategy == 'arch1':
            glob, last_tok, q, a = self._get_pooling_features(last_hidden_state, attention_mask, token_type_ids)
            q_feat = torch.cat([glob, last_tok, q], dim=1)
            a_feat = torch.cat([glob, last_tok, a], dim=1)
            output = torch.cat([self.q_head(q_feat), self.a_head(a_feat)], dim=1)
            
        elif self.pooling_strategy == 'arch1_6groups':
            glob, last_tok, q, a = self._get_pooling_features(last_hidden_state, attention_mask, token_type_ids)
            # Llama: use shared feature (mean + context)
            feat_shared = torch.cat([glob, last_tok, q], dim=1)
            
            out_g1 = self.head_g1(feat_shared)
            out_g2 = self.head_g2(feat_shared)
            out_g3 = self.head_g3(feat_shared)
            out_g4 = self.head_g4(feat_shared)
            out_g5 = self.head_g5(feat_shared)
            out_g6 = self.head_g6(feat_shared)
            
            batch_size = input_ids.size(0)
            output = torch.zeros(batch_size, 30, dtype=out_g1.dtype, device=input_ids.device)
            output[:, self.idx_g1] = out_g1
            output[:, self.idx_g2] = out_g2
            output[:, self.idx_g3] = out_g3
            output[:, self.idx_g4] = out_g4
            output[:, self.idx_g5] = out_g5
            output[:, self.idx_g6] = out_g6
            
        elif self.pooling_strategy == 'arch2':
            feature = self._pool_arch2(outputs.hidden_states)
            output = self.fc(feature)
            
        return output

# ------------------- Post-Processing Strategies -------------------

# 1. Optimized Rounder (閾值截斷)
class OptimizedRounder:
    def __init__(self):
        self.coef_ = [0.05, 0.95]
        
    def predict(self, X):
        X = np.nan_to_num(X, nan=0.5)
        X_p = np.copy(X)
        low, high = self.coef_[0], self.coef_[1]
        X_p = np.clip(X_p, low, high)
        if np.unique(X_p).size == 1:
            eps = 1e-6
            max_idx = np.argmax(X)
            X_p[max_idx] += eps
        return X_p

# 2. Voters Rounder (網格吸附 - 防呆修正版)
class VotersRounder:
    def __init__(self, train_vals, dev_threshold=0.01):
        """
        Deviation-aware voters rounder
        - Builds a clean grid from training values
        - Snaps predictions to nearest grid value
        - If snapped values collapse (std < threshold), fallback to original predictions
        """
        clean_vals = train_vals[~np.isnan(train_vals)]
        self.unique_vals = np.sort(np.unique(clean_vals))
        self.dev_threshold = dev_threshold
    
    def predict(self, X):
        X_clean = np.nan_to_num(X, nan=0.5)
        idx = np.abs(X_clean[:, None] - self.unique_vals[None, :]).argmin(axis=1)
        X_p = self.unique_vals[idx]
        
        deviation = np.std(X_p)
        if deviation < self.dev_threshold:
            return X_clean
        return X_p

# 3. Distribution Rounder (分佈擬合)
class DistributionRounder:
    def __init__(self, train_vals):
        self.train_vals = np.sort(train_vals)
        
    def predict(self, X):
        n = len(X)
        ranks = rankdata(X, method='ordinal') - 1
        return np.interp(
            np.linspace(0, 1, n),
            np.linspace(0, 1, len(self.train_vals)),
            self.train_vals
        )[ranks]

# ------------------- Inference Logic -------------------
def inference_fn(test_loader, model, device):
    model.eval()
    preds = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Predicting"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch.get('token_type_ids')
            if token_type_ids is not None:
                token_type_ids = token_type_ids.to(device)
            
            y_preds = model(input_ids, attention_mask, token_type_ids)
            preds.append(y_preds.sigmoid().cpu().numpy())
            
    return np.concatenate(preds)

# ------------------- Main -------------------
if __name__ == '__main__':
    TEST_PATH = 'test.csv'
    TRAIN_PATH = 'train.csv'
    
    if not os.path.exists(TEST_PATH):
        TEST_PATH = '/kaggle/input/google-quest-challenge/test.csv'
        TRAIN_PATH = '/kaggle/input/google-quest-challenge/train.csv'
    
    test = pd.read_csv(TEST_PATH)
    train = pd.read_csv(TRAIN_PATH)
    
    print(f"Base Model: {CFG.base_model}")
    print(f"Test Shape: {test.shape}, Train Shape: {train.shape}")
    print(f"Selected Post-Processing Strategy: {CFG.post_processing}")
    
    tokenizer = AutoTokenizer.from_pretrained(CFG.base_model, trust_remote_code=CFG.trust_remote_code)
    # Llama tokenizer needs pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    test_dataset = QuestDataset(test, tokenizer, max_len=CFG.max_len)
    test_loader = DataLoader(
        test_dataset, 
        batch_size=CFG.batch_size, 
        shuffle=False, 
        num_workers=CFG.num_workers,
        pin_memory=True
    )
    
    # 掃描權重：保留原有的命名規則 (與訓練腳本一致)
    weight_paths = []
    for fold in range(CFG.n_fold):
        path = os.path.join(CFG.model_dir, f"deberta_v3_fold{fold}_best.pth")
        if os.path.exists(path): weight_paths.append(path)
    single_run = os.path.join(CFG.model_dir, "deberta_v3_single_run_best.pth")
    if os.path.exists(single_run):
        weight_paths.append(single_run)
        
    if not weight_paths:
        print("No weights found in model_dir. Please set CFG.model_dir to your Llama checkpoints.")
        sys.exit(1)
        
    print(f"Found {len(weight_paths)} candidate checkpoints. Building models and filtering incompatible ones...")
    
    model_preds = []
    used_weights = 0
    for weight_path in weight_paths:
        model = QuestModel(
            CFG.base_model, 
            num_targets=len(TARGET_COLS), 
            pooling_strategy=CFG.pooling_strategy,
            freeze_backbone=CFG.freeze_backbone,
            trust_remote_code=CFG.trust_remote_code
        )
        model.to(CFG.device)

        try:
            state = torch.load(weight_path, map_location='cpu')
            model.load_state_dict(state, strict=True)
        except Exception as e:
            print(f"[Skip] Incompatible checkpoint: {os.path.basename(weight_path)} -> {e}")
            del model
            torch.cuda.empty_cache()
            gc.collect()
            continue
        
        preds = inference_fn(test_loader, model, CFG.device)
        model_preds.append(preds)
        used_weights += 1
        
        del model
        torch.cuda.empty_cache()
        gc.collect()
        
    if used_weights == 0:
        print("All checkpoints were incompatible with the selected base model.\n"
              "- Ensure CFG.base_model matches the model used during training (e.g., meta-llama/Llama-3.2-1B).\n"
              "- Ensure CFG.model_dir points to your Llama-trained weights, not DeBERTa/Qwen weights.")
        sys.exit(1)
    
    if len(model_preds) > 0:
        avg_preds = np.mean(model_preds, axis=0)
        final_preds = np.zeros_like(avg_preds)
        
        print(f"Applying Post-Processing: {CFG.post_processing}")
        
        for i, col in enumerate(TARGET_COLS):
            train_col_values = train[col].values
            curr_preds = avg_preds[:, i]
            
            if CFG.post_processing == 'raw':
                final_preds[:, i] = curr_preds
                
            elif CFG.post_processing == 'optimized':
                opt = OptimizedRounder()
                final_preds[:, i] = opt.predict(curr_preds)
                
            elif CFG.post_processing == 'voters':
                voter = VotersRounder(train_col_values)
                final_preds[:, i] = voter.predict(curr_preds)
                
            elif CFG.post_processing == 'distribution':
                dist_rounder = DistributionRounder(train_col_values)
                final_preds[:, i] = dist_rounder.predict(curr_preds)
        
        # 儲存提交檔
        try:
            sub_path = '/kaggle/input/google-quest-challenge/sample_submission.csv'
            submission = pd.read_csv(sub_path)
        except Exception:
            submission = pd.DataFrame({'qa_id': test['qa_id']})
        submission[TARGET_COLS] = final_preds
        submission.to_csv('submission.csv', index=False)
        print("submission.csv saved successfully!")
    else:
        print("Error: No predictions.")
