In [None]:
import json
import re
from dataclasses import dataclass
from typing import List, Dict, Optional

import torch
from torch.utils.data import Dataset, random_split
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
from peft import LoraConfig, get_peft_model
import matplotlib.pyplot as plt
import numpy as np
import os
import warnings

# 忽略 FutureWarning
warnings.filterwarnings("ignore", category=FutureWarning)

MODEL_ID = "meta-llama/Llama-3.2-1B"
TRAIN_PATH = "commonsense_15k.json"
OUTPUT_ROOT = "./llama32_1b_team_lora"  # 每個 setting 存不同子資料夾


def extract_options_from_instruction(instruction: str, answer_format: str) -> List[str]:
    if "true/false" in answer_format.lower():
        return ["true", "false"]
    
    formats = {
        "Answer": r"Answer\d+:",
        "Solution": r"Solution\d+:",
        "Ending": r"Ending\d+:",
        "Option": r"Option\d+:",
    }

    detected = None
    for fmt, pattern in formats.items():
        if re.search(pattern, instruction):
            detected = fmt
            break

    if detected is None:
        return None

    boundary = (
        rf"(?=\s*{detected}\d+:|\s*Answer format:|\s*Solution format:|$)"
    )

    if detected in ["Answer", "Option"]:
        # Multi-line safe version (handles same-line concatenated Answers)
        pattern = rf"({detected}\d+:\s*.*?){boundary}"
        matches = re.findall(pattern, instruction, re.DOTALL)
        return [m.strip() for m in matches]

    elif detected in ["Solution", "Ending"]:
        # Multi-line blocks (may contain many sentences)
        pattern = rf"({detected}\d+:\s*.*?){boundary}"
        matches = re.findall(pattern, instruction, re.DOTALL)
        return [m.strip() for m in matches]

    return None


def parse_answer_key(answer: str) -> str:
    """
    解析答案，返回標準化的答案
    """
    answer = answer.lower().strip()
    
    # 直接的 true/false
    if answer in ["true", "false"]:
        return answer
    
    # solution1/solution2 格式
    solution_match = re.search(r"solution(\d+)", answer)
    if solution_match:
        return f"solution{solution_match.group(1)}"
    
    # answer1/answer2/answer3 格式
    answer_match = re.search(r"answer(\d+)", answer)
    if answer_match:
        return f"answer{answer_match.group(1)}"
    
    # ending1/ending2/ending3/ending4 格式
    ending_match = re.search(r"ending(\d+)", answer)
    if ending_match:
        return f"ending{ending_match.group(1)}"
    
    # option1/option2 格式
    option_match = re.search(r"option(\d+)", answer)
    if option_match:
        return f"option{option_match.group(1)}"
    
    return answer


def get_correct_option_index(parsed_answer: str, options: List[str]) -> int:
    """
    根據解析後的答案找到正確選項的索引
    """
    if parsed_answer in ["true", "false"]:
        try:
            return options.index(parsed_answer)
        except ValueError:
            return -1
    
    # 對於其他格式，提取數字
    number_match = re.search(r"(\d+)", parsed_answer)
    if number_match:
        # 將1-based索引轉換為0-based索引
        return int(number_match.group(1)) - 1
    
    return -1


class TeamBinaryDataset(Dataset):
    def __init__(self, path: str, tokenizer, max_length: int = 512):
        self.samples = []
        self.tokenizer = tokenizer
        self.max_length = max_length

        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
            for item in data:
                question = self._extract_question(item["instruction"])
                answer_format = self._extract_answer_format(item["instruction"])
                
                options = extract_options_from_instruction(item["instruction"], answer_format)
                if not options:
                    continue
                
                parsed_answer = parse_answer_key(item["answer"])
                correct_idx = get_correct_option_index(parsed_answer, options)
                
                if correct_idx == -1:
                    continue
                
                for i, option in enumerate(options):
                    text = self.build_text(question, option)
                    label = 1 if i == correct_idx else 0
                    self.samples.append({"text": text, "label": label})

    def _extract_question(self, instruction: str) -> str:
        if "question:" in instruction.lower():
            question_start = instruction.lower().find("question:")
            question_part = instruction[question_start + 9:]
            
            if "answer format:" in question_part.lower():
                question = question_part.split("Answer format:")[0].strip()
            else:
                question = question_part.strip()
            return question
        
        lines = instruction.split('\n')
        for line in lines:
            if line.strip() and not any(keyword in line.lower() for keyword in 
                                      ['solution', 'answer', 'ending', 'option']):
                return line.strip()
        
        return instruction.split('\n')[0].strip() if instruction else ""

    def _extract_answer_format(self, instruction: str) -> str:
        format_match = re.search(r"Answer format:\s*([^\n]+)", instruction, re.IGNORECASE)
        if format_match:
            return format_match.group(1).strip()
        return ""

    def build_text(self, question: str, candidate_answer: str) -> str:
        if len(candidate_answer) > 200:
            candidate_answer = candidate_answer[:200] + "..."
        
        return f"Question: {question}\nCandidate answer: {candidate_answer}\n"

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        encoded = self.tokenizer(
            item["text"],
            truncation=True,
            max_length=self.max_length,
        )
        encoded["labels"] = item["label"]
        return encoded

# ================== LoRA 設定 ==================
LORA_SETTINGS = {
    "attn_light": {
        "description": 'ATTN—light: ["q_proj", "v_proj"]',
        "target_modules": ["q_proj", "v_proj"],
    },
    "attn_ffn_medium": {
        "description": 'ATTN+FFN—medium: ["q_proj","k_proj","v_proj","up_proj","down_proj"]',
        "target_modules": ["q_proj", "k_proj", "v_proj", "up_proj", "down_proj"],
    },
    "full_heavy": {
        "description": 'Full—heavy: ["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]',
        "target_modules": [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "up_proj",
            "down_proj",
            "gate_proj",
        ],
    },
}


def make_lora_model(base_model, target_modules: List[str]):
    """
    把 base_model 包成 LoRA 版本，target_modules 依據 setting 不同
    """
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        target_modules=target_modules,
        task_type="SEQ_CLS",  # sequence classification
    )
    model = get_peft_model(base_model, lora_config)
    model.print_trainable_parameters()
    return model


# ================== 訓練 + 記錄 Learning Curve ==================
def train_one_setting(
    setting_name: str,
    target_modules: List[str],
    dataset: TeamBinaryDataset,
    tokenizer,
    num_epochs: int = 3,
):
    print(f"\n========== Training setting: {setting_name} ==========")
    print(f"Target modules: {target_modules}")

    # 1. split train/val
    val_ratio = 0.1
    val_size = int(len(dataset) * val_ratio)
    train_size = len(dataset) - val_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    # 2. base model + LoRA
    base_model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_ID,
        num_labels=2,
        problem_type="single_label_classification",
    )
    model = make_lora_model(base_model, target_modules)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    out_dir = os.path.join(OUTPUT_ROOT, setting_name)
    os.makedirs(out_dir, exist_ok=True)

    training_args = TrainingArguments(
        output_dir=out_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        weight_decay=0.01,
        warmup_ratio=0.03,
        logging_steps=50,
        save_strategy="epoch",
        eval_strategy="epoch",
        fp16=torch.cuda.is_available(),
        report_to=[],
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=data_collator,
        processing_class=tokenizer,
    )

    trainer.train()
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)

    log_history = trainer.state.log_history

    train_losses = []   # (epoch, loss)
    eval_losses = []    # (epoch, loss)

    for log in log_history:
        if "loss" in log and "epoch" in log and "eval_loss" not in log:
            train_losses.append((log["epoch"], log["loss"]))
        if "eval_loss" in log and "epoch" in log:
            eval_losses.append((log["epoch"], log["eval_loss"]))

    return {
        "train_losses": train_losses,
        "eval_losses": eval_losses,
    }

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

full_dataset = TeamBinaryDataset(TRAIN_PATH, tokenizer, max_length=512)

all_curves = {}  # setting_name -> dict(train_losses, eval_losses)

for setting_name, cfg in LORA_SETTINGS.items():
    curves = train_one_setting(
        setting_name=setting_name,
        target_modules=cfg["target_modules"],
        dataset=full_dataset,
        tokenizer=tokenizer,
        num_epochs=3,
    )
    all_curves[setting_name] = curves

# ================== 畫 Learning Curves ==================
# 把三個設定畫在同一張圖（train 與 val 分開 subplot）
plt.figure(figsize=(10, 5))

# subplot 1: train loss
plt.subplot(1, 2, 1)
for setting_name, curves in all_curves.items():
    epochs, losses = zip(*curves["train_losses"])
    plt.plot(epochs, losses, marker="o", label=f"{setting_name} (train)")
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.title("Training Loss per Setting")
plt.legend()
plt.grid(True)

# subplot 2: eval loss
plt.subplot(1, 2, 2)
for setting_name, curves in all_curves.items():
    if len(curves["eval_losses"]) == 0:
        continue
    epochs, losses = zip(*curves["eval_losses"])
    plt.plot(epochs, losses, marker="o", label=f"{setting_name} (val)")
plt.xlabel("Epoch")
plt.ylabel("Validation Loss")
plt.title("Validation Loss per Setting")
plt.legend()
plt.grid(True)

plt.tight_layout()
os.makedirs(OUTPUT_ROOT, exist_ok=True)
plt.savefig(os.path.join(OUTPUT_ROOT, "learning_curves.png"))
plt.show()

In [None]:
import pandas as pd
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

class TestDataset(Dataset):
    def __init__(self, test_data: pd.DataFrame, tokenizer, max_length: int = 512):
        self.samples = []
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.test_ids = []
        
        for idx, row in test_data.iterrows():
            test_id = row['id']
            instruction = row['instruction']
            
            # 提取問題和選項
            question = self._extract_question(instruction)
            answer_format = self._extract_answer_format(instruction)
            options = extract_options_from_instruction(instruction, answer_format)
            
            if not options:
                continue
                
            # 為每個選項創建一個樣本
            for i, option in enumerate(options):
                text = self.build_text(question, option)
                self.samples.append({
                    "text": text,
                    "test_id": test_id,
                    "option_idx": i,
                    "option_text": option
                })
                self.test_ids.append(test_id)

    def _extract_question(self, instruction: str) -> str:
        if "question:" in instruction.lower():
            question_start = instruction.lower().find("question:")
            question_part = instruction[question_start + 9:]
            
            # 找到第一個 Answer 的位置
            answer_start = float('inf')
            for pattern in ["Answer1:", "Solution1:", "Ending1:", "Option1:"]:
                pos = question_part.find(pattern)
                if pos != -1 and pos < answer_start:
                    answer_start = pos
            
            if answer_start != float('inf'):
                question = question_part[:answer_start].strip()
            else:
                question = question_part.strip()
            return question
        
        # 如果沒有 "question:"，取第一行作為問題
        lines = instruction.split('\n')
        for line in lines:
            if line.strip() and not any(keyword in line for keyword in 
                                      ['Answer1:', 'Solution1:', 'Ending1:', 'Option1:']):
                return line.strip()
        
        return instruction.split('\n')[0].strip() if instruction else ""

    def _extract_answer_format(self, instruction: str) -> str:
        # 根據內容判斷答案格式
        if "Answer1:" in instruction:
            return "answer1/answer2/answer3/answer4"
        elif "Solution1:" in instruction:
            return "solution1/solution2"
        elif "Ending1:" in instruction:
            return "ending1/ending2/ending3/ending4"
        elif "Option1:" in instruction:
            return "option1/option2"
        return ""

    def build_text(self, question: str, candidate_answer: str) -> str:
        if len(candidate_answer) > 200:
            candidate_answer = candidate_answer[:200] + "..."
        
        return f"Question: {question}\nCandidate answer: {candidate_answer}\n"

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        # ✅ 修正：只返回 tokenizer 編碼的結果
        encoded = self.tokenizer(
            item["text"],
            truncation=True,
            max_length=self.max_length,
        )
        
        # ✅ 添加額外信息，但不在編碼結果中
        return {
            **encoded,  # 包含 input_ids, attention_mask
            "test_id": item["test_id"],
            "option_idx": item["option_idx"],
            "option_text": item["option_text"]
        }

def custom_collate_fn(batch):
    """
    自定義的 collate 函數，處理批次數據
    """
    # 分離數值數據和額外信息
    tokenizer_keys = ['input_ids', 'attention_mask']
    extra_keys = ['test_id', 'option_idx', 'option_text']
    
    # 處理 tokenizer 輸出
    tokenizer_batch = {}
    for key in tokenizer_keys:
        if key in batch[0]:
            tokenizer_batch[key] = [item[key] for item in batch]
    
    # 使用 DataCollatorWithPadding 處理
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    collated = data_collator(tokenizer_batch)
    
    # 添加額外信息
    for key in extra_keys:
        if key in batch[0]:
            collated[key] = [item[key] for item in batch]
    
    return collated

def predict_with_lora_model(model_path: str, test_csv_path: str, output_csv_path: str, tokenizer):
    """
    使用訓練好的 LoRA 模型進行預測
    
    Args:
        model_path: 訓練好的 LoRA 模型路徑
        test_csv_path: 測試數據 CSV 路徑
        output_csv_path: 輸出預測結果 CSV 路徑
        tokenizer: 分詞器
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # 1. 載入測試數據
    test_data = pd.read_csv(test_csv_path)
    print(f"載入 {len(test_data)} 條測試數據")
    
    # 2. 創建測試數據集
    test_dataset = TestDataset(test_data, tokenizer, max_length=512)
    
    # ✅ 使用自定義的 collate 函數
    test_loader = DataLoader(
        test_dataset, 
        batch_size=8, 
        shuffle=False,
        collate_fn=custom_collate_fn
    )
    
    # 3. 載入模型並確保 pad_token_id 設置正確
    from peft import PeftModel
    base_model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_ID,
        num_labels=2,
        problem_type="single_label_classification",
        pad_token_id=tokenizer.pad_token_id,
    )
    
    # ✅ 確保模型配置正確
    if hasattr(base_model, 'config'):
        base_model.config.pad_token_id = tokenizer.pad_token_id
    
    model = PeftModel.from_pretrained(base_model, model_path)
    model.to(device)
    model.eval()
    
    # 4. 進行預測
    predictions = {}  # test_id -> {option_idx: probability}
    
    print("開始預測...")
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="預測中"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            test_ids = batch["test_id"]
            option_indices = batch["option_idx"]
            
            # 前向傳播
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = torch.softmax(outputs.logits, dim=1)[:, 1]  # 取正確答案的機率
            
            # 收集預測結果
            for test_id, option_idx, prob in zip(test_ids, option_indices, probabilities):
                if test_id not in predictions:
                    predictions[test_id] = {}
                predictions[test_id][option_idx] = prob.item()
    
    # 5. 生成最終預測
    results = []
    for test_id, options in predictions.items():
        # 找到機率最高的選項
        best_option_idx = max(options.keys(), key=lambda k: options[k])
        
        # 根據選項索引生成答案標籤
        test_row = test_data[test_data['id'] == test_id]['instruction'].iloc[0]
        if "Answer" in test_row:
            answer_label = f"answer{best_option_idx + 1}"
        elif "Solution" in test_row:
            answer_label = f"solution{best_option_idx + 1}"
        elif "Ending" in test_row:
            answer_label = f"ending{best_option_idx + 1}"
        elif "Option" in test_row:
            answer_label = f"option{best_option_idx + 1}"
        else:
            answer_label = f"answer{best_option_idx + 1}"  # 默認
        
        results.append({
            "id": test_id,
            "answer": answer_label
        })
    
    # 6. 保存結果
    result_df = pd.DataFrame(results)
    result_df.to_csv(output_csv_path, index=False)
    print(f"預測結果已保存到: {output_csv_path}")
    print(f"預測了 {len(results)} 條數據")
    
    return result_df

def run_inference_for_all_settings(test_csv_path: str, output_dir: str):
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    os.makedirs(output_dir, exist_ok=True)
    
    for setting_name in LORA_SETTINGS.keys():
        print(f"\n========== 使用 {setting_name} 模型進行預測 ==========")
        
        model_path = os.path.join(OUTPUT_ROOT, setting_name)
        output_csv = os.path.join(output_dir, f"predictions_{setting_name}.csv")
        
        try:
            predict_with_lora_model(
                model_path=model_path,
                test_csv_path=test_csv_path,
                output_csv_path=output_csv,
                tokenizer=tokenizer
            )
        except Exception as e:
            print(f"預測 {setting_name} 時發生錯誤: {e}")

In [None]:
test_csv_path = "test.csv"  
output_dir = "./predictions"

run_inference_for_all_settings(test_csv_path, output_dir)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

predict_with_lora_model(
    model_path="./llama32_1b_team_lora/attn_light",
    test_csv_path=test_csv_path,
    output_csv_path="final_predictions.csv",
    tokenizer=tokenizer
)