In [1]:
# Google Colabでの設定
google_colab = True

if google_colab:
    from google.colab import drive
    from google.colab import userdata

    drive.mount("/content/drive")

    # ディレクトリ移動
    %cd /content/drive/MyDrive/Python/kaggle_map/src/exp017_qwen2.5-14b-lora

    !pip install -q -U trl==0.23.0
    !pip install -q accelerate==1.10.1 bitsandbytes==0.47.0 mpi4py==4.1.0 deepspeed==0.17.6

Mounted at /content/drive
/content/drive/MyDrive/Python/kaggle_map/src/exp017_qwen2.5-14b-lora
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.7/564.7 kB[0m [31m38.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m82.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.0/54.0 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for deepspeed (setup.py) ... [?25l[?25hdone


In [2]:
import os
import gc
import time
import json
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import torch
import wandb
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    PreTrainedTokenizer
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

import warnings
warnings.filterwarnings("ignore")

In [3]:
# DeepSpeed requires a distributed environment even when only one process is used.
# This emulates a launcher in the notebook
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "9995"  # modify if RuntimeError: Address already in use
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

In [4]:
class CFG:
    """実験設定管理クラス"""

    # ============== 実験情報 =============
    comp_name = "kaggle_map"
    exp_name = "exp017_qwen2.5-14b-lora"
    model_name = "Qwen/Qwen2.5-14B-Instruct"

    # ============== ファイルパス設定 =============
    comp_dir_path = "../../kaggle/input/"
    comp_dataset_path = f"{comp_dir_path}/map-charting-student-math-misunderstandings/"
    output_dir_path = "output/"
    log_dir_path = "logs/"

    # ============== モデル設定 =============
    max_len = 400
    num_train_epochs = 2
    per_device_train_batch_size = 8
    gradient_accumulation_steps = 2
    per_device_eval_batch_size = 8
    optim_type = "adamw_torch"
    learning_rate = 8e-5
    lr_scheduler_type = "cosine"
    warmup_ratio = 0.03
    weight_decay = 0.01

    lora_r = 64
    lora_alpha = 128
    lora_dropout = 0.01
    lora_bias = "none"
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    task_type = "CAUSAL_LM"

    cols = ["prompt", "completion"]

    # ============== その他設定 =============
    seed = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ============== プロンプト設定 =============
# ============== プロンプト設定 =============
    prompt_format = """You are a specialist in identifying the types of misunderstandings that arise from students' answers to math problems.
Based on the information provided below, please determine what kind of misunderstanding the student has.

{custom_prompt}
Student Answer: {MC_Answer}
Student Explanation: {StudentExplanation}
Correct: {Correct}
"""

    custom_prompts = {
        31772: """Question: What fraction of the shape is not shaded? Give your answer in its simplest form. [Image: A triangle split into 9 equal smaller triangles. 6 of them are shaded.]
Choices: (A) \\( \\frac{1}{3} \\), (B) \\( \\frac{3}{6} \\), (C) \\( \\frac{3}{8} \\), (D) \\( \\frac{3}{9} \\)
Correct Answer: \\( \\frac{1}{3} \\)
Misconception List: WNB, Incomplete""",

        31774: """Question: Calculate \\( \\frac{1}{2} \\div 6 \\)
Choices: (A) \\( \\frac{1}{12} \\), (B) \\( 3 \\), (C)\\( \\frac{1}{3} \\), (D) \\( \\frac{6}{2} \\)
Correct Answer: \\( \\frac{1}{12} \\)
Misconception List: SwapDividend, Mult, FlipChange""",

        31777: """Question: A box contains \\( 120 \\) counters. The counters are red or blue. \\( \\frac{3}{5} \\) of the counters are red.
Choices: (A) \\( 72 \\), (B) \\( 24 \\), (C) \\( 60 \\), (D) \\( 48 \\)
Correct Answer: \\( 72 \\)
Misconception List: Incomplete, Irrelevant, Wrong_Fraction""",

        31778: """Question: \\( \\frac{A}{10}=\\frac{9}{15} \\) What is the value of \\( A \\) ?
Choices: (A) \\( 3 \\), (B) \\( 4 \\), (C) \\( 6 \\), (D) \\( 9 \\)
Correct Answer: \\( 6 \\)
Misconception List: Additive, Irrelevant, WNB""",

        32829: """Question: \\( 2 y=24 \\) What is the value of \\( y \\) ?
Choices: (A) \\( 4 \\), (B) \\( 12 \\), (C) \\( 48 \\), (D) \\( 22 \\)
Correct Answer: \\( 12 \\)
Misconception List: Not_variable, Adding_terms, Inverse_operation""",

        32833: """Question: Calculate \\( \\frac{2}{3} \\times 5 \\)
Choices: (A) \\( \\frac{10}{15} \\), (B) \\( \\frac{2}{15} \\), (C) \\( 5 \\frac{2}{3} \\), (D) \\( 3 \\frac{1}{3} \\)
Correct Answer: \\( 3 \\frac{1}{3} \\)
Misconception List: Duplication, Inversion, Wrong_Operation""",

        32835: """Question: Which number is the greatest?
Choices: (A) \\( 6 \\), (B) \\( 6.2 \\), (C) \\( 6.079 \\), (D) \\( 6.0001 \\)
Correct Answer: \\( 6.2 \\)
Misconception List: Whole_numbers_larger, Longer_is_bigger, Ignores_zeroes, Shorter_is_bigger""",

        33471: """Question: A bag contains \\( 24 \\) yellow and green balls. \\( \\frac{3}{8} \\) of the balls are yellow. How many of the balls are green?
Choices: (A) \\( 8 \\), (B) \\( 9 \\), (C) \\( 15 \\), (D) \\( 3 \\)
Correct Answer: \\( 15 \\)
Misconception List: Wrong_fraction, Incomplete""",

        33472: """Question: \\( \\frac{1}{3}+\\frac{2}{5}= \\)
Choices: (A) \\( \\frac{3}{8} \\), (B) \\( \\frac{3}{15} \\), (C) \\( \\frac{11}{30} \\), (D) \\( \\frac{11}{15} \\)
Correct Answer: \\( \\frac{11}{15} \\)
Misconception List: Adding_across, Denominator-only_change, Incorrect_equivalent_fraction_addition""",

        33474: """Question: Sally has \\( \\frac{2}{3} \\) of a whole cake in the fridge. Robert eats \\( \\frac{1}{3} \\) of this piece. What fraction of the whole cake has Robert eaten?
Choices: (A) \\( \\frac{1}{3}+\\frac{2}{3} \\), (B) \\( \\frac{1}{3} \\times \\frac{2}{3} \\), (C) \\( \\frac{2}{3} \\div \\frac{1}{3} \\), (D) \\( \\frac{2}{3}-\\frac{1}{3} \\)
Correct Answer: \\( \\frac{1}{3} \\times \\frac{2}{3} \\)
Misconception List: Division, Subtraction""",

        76870: """Question: This is part of a regular polygon. How many sides does it have? [Image: A diagram showing an obtuse angle labelled 144 degrees]
Choices: (A) \\( 10 \\), (B) \\( 5 \\), (C) \\( 6 \\), (D) Not enough information
Correct Answer: \\( 10 \\)
Misconception List: Unknowable, Definition, Interior""",

        89443: """Question: What number belongs in the box?
\\( (-8)-(-5)=\\square \\)
Choices: (A) \\( -13 \\), (B) \\( 13 \\), (C) \\( 3 \\), (D) \\( -3 \\)
Correct Answer: \\( -3 \\)
Misconception List: Positive, Tacking""",

        91695: """Question: Dots have been arranged in these patterns: [Image: Pattern 1 consists of 6 dots, Pattern 2 consists of 10 dots, Pattern 3 consists of 14 dots and Pattern 4 consists of 18 dots] How many dots would there be in Pattern \\( 6 \\) ?
Choices: (A) \\( 26 \\), (B) \\( 20 \\), (C) \\( 36 \\), (D) \\( 22 \\)
Correct Answer: \\( 26 \\)
Misconception List: Wrong_term, Firstterm""",

        104665: """Question: It takes \\( 3 \\) people a total of \\( 192 \\) hours to build a wall.
How long would it take if \\( 12 \\) people built the same wall?
Choices: (A) \\( 768 \\) hours, (B) \\( 48 \\) hours, (C) \\( 192 \\) hours, (D) \\( 64 \\) hours
Correct Answer: \\( 48 \\) hours
Misconception List: Base_rate, Multiplying_by_4""",

        109465: """Question: The probability of an event occurring is \\( 0.9 \\).
Which of the following most accurately describes the likelihood of the event occurring?
Choices: (A) Certain, (B) Likely, (C) Unlikely, (D) Impossible
Correct Answer: Likely
Misconception List: Certainty, Scale"""
    }

In [5]:
# 乱数固定
def set_seed(seed=None, cudnn_deterministic=True):
    if seed is None:
        seed = 42

    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = cudnn_deterministic
    torch.backends.cudnn.benchmark = False

def make_dirs(cfg):
    for dir in [cfg.output_dir_path, cfg.log_dir_path]:
        os.makedirs(dir, exist_ok=True)

def cfg_init(cfg):
    set_seed(cfg.seed)
    make_dirs(cfg)

In [6]:
cfg_init(CFG)

## データの読み込みと前処理

In [7]:
def add_folds_by_qid_cat_misc(df, n_splits=5, random_state=42, fallback="pair"):
    s_qid = df["QuestionId"].astype(str).fillna("NA")
    s_cat = df["Category"].astype(str).fillna("NA")
    s_misc = df["Misconception"].astype(str).fillna("NA")

    y_triple = s_qid + "|" + s_cat + "|" + s_misc
    y_pair = s_cat + "|" + s_misc

    cnt = y_triple.value_counts()
    if (cnt < n_splits).any():
        if fallback == "pair":
            rare = y_triple.map(cnt) < n_splits
            y = np.where(rare, y_pair, y_triple)
        elif fallback == "category":
            rare = y_triple.map(cnt) < n_splits
            y = np.where(rare, s_cat, y_triple)
        elif fallback == "none":
            y = y_triple
        else:
            raise ValueError("fallback は 'pair' / 'category' / 'none' のいずれかにしてください。")
    else:
        y = y_triple

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    folds = np.full(len(df), -1, dtype=int)
    for fold, (_, val_idx) in enumerate(skf.split(np.zeros(len(df)), y)):
        folds[val_idx] = fold

    out = df.copy()
    out["fold"] = folds
    return out

In [8]:
def wrong_corrections(df: pd.DataFrame) -> pd.DataFrame:
    """既知の誤りを修正する"""
    false_to_true_ids = [12878, 12901, 13876, 14089, 14159, 14185]
    df["Category"] = np.where(
        df["row_id"].isin(false_to_true_ids),
        df["Category"].str.replace("False", "True"),
        df["Category"]
    )

    true_to_false_ids = [14280, 14305, 14321, 14335, 14338,  14352, 14355, 14403, 14407, 14412, 14413, 14418]
    df["Category"] = np.where(
        df["row_id"].isin(true_to_false_ids),
        df["Category"].str.replace("True", "False"),
        df["Category"]
    )
    return df


def replace_duplicate_misc(df: pd.DataFrame) -> pd.DataFrame:
    """誤りのある誤答ラベルを修正する"""
    df["Misconception"] = df["Misconception"].replace({"Wrong_Fraction": "Wrong_fraction"})
    return df


def make_completion(df: pd.DataFrame) -> pd.DataFrame:
    """completion列を作成する"""
    df["Misconception"] = df["Misconception"].fillna("NA")
    df["completion"] = df["Category"] + ":" + df["Misconception"]
    return df


def add_is_correct(df: pd.DataFrame) -> pd.DataFrame:
    """正答かどうかのフラグを追加する"""
    idx = df.apply(lambda row: row["Category"].split("_")[0], axis=1) == "True"
    correct = df.loc[idx].copy()
    correct["count"] = correct.groupby(["QuestionId", "MC_Answer"]).MC_Answer.transform("count")
    correct = correct.sort_values("count", ascending=False)
    correct = correct.drop_duplicates(["QuestionId"])
    correct = correct[["QuestionId", "MC_Answer"]]
    correct["is_correct"] = 1

    df = df.merge(correct, on=["QuestionId", "MC_Answer"], how="left")
    df["is_correct"] = df["is_correct"].fillna(0)
    return df


def format_input(row) -> str:
    """入力テキストのフォーマット - row_idに基づいてカスタムプロンプトを使用"""
    if row["QuestionId"] not in CFG.custom_prompts:
        print("QuestionIdが未知です。")
        escape_format = """You are a specialist in identifying the types of misunderstandings that arise from students' answers to math problems.
Based on the information provided below, please determine what kind of misunderstanding the student has.

Question: {QuestionText}
Answer: {MC_Answer}
Correct: {Correct}
Student Explanation: {StudentExplanation}
"""
        return escape_format.format(
            QuestionText=row["QuestionText"],
            MC_Answer=row["MC_Answer"],
            Correct="Yes" if row["is_correct"] else "No",
            StudentExplanation=row["StudentExplanation"],
        )
    else:
        custom_prompt = CFG.custom_prompts[row["QuestionId"]]
        return CFG.prompt_format.format(
            custom_prompt=custom_prompt,
            MC_Answer=row["MC_Answer"],
            Correct="Yes" if row["is_correct"] else "No",
            StudentExplanation=row["StudentExplanation"],
        )

In [9]:
# 学習データの読み込み
train = pd.read_csv(f"{CFG.comp_dataset_path}/train.csv")

# Fold分割
train = add_folds_by_qid_cat_misc(train, n_splits=5, random_state=42, fallback="pair")

# 既知の誤り修正
train = wrong_corrections(train)

# 重複するMisconceptionの統一
train = replace_duplicate_misc(train)

# completion列の作成
train = make_completion(train)

# 正解フラグの作成
train = add_is_correct(train)

# 入力プロンプトの作成
train["prompt"] = train.apply(format_input, axis=1)

In [10]:
# プロンプトの表示
print(train["prompt"].values[0])

You are a specialist in identifying the types of misunderstandings that arise from students' answers to math problems.
Based on the information provided below, please determine what kind of misunderstanding the student has.

Question: What fraction of the shape is not shaded? Give your answer in its simplest form. [Image: A triangle split into 9 equal smaller triangles. 6 of them are shaded.]
Choices: (A) \( \frac{1}{3} \), (B) \( \frac{3}{6} \), (C) \( \frac{3}{8} \), (D) \( \frac{3}{9} \)
Correct Answer: \( \frac{1}{3} \)
Misconception List: WNB, Incomplete
Student Answer: \( \frac{1}{3} \)
Student Explanation: 0ne third is equal to tree nineth
Correct: Yes



In [11]:
# データセットの分割
train_df = train[train["fold"] != 0].reset_index(drop=True)
val_df = train[train["fold"] == 0].reset_index(drop=True)

train_ds = Dataset.from_pandas(train_df[CFG.cols], preserve_index=False)
val_ds = Dataset.from_pandas(val_df[CFG.cols], preserve_index=False)

## 学習設定

In [12]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [13]:
# モデルの読み込み
model = AutoModelForCausalLM.from_pretrained(
    CFG.model_name,
    quantization_config=bnb_config,
    trust_remote_code=True,
    dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2",
    device_map="auto",
)

# トークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained(
    CFG.model_name,
    trust_remote_code=True,
    add_eos_token=True,
    add_bos_token=True,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
eos_token_id = tokenizer.eos_token_id

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/3.89G [00:00<?, ?B/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/3.98G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/1.70G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

In [14]:
def add_completion_token(
        model: AutoModelForCausalLM,
        tokenizer: PreTrainedTokenizer,
        completions: list[str]
    ) -> PreTrainedTokenizer:
    special_tokens_dict = {"additional_special_tokens": completions}
    tokenizer.add_special_tokens(special_tokens_dict)
    print(f"Added {len(completions)} special tokens.")

    model.resize_token_embeddings(len(tokenizer))
    print(f"Resized model embeddings to {len(tokenizer)} tokens.")

    return model, tokenizer

In [15]:
# 全てのラベルを保存
all_completions = train["completion"].unique().tolist()
with open(f"{CFG.output_dir_path}/all_completions.json", "w", encoding="utf-8") as f:
    json.dump(all_completions, f, ensure_ascii=False, indent=2)

# 全てのラベルを特殊トークンとして追加
model, tokenizer = add_completion_token(model, tokenizer, all_completions)

Added 64 special tokens.
Resized model embeddings to 151729 tokens.


In [16]:
# wandbのログイン
wandb.login(key=userdata.get("WANDB_API_KEY"))
wandb.init(project=CFG.comp_name, name=CFG.exp_name)

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtomokazu_rikioka[0m ([33mtomokazu_rikioka_[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [17]:
%%bash
cat <<'EOT' > ds_config_zero1.json
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "zero_optimization": {
        "stage": 1,
        "offload_optimizer": {
            "device": "none",
            "pin_memory": true
        }
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}
EOT

In [18]:
# 学習の設定
sft_config = SFTConfig(
    output_dir=CFG.output_dir_path,
    do_train=True,
    do_eval=True,
    num_train_epochs=CFG.num_train_epochs,
    per_device_train_batch_size=CFG.per_device_train_batch_size,
    gradient_accumulation_steps=CFG.gradient_accumulation_steps,
    per_device_eval_batch_size=CFG.per_device_eval_batch_size,
    learning_rate=CFG.learning_rate,
    optim=CFG.optim_type,
    lr_scheduler_type=CFG.lr_scheduler_type,
    warmup_ratio=CFG.warmup_ratio,
    weight_decay=CFG.weight_decay,
    logging_dir=CFG.log_dir_path,
    eval_strategy="steps",
    save_strategy="steps",
    logging_steps=10,
    eval_steps=0.2,
    save_steps=0.2,
    save_total_limit=1,
    metric_for_best_model="map@3",
    greater_is_better=True,
    load_best_model_at_end=True,
    max_length=CFG.max_len,
    max_grad_norm=1.0,
    report_to="wandb",
    bf16=True,
    fp16=False,  # KaggleはT4なのでFP16で推論
    bf16_full_eval=True,
    gradient_checkpointing=False,
    completion_only_loss=True,
    deepspeed="ds_config_zero1.json",
    dataset_num_proc=8,
)

In [19]:
# 追加済みの全ラベルをIDに
new_token_ids = tokenizer.convert_tokens_to_ids(all_completions)

# LoRAの設定
lora_config = LoraConfig(
    r=CFG.lora_r,
    lora_alpha=CFG.lora_alpha,
    lora_dropout=CFG.lora_dropout,
    bias=CFG.lora_bias,
    target_modules=CFG.target_modules,
    task_type=CFG.task_type,
    modules_to_save=["lm_head"],
    trainable_token_indices={"embed_tokens": new_token_ids}
)

## モデルの学習

In [20]:
def preprocess_logits_for_metrics(logits, labels):
    """logitsから上位3件のトークンIDを抽出"""
    if isinstance(logits, tuple):
        logits = logits[0]

    # top-kの値とインデックスを取得
    k = 3
    _, topk_indices = torch.topk(
        logits,
        k=min(k, logits.size(-1)),
        dim=-1,
        largest=True,
        sorted=True
    )

    return topk_indices, labels

In [21]:
def compute_metrics(eval_preds):
    """メトリクス（accuracy と MAP@3）を計算する関数。"""
    topk_preds, labels = eval_preds

    if isinstance(topk_preds, tuple):
        topk_preds = topk_preds[0]

    # -100 と eos を無視
    ignore_ids = {-100, eos_token_id}
    valid_labels_mask = ~np.isin(labels, list(ignore_ids))
    valid_labels = labels[valid_labels_mask]

    # preds とラベルを位置合わせ（1トークンずれ対策）
    shifted_mask = np.roll(valid_labels_mask, shift=-1, axis=1)
    shifted_mask[:, -1] = False
    aligned_topk_preds = topk_preds[shifted_mask]  # shape: (N, topk)

    # デコード
    decoded_topk_preds = [
        [tokenizer.decode([int(pred)]) for pred in preds_topk]
        for preds_topk in aligned_topk_preds
    ]
    decoded_labels = [tokenizer.decode([int(label)]) for label in valid_labels]

    # accuracy
    accuracy = np.mean([
        preds_topk[0].strip() == label.strip()
        for preds_topk, label in zip(decoded_topk_preds, decoded_labels)
    ])

    # MAP@3
    map3 = np.mean([
        sum((label.strip() == pred.strip()) / (rank+1)
            for rank, pred in enumerate(preds_topk))
        for preds_topk, label in zip(decoded_topk_preds, decoded_labels)
    ])

    return {"accuracy": float(accuracy), "map@3": float(map3)}

In [22]:
# Trainerの設定
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    peft_config=lora_config,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

Adding EOS to train dataset (num_proc=8):   0%|          | 0/29356 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=8):   0%|          | 0/29356 [00:00<?, ? examples/s]

Truncating train dataset (num_proc=8):   0%|          | 0/29356 [00:00<?, ? examples/s]

Adding EOS to eval dataset (num_proc=8):   0%|          | 0/7340 [00:00<?, ? examples/s]

Tokenizing eval dataset (num_proc=8):   0%|          | 0/7340 [00:00<?, ? examples/s]

Truncating eval dataset (num_proc=8):   0%|          | 0/7340 [00:00<?, ? examples/s]

In [23]:
for n, p in trainer.model.named_parameters():
    if "embed_tokens" in n or "lm_head" in n:
        print(n, p.requires_grad)

base_model.model.model.embed_tokens.token_adapter.base_layer.weight False
base_model.model.model.embed_tokens.token_adapter.trainable_tokens_delta.default True
base_model.model.lm_head.original_module.weight False
base_model.model.lm_head.modules_to_save.default.weight True


In [24]:
# 学習
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Step,Training Loss,Validation Loss,Accuracy,Map@3,Entropy,Num Tokens,Mean Token Accuracy
734,0.2233,0.254368,0.837466,0.911262,1.705937,2152911.0,0.91905
1468,0.1726,0.200083,0.86921,0.929814,1.966179,4307560.0,0.934777
2202,0.1102,0.178634,0.890191,0.94228,1.631132,6455858.0,0.945125
2936,0.0598,0.162629,0.896866,0.946299,1.536656,8605659.0,0.949142
3670,0.1151,0.15082,0.900136,0.947866,1.567976,10762568.0,0.95064


TrainOutput(global_step=3670, training_loss=0.29896748814660784, metrics={'train_runtime': 10111.9322, 'train_samples_per_second': 5.806, 'train_steps_per_second': 0.363, 'total_flos': 1.201907299170386e+18, 'train_loss': 0.29896748814660784, 'epoch': 2.0})

In [25]:
# モデルの保存
trainer.save_model(CFG.output_dir_path + "/model")
tokenizer.save_pretrained(CFG.output_dir_path + "/model")

('output//model/tokenizer_config.json',
 'output//model/special_tokens_map.json',
 'output//model/chat_template.jinja',
 'output//model/vocab.json',
 'output//model/merges.txt',
 'output//model/added_tokens.json',
 'output//model/tokenizer.json')