In [1]:
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# 国家名称映射
COUNTRY_NAME_MAP = {
    'ar-EG': 'Egypt',
    'ar-MA': 'Morocco',
    'ar-SA': 'Saudi Arabia',
    'bg-BG': 'Bulgaria',
    'el-GR': 'Greece',
    'en-AU': 'Australia',
    'en-GB': 'United Kingdom',
    'es-EC': 'Ecuador',
    'es-ES': 'Spain',
    'es-MX': 'Mexico',
    'eu-ES': 'Spain (Basque)',
    'fa-IR': 'Iran',
    'fr-FR': 'France',
    'ga-IE': 'Ireland',
    'id-ID': 'Indonesia',
    'ja-JP': 'Japan',
    'ko-KR': 'South Korea',
    'ms-SG': 'Singapore (Malay)',
    'ta-LK': 'Sri Lanka (Tamil)',
    'ta-SG': 'Singapore (Tamil)',
    'tl-PH': 'Philippines (Tagalog)',
    'zh-CN': 'China',
    'zh-SG': 'Singapore (Chinese)',
    'am-ET': 'Ethiopia',
    'ar-DZ': 'Algeria',
    'as-AS': 'India (Assam)',
    'az-AZ': 'Azerbaijan',
    'en-SG': 'Singapore',
    'en-US': 'United States',
    'eu-PV': 'Basque Country (Spain)',
    'ha-NG': 'Nigeria',
    'ko-KP': 'North Korea',
    'su-JB': 'Indonesia (West Java)',
    'sv-SE': 'Sweden',
    'zh-TW': 'Taiwan',
    'en-AS': 'India (Assam)',
    'en-AZ': 'Azerbaijan',
    'en-BG': 'Bulgaria',
    'en-CN': 'China',
    'en-DZ': 'Algeria',
    'en-EC': 'Ecuador',
    'en-EG': 'Egypt',
    'en-ES': 'Spain',
    'en-ET': 'Ethiopia',
    'en-FR': 'France',
    'en-GR': 'Greece',
    'en-ID': 'Indonesia',
    'en-IE': 'Ireland',
    'en-IR': 'Iran',
    'en-JB': 'Indonesia (West Java)',
    'en-JP': 'Japan',
    'en-KP': 'North Korea',
    'en-KR': 'South Korea',
    'en-LK': 'Sri Lanka',
    'en-MA': 'Morocco',
    'en-MX': 'Mexico',
    'en-NG': 'Nigeria',
    'en-PH': 'Philippines',
    'en-PV': 'Basque Country (Spain)',
    'en-SA': 'Saudi Arabia',
    'en-SE': 'Sweden',
    'en-TW': 'Taiwan',
}

# ===============================================
# 加载模型和tokenizer
# ===============================================
lora_path = "/kaggle/input/track-2-instruction-tuning/final_lora"
base_model = "Qwen3/Qwen3-4B"
# 加载 tokenizer 和 base model
tokenizer = AutoTokenizer.from_pretrained(lora_path, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-4B",
    device_map="auto",
    trust_remote_code=True,
    torch_dtype="bfloat16"
)

tokenizer.pad_token = tokenizer.eos_token

# 加载 LoRA 适配器
model = PeftModel.from_pretrained(base_model, lora_path)
model.eval()

print("模型加载完成！")

# ===============================================
# 推理辅助函数
# ===============================================

def create_prompt(question, options, country):
    """
    创建提示词，与track2_3.py训练时一致
    """
    # 构建选项文本
    choices_text = ""
    for i, opt in enumerate(options):
        letter = chr(65 + i)  # A, B, C, D
        choices_text += f"{letter}. {opt}\n"

    # 使用与track2_3.py训练一致的提示词格式
    prompt = (
        f"As a local resident of {country}, please answer the following question based on common knowledge in {country}.\n"
        f"Question: {question}\n"
        f"Options:\n{choices_text}"
        f"Answer with only the letter (A, B, C, or D):"
    )

    # 构建完整提示
    full_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
    return full_prompt

# ===============================================
# 推理主程序
# ===============================================

# 测试数据路径
test_file = "/kaggle/input/sem-eval-2026-task-7/track_2_mcq_input.tsv"

# 读取测试数据
test_df = pd.read_csv(test_file, header=0, delimiter="\t")
print(f"测试数据加载完成，共 {len(test_df)} 条记录")

# 进行推理
predictions = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="推理进度"):
    # 提取ID和问题
    id_value = row["id"]
    question = row["question"]
    
    # 从ID中提取语言区域代码
    lang_reg = id_value.split('_')[0]
    
    # 从lang_reg中提取国家名称
    country = COUNTRY_NAME_MAP.get(lang_reg, "the relevant country")
    
    # 提取选项
    options = [row["option A"], row["option B"], row["option C"], row["option D"]]
    
    # 创建提示词
    prompt = create_prompt(question, options, country)
    
    # 生成输入
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=1024,
        padding=False,
        add_special_tokens=True,
    ).to(model.device)
    
    # 推理
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            temperature=0.1,
            top_p=0.95
        )
    
    # 解析结果
    generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
    pred_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
    
    # 提取选项字母
    best_option = None
    for letter in ["A", "B", "C", "D"]:
        if letter in pred_text:
            best_option = letter
            break
    
    # 如果没有找到明确的选项，选择第一个作为默认
    if best_option is None:
        best_option = "A"
    
    predictions.append(best_option)

# ===============================================
# 保存结果
# ===============================================

# 创建结果DataFrame
result_df = pd.DataFrame({
    "id": test_df["id"],
    "A": [1 if pred == "A" else 0 for pred in predictions],
    "B": [1 if pred == "B" else 0 for pred in predictions],
    "C": [1 if pred == "C" else 0 for pred in predictions],
    "D": [1 if pred == "D" else 0 for pred in predictions]
})

# 保存结果
output_file = "./track_2_mcq_output.tsv"
result_df.to_csv(output_file, index=False, sep='\t')

print(f"推理完成！结果已保存到 {output_file}")
print(f"共处理 {len(result_df)} 条记录")

# 显示部分结果
print("\n部分推理结果：")
print(result_df.head())

2026-01-25 07:18:23.012984: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769325503.227569      24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769325503.291083      24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769325503.758761      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769325503.758796      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769325503.758799      24 computation_placer.cc:177] computation placer alr

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

模型加载完成！
测试数据加载完成，共 47014 条记录


推理进度: 100%|██████████| 47014/47014 [6:51:05<00:00,  1.91it/s]

推理完成！结果已保存到 ./track_2_mcq_output.tsv
共处理 47014 条记录

部分推理结果：
           id  A  B  C  D
0  ga-IE_0001  1  0  0  0
1  ga-IE_0002  1  0  0  0
2  ga-IE_0003  1  0  0  0
3  ga-IE_0004  0  0  1  0
4  ga-IE_0005  0  0  1  0



