In [None]:
# Google Colabでの設定
google_colab = True

if google_colab:
    from google.colab import drive
    from google.colab import userdata

    drive.mount("/content/drive")

    # ディレクトリ移動
    %cd /content/drive/MyDrive/Python/kaggle_map/src/exp013_qwen2.5-7b-lora

    !pip install -q trl==0.23.0

In [None]:
import os
import gc
import time
import json
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import torch
import wandb
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    TrainingArguments,
    PreTrainedTokenizer,
    Trainer
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

import warnings
warnings.filterwarnings("ignore")

In [None]:
class CFG:
    """実験設定管理クラス"""

    # ============== 実験情報 =============
    comp_name = "kaggle_map"
    exp_name = "exp013_qwen2.5-7b-lora"
    model_name = "Qwen/Qwen2.5-7B"

    # ============== ファイルパス設定 =============
    comp_dir_path = "../../kaggle/input/"
    comp_dataset_path = f"{comp_dir_path}/map-charting-student-math-misunderstandings/"
    output_dir_path = "output/"
    log_dir_path = "logs/"

    # ============== モデル設定 =============
    max_len = 256
    num_train_epochs = 2
    per_device_train_batch_size = 16
    gradient_accumulation_steps = 1
    per_device_eval_batch_size = 8
    optim_type = "adamw_torch"
    learning_rate = 2e-5
    lr_scheduler_type = "cosine"
    warmup_ratio = 0.03
    weight_decay = 0.01

    lora_r = 32
    lora_alpha = 64
    lora_dropout = 0.05
    lora_bias = "none"
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    task_type = "CAUSAL_LM"

    cols = ["prompt", "label"]

    # ============== その他設定 =============
    seed = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ============== プロンプト設定 =============
    prompt_format = """You are a specialist in identifying the types of misunderstandings that arise from students' answers to math problems.
Based on the information provided below, please determine what kind of misunderstanding the student has.

Question: {QuestionText}
Answer: {MC_Answer}
Correct: {Correct}
Student Explanation: {StudentExplanation}
"""

In [None]:
# 乱数固定
def set_seed(seed=None, cudnn_deterministic=True):
    if seed is None:
        seed = 42

    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = cudnn_deterministic
    torch.backends.cudnn.benchmark = False

def make_dirs(cfg):
    for dir in [cfg.output_dir_path, cfg.log_dir_path]:
        os.makedirs(dir, exist_ok=True)

def cfg_init(cfg):
    set_seed(cfg.seed)
    make_dirs(cfg)

In [None]:
cfg_init(CFG)

## データの読み込みと前処理

In [None]:
def add_folds_by_qid_cat_misc(df, n_splits=5, random_state=42, fallback="pair"):
    s_qid = df["QuestionId"].astype(str).fillna("NA")
    s_cat = df["Category"].astype(str).fillna("NA")
    s_misc = df["Misconception"].astype(str).fillna("NA")

    y_triple = s_qid + "|" + s_cat + "|" + s_misc
    y_pair = s_cat + "|" + s_misc

    cnt = y_triple.value_counts()
    if (cnt < n_splits).any():
        if fallback == "pair":
            rare = y_triple.map(cnt) < n_splits
            y = np.where(rare, y_pair, y_triple)
        elif fallback == "category":
            rare = y_triple.map(cnt) < n_splits
            y = np.where(rare, s_cat, y_triple)
        elif fallback == "none":
            y = y_triple
        else:
            raise ValueError("fallback は 'pair' / 'category' / 'none' のいずれかにしてください。")
    else:
        y = y_triple

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    folds = np.full(len(df), -1, dtype=int)
    for fold, (_, val_idx) in enumerate(skf.split(np.zeros(len(df)), y)):
        folds[val_idx] = fold

    out = df.copy()
    out["fold"] = folds
    return out

In [None]:
def wrong_corrections(df: pd.DataFrame) -> pd.DataFrame:
    """既知の誤りを修正する"""
    false_to_true_ids = [12878, 12901, 13876, 14089, 14159, 14185]
    df["Category"] = np.where(
        df["row_id"].isin(false_to_true_ids),
        df["Category"].str.replace("False", "True"),
        df["Category"]
    )

    true_to_false_ids = [14280, 14305, 14321, 14335, 14338,  14352, 14355, 14403, 14407, 14412, 14413, 14418]
    df["Category"] = np.where(
        df["row_id"].isin(true_to_false_ids),
        df["Category"].str.replace("True", "False"),
        df["Category"]
    )
    return df


def replace_duplicate_misc(df: pd.DataFrame) -> pd.DataFrame:
    """誤りのある誤答ラベルを修正する"""
    df["Misconception"] = df["Misconception"].replace({"Wrong_Fraction": "Wrong_fraction"})
    return df


def make_completion(df: pd.DataFrame) -> pd.DataFrame:
    """completion列を作成する"""
    df["Misconception"] = df["Misconception"].fillna("NA")
    df["completion"] = df["Category"] + ":" + df["Misconception"]

    # ラベルエンコード
    le = LabelEncoder()
    df["label"] = le.fit_transform(df["completion"])
    return df


def add_is_correct(df: pd.DataFrame) -> pd.DataFrame:
    """正答かどうかのフラグを追加する"""
    idx = df.apply(lambda row: row["Category"].split("_")[0], axis=1) == "True"
    correct = df.loc[idx].copy()
    correct["count"] = correct.groupby(["QuestionId", "MC_Answer"]).MC_Answer.transform("count")
    correct = correct.sort_values("count", ascending=False)
    correct = correct.drop_duplicates(["QuestionId"])
    correct = correct[["QuestionId", "MC_Answer"]]
    correct["is_correct"] = 1

    df = df.merge(correct, on=["QuestionId", "MC_Answer"], how="left")
    df["is_correct"] = df["is_correct"].fillna(0)
    return df


def format_input(row) -> str:
    """入力テキストのフォーマット"""
    return CFG.prompt_format.format(
        QuestionText=row["QuestionText"],
        MC_Answer=row["MC_Answer"],
        Correct="Yes" if row["is_correct"] else "No",
        StudentExplanation=row["StudentExplanation"],
    )

In [None]:
# 学習データの読み込み
train = pd.read_csv(f"{CFG.comp_dataset_path}/train.csv")

# Fold分割
train = add_folds_by_qid_cat_misc(train, n_splits=5, random_state=42, fallback="pair")

# 既知の誤り修正
train = wrong_corrections(train)

# 重複するMisconceptionの統一
train = replace_duplicate_misc(train)

# completion列の作成
train = make_completion(train)

# 正解フラグの作成
train = add_is_correct(train)

# 入力プロンプトの作成
train["prompt"] = train.apply(format_input, axis=1)

In [None]:
# プロンプトの表示
print(train["prompt"].values[0])

In [None]:
# データセットの分割
train_df = train[train["fold"] != 0].reset_index(drop=True)
val_df = train[train["fold"] == 0].reset_index(drop=True)

train_ds = Dataset.from_pandas(train_df[CFG.cols], preserve_index=False)
val_ds = Dataset.from_pandas(val_df[CFG.cols], preserve_index=False)

## 学習設定

In [None]:
# モデルの読み込み
n_classes = train["label"].nunique()
model = AutoModelForSequenceClassification.from_pretrained(
    CFG.model_name,
    num_labels=n_classes,
    trust_remote_code=True,
    dtype=torch.bfloat16,
    device_map="auto",
)

# トークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained(
    CFG.model_name,
    trust_remote_code=True,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
eos_token_id = tokenizer.eos_token_id

In [None]:
# トークナイズ関数
def tokenize(batch):
    return tokenizer(batch["prompt"], padding="max_length", truncation=True, max_length=CFG.max_len)

train_ds = train_ds.map(tokenize, batched=True)
val_ds = val_ds.map(tokenize, batched=True)

# PyTorch用のフォーマットを設定
columns = ["input_ids", "attention_mask", "label"]
train_ds.set_format(type="torch", columns=columns)
val_ds.set_format(type="torch", columns=columns)

In [None]:
# wandbのログイン
wandb.login(key=userdata.get("WANDB_API_KEY"))
wandb.init(project=CFG.comp_name, name=CFG.exp_name)

In [None]:
# 学習の設定
training_args = TrainingArguments(
    output_dir=CFG.output_dir_path,
    do_train=True,
    do_eval=True,
    num_train_epochs=CFG.num_train_epochs,
    per_device_train_batch_size=CFG.per_device_train_batch_size,
    gradient_accumulation_steps=CFG.gradient_accumulation_steps,
    per_device_eval_batch_size=CFG.per_device_eval_batch_size,
    learning_rate=CFG.learning_rate,
    optim=CFG.optim_type,
    lr_scheduler_type=CFG.lr_scheduler_type,
    warmup_ratio=CFG.warmup_ratio,
    weight_decay=CFG.weight_decay,
    logging_dir=CFG.log_dir_path,
    eval_strategy="steps",
    save_strategy="steps",
    logging_steps=50,
    eval_steps=0.2,
    save_steps=0.2,
    save_total_limit=1,
    metric_for_best_model="map@3",
    greater_is_better=True,
    load_best_model_at_end=True,
    max_grad_norm=1.0,
    report_to="wandb",
    bf16=True,
    fp16=False,  # KaggleはT4なのでFP16で推論
    bf16_full_eval=True,
    gradient_checkpointing=True,
)

In [None]:
# LoRAの設定
lora_config = LoraConfig(
    r=CFG.lora_r,
    lora_alpha=CFG.lora_alpha,
    lora_dropout=CFG.lora_dropout,
    bias=CFG.lora_bias,
    target_modules=CFG.target_modules,
    task_type=CFG.task_type,
)

## モデルの学習

In [None]:
# def preprocess_logits_for_metrics(logits, labels):
#     """logitsから上位3件のトークンIDを抽出"""
#     if isinstance(logits, tuple):
#         logits = logits[0]

#     # top-kの値とインデックスを取得
#     k = 3
#     _, topk_indices = torch.topk(
#         logits,
#         k=min(k, logits.size(-1)),
#         dim=-1,
#         largest=True,
#         sorted=True
#     )

#     return topk_indices, labels

In [None]:
def compute_metrics(eval_preds):
    """メトリクス（accuracy と MAP@3）を計算する関数。"""
    logits, labels = eval_preds
    probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1).numpy()

    top3 = np.argsort(-probs, axis=1)[:, :3]  # トップ3の予測
    match = top3 == labels[:, None]

    # top-1 accuracy
    pred1 = probs.argmax(axis=1)
    accuracy = (pred1 == labels).mean()

    # MAP@3を計算
    map3 = 0
    for i in range(len(labels)):
        if match[i, 0]:
            map3 += 1.0
        elif match[i, 1]:
            map3 += 1.0 / 2
        elif match[i, 2]:
            map3 += 1.0 / 3

    return {"accuracy": accuracy, "map@3": map3 / len(labels)}

In [None]:
# Trainerの設定
trainer = Trainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    peft_config=lora_config,
    compute_metrics=compute_metrics,
    # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.model.print_trainable_parameters()

In [None]:
# 学習
trainer.train()

In [None]:
# モデルの保存
trainer.save_model("/content/model")
tokenizer.save_pretrained("/content/model")

In [None]:
# Hugging Faceにアップロード
from huggingface_hub import HfApi
from huggingface_hub import upload_folder

api = HfApi()
api.create_repo(CFG.exp_name, private=True)

upload_folder(
    repo_id="ricky0526/"+CFG.exp_name,
    folder_path="/content/model",
    commit_message="Initial upload"
)