In [1]:
# --- Auto-reload imported modules ---
%reload_ext autoreload
%autoreload 2

import os, shutil, warnings
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
    LogitsProcessor,
)

from trl import SFTTrainer, SFTConfig

from peft import LoraConfig, get_peft_model, PeftModel
import joblib

In [None]:
# --- Config ---
VER = 10
DIR = f"ver_{VER}"; os.makedirs(DIR, exist_ok=True)

MODEL_NAME = "./models/Qwen2.5-14B-Instruct"   # or "Qwen/Qwen2.5-14B-Instruct"
# MODEL_NAME = "./models/Qwen2.5-0.5B-Instruct"

MAX_LEN = 1610
TRAIN_MODEL = True

CV_FOLD = 5
CV_SEED = 42
USE_SINGLE_FOLD = True
EVAL_MODE = "vote@3"  # "vote" or "vote@3" (use this one)

TRAIN_CSV = "./raw_data/train.csv"
TEST_CSV  = "./raw_data/test.csv"
CLEAN_MISLABEL = "ignore"   # ignore | fix | remove

# TODO: Add flash-attention-2
# TODO: Train Val Split, some low counts < 5.


In [3]:
# from transformers import AutoTokenizer, AutoModelForSequenceClassification

# from huggingface_hub import login
# login(token="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")

# # Choose your model Qwen/Qwen2.5-0.5B-Instruct
# model_name = "Qwen/Qwen2.5-0.5B-Instruct"
# save_path = "./models/Qwen2.5-0.5B-Instruct"

# model_name = "Qwen/Qwen2.5-14B-Instruct"
# save_path = "./models/Qwen2.5-14B-Instruct"

# # Download and save model + tokenizer locally
# tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True)

# tokenizer.save_pretrained(save_path)
# model.save_pretrained(save_path)

In [4]:
special_character_list = [
    '■','□','▲','△','▼','▽','◆','◇','○','●','★','☆','♦','♥','♠','♣',
    '§','†','‡','※','∞','±','≠','≈','√','∑','∏','∆','Ω','μ','∂',
    '→','←','↑','↓','↔','↕','〈','〉','『','』','│','─','┌','┐','└','┘','┼','█','▓','▒',
    '£','¥','€','₩','©','®','™','♪','♫','☀','☁','☂','☃','☎'
]

target_values = ['False_Correct:NA',
 'False_Misconception:Adding_across',
 'False_Misconception:Adding_terms',
 'False_Misconception:Additive',
 'False_Misconception:Base_rate',
 'False_Misconception:Certainty',
 'False_Misconception:Definition',
 'False_Misconception:Denominator-only_change',
 'False_Misconception:Division',
 'False_Misconception:Duplication',
 'False_Misconception:Firstterm',
 'False_Misconception:FlipChange',
 'False_Misconception:Ignores_zeroes',
 'False_Misconception:Incomplete',
 'False_Misconception:Incorrect_equivalent_fraction_addition',
 'False_Misconception:Interior',
 'False_Misconception:Inverse_operation',
 'False_Misconception:Inversion',
 'False_Misconception:Irrelevant',
 'False_Misconception:Longer_is_bigger',
 'False_Misconception:Mult',
 'False_Misconception:Multiplying_by_4',
 'False_Misconception:Not_variable',
 'False_Misconception:Positive',
 'False_Misconception:Scale',
 'False_Misconception:Shorter_is_bigger',
 'False_Misconception:Subtraction',
 'False_Misconception:SwapDividend',
 'False_Misconception:Tacking',
 'False_Misconception:Unknowable',
 'False_Misconception:WNB',
 'False_Misconception:Whole_numbers_larger',
 'False_Misconception:Wrong_Fraction',
 'False_Misconception:Wrong_Operation',
 'False_Misconception:Wrong_fraction',
 'False_Misconception:Wrong_term',
 'False_Neither:NA',
 'True_Correct:NA',
 'True_Misconception:Adding_across',
 'True_Misconception:Additive',
 'True_Misconception:Base_rate',
 'True_Misconception:Definition',
 'True_Misconception:Denominator-only_change',
 'True_Misconception:Division',
 'True_Misconception:Duplication',
 'True_Misconception:Firstterm',
 'True_Misconception:FlipChange',
 'True_Misconception:Incomplete',
 'True_Misconception:Incorrect_equivalent_fraction_addition',
 'True_Misconception:Inversion',
 'True_Misconception:Irrelevant',
 'True_Misconception:Longer_is_bigger',
 'True_Misconception:Mult',
 'True_Misconception:Multiplying_by_4',
 'True_Misconception:Not_variable',
 'True_Misconception:Positive',
 'True_Misconception:Shorter_is_bigger',
 'True_Misconception:Subtraction',
 'True_Misconception:SwapDividend',
 'True_Misconception:Tacking',
 'True_Misconception:WNB',
 'True_Misconception:Whole_numbers_larger',
 'True_Misconception:Wrong_fraction',
 'True_Misconception:Wrong_term',
 'True_Neither:NA']

target_token_map = {label: char for label, char in zip(target_values, special_character_list)}
token_target_map = {v: k for k, v in target_token_map.items()}

SYSTEM_PROMPT = """You are now tasked with analyzing math problems and classifying student responses. Given a math problem, the student's chosen answer, whether it's correct, and the student's explanation, you need to determine the appropriate Category and Misconception classification.
Below are the available Category:Misconception classifications you can choose from.

Your job is to output exactly ONE classification token from the allowed set below.

OUTPUT RULES (READ CAREFULLY):
1) Your entire reply must be exactly one character: a single token from the allowed set.
2) Do NOT output any words, labels, punctuation, quotes, spaces, or newlines.
3) Do NOT explain your reasoning or restate the problem.
4) Choose the token that best matches the Category:Misconception for the given input.

ALLOWED OUTPUT TOKENS (token : meaning):
""" + '\n'.join([f"{k} : {v}" for k, v in token_target_map.items() ]) +"""

Please analyze the given input and provide your classification."""

In [5]:
def clean_mislabel_entries(train: pd.DataFrame) -> pd.DataFrame:
    print(f"Using {CLEAN_MISLABEL} for data cleaning Strat")
    qid = 31778
    correct_answer = r"\( 6 \)"
    rows_to_fix = []
    for idx, row in train[train['QuestionId'] == qid].iterrows():
        is_correct_answer = row['MC_Answer'] == correct_answer
        is_true = str(row['Category']).startswith("True")
        if is_correct_answer and not is_true:
            rows_to_fix.append(idx)
        elif not is_correct_answer and is_true:
            rows_to_fix.append(idx)
    assert len(rows_to_fix) == 18, "Expected 18 mislabeled entries to fix, found a different number."

    if CLEAN_MISLABEL == "ignore":
        return train
    elif CLEAN_MISLABEL == "remove":
        return train.drop(index=rows_to_fix).reset_index(drop=True)
    elif CLEAN_MISLABEL == "fix":
        for idx in rows_to_fix:
            row = train.loc[idx]
            cat = str(row['Category']).split("_", 1)[-1]
            prefix = "True" if row['MC_Answer'] == correct_answer else "False"
            train.at[idx, 'Category'] = f"{prefix}_{cat}"
        return train
    else:
        raise ValueError("CLEAN_MISLABEL must be 'ignore', 'remove', or 'fix'")

def load_and_preprocess_data():
    train = pd.read_csv(TRAIN_CSV)
    train = clean_mislabel_entries(train)
    train['Misconception'] = train['Misconception'].fillna('NA')
    train['target'] = train['Category'] + ":" + train['Misconception']

    le = LabelEncoder()
    train['label'] = le.fit_transform(train['target'])

    idx = train['Category'].str.startswith("True")
    correct = (
        train[idx].groupby(['QuestionId','MC_Answer']).size()
        .reset_index(name='c').sort_values('c', ascending=False)
        .drop_duplicates(['QuestionId']).assign(is_correct=1)[['QuestionId','MC_Answer','is_correct']]
    )
    train = train.merge(correct, on=['QuestionId','MC_Answer'], how='left')
    train['is_correct'] = train['is_correct'].fillna(0)

    train["split_key"] = (train['QuestionId'].astype(str) + "_" + train['label'].astype(str)).astype('category').cat.codes
    return train, le

def format_input(row):
    x = "Yes" if row['is_correct'] else "No"
    return (
        f"Question: {row['QuestionText']}\n"
        f"Student Answer: {row['MC_Answer']}\n"
        f"Correct? {x}\n"
        f"Student Explanation: {row['StudentExplanation']}"
    )

def preprocess_function_conversational_prompt_completion(row):
    assiant_token = target_token_map[ row['target'] ]
    return {
        "prompt": [{"role": "system", "content": SYSTEM_PROMPT}, 
                   {"role": "user", "content": format_input(row)}],
        "completion": [
            {"role": "assistant", "content": assiant_token}
        ],
    }

In [6]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # important for causal LM

allowed_token_ids = [tokenizer.encode(str(i), add_special_tokens=False)[0] for i in special_character_list]

# Make sure each special symbol is a single token for this tokenizer
special_token_ids = []
for ch in special_character_list:
    ids = tokenizer.encode(ch, add_special_tokens=False)
    assert len(ids) == 1, f"Symbol {ch} is not a single token with this tokenizer."
    special_token_ids.append(ids[0])

def load_base_model_bf16():
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2"
    )
    # model.config.use_cache = False  # better for training/checkpointing
    return model

In [None]:
# --- Compute MAP@3 ---
def compute_map3(eval_pred):
    logits_ex, labels_ex = eval_pred

    start_token = 77091
    offset = 2

    # --- Find all (b, t) positions of the start token ---
    b_idx, t_idx = np.where(labels_ex == start_token)
    target_positions = t_idx + offset

    labels_ex = labels_ex[b_idx, target_positions]
    logits_ex = logits_ex[b_idx, target_positions-1, :]

    # --- Mask logits outside allowed_token_ids ---
    mask = np.full(logits_ex.shape, -np.inf, dtype=np.float32)
    mask[:, allowed_token_ids] = logits_ex[:, allowed_token_ids]
    logits_ex = mask

    # --- Top 3 predictions ---
    top3_ids = np.argsort(-logits_ex, axis=1)[:, :3]  # (N, 3)

    match = (top3_ids == labels_ex[:, None])
    map3 = np.mean([1 if m[0] else 0.5 if m[1] else 1/3 if m[2] else 0 for m in match])
    return {"map@3": map3}

In [None]:
allowed_token_ids_tensor = torch.tensor(allowed_token_ids, dtype=torch.long)

def preprocess_logits_for_metrics(logits, labels):
    # logits: [B, T, V]  (for causal LM heads)
    # labels: [B, T]
    start_token_id = 77091
    offset = 2
    
    if logits.ndim == 3:
        # Find positions where label at t == start_token
        # Note: doing this per-batch, so labels is smaller
        b_idx, t_idx = torch.where(labels == start_token_id)
        target_positions = t_idx + offset

        # Slice logits to these positions
        logits = logits[b_idx, target_positions-1, :]

    # Mask: keep only allowed token ids
    logits = logits.index_select(dim=-1, index=allowed_token_ids_tensor.to(logits.device))

    # Optionally: top-3 on GPU to shrink size even more
    top3_vals, top3_ids = torch.topk(logits, k=3, dim=-1)
    return top3_ids  # Pass only top3 IDs to compute_metrics

def compute_map3(eval_pred):
    start_token_id = 77091
    offset = 2
    
    top3_ids, labels_ex = eval_pred  # top3_ids already on CPU
    top3_ids = top3_ids.cpu().numpy() if torch.is_tensor(top3_ids) else top3_ids
    labels_ex = labels_ex.cpu().numpy() if torch.is_tensor(labels_ex) else labels_ex

    # For labels_ex, we still need to slice to the target positions
    # Same logic as in preprocess: find start token positions
    b_idx, t_idx = np.where(labels_ex == start_token_id)
    target_positions = t_idx + offset
    labels_ex = labels_ex[b_idx, target_positions]

    match = (top3_ids == labels_ex[:, None])
    map3 = np.mean([1 if m[0] else 0.5 if m[1] else 1/3 if m[2] else 0 for m in match])
    return {"map@3": map3}

In [None]:
param_sets = [
    # 4) Attention + MLP (heavier); more epochs, lower LR
    dict(
        name="attn_mlp_r8_alpha32",
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
        r=8, lora_alpha=32,
        lr=8e-5,
        epochs=4
    ),
    
    # 0) Minimal, fast, strong baseline for your setup
    dict(
        name="qv_r8_alpha32_fast",
        target_modules=["q_proj","v_proj"],
        r=8, lora_alpha=32,           # α/r = 4
        lr=2.5e-4,
        epochs=3
    ),

    # 1) Same topology, higher capacity (r), keep scaling constant
    dict(
        name="qv_r16_alpha64_cap",
        target_modules=["q_proj","v_proj"],
        r=16, lora_alpha=64,          # α/r = 4
        lr=1.8e-4,                    # little lower LR for more params
        epochs=3
    ),

    # 2) Add o_proj to mix heads; keep rank moderate
    dict(
        name="qvo_r8_alpha32_mix",
        target_modules=["q_proj","v_proj","o_proj"],
        r=8, lora_alpha=32,
        lr=2.0e-4,
        epochs=3
    ),

    # 3) Full attention stack; step down LR
    dict(
        name="attn_full_r8_alpha32",
        target_modules=["q_proj","k_proj","v_proj","o_proj"],
        r=8, lora_alpha=32,
        lr=1.2e-4,
        epochs=3
    ),

    # 5) Heavier rank on attention-only (stress test capacity); reduce LR
    dict(
        name="attn_full_r16_alpha64",
        target_modules=["q_proj","k_proj","v_proj","o_proj"],
        r=16, lora_alpha=64,
        lr=9e-5,
        epochs=3
    ),

    # 6) Low-rank but bigger α/r (stronger injected update); watch stability
    dict(
        name="qv_r4_alpha32_boost",   # α/r = 8 (intentional)
        target_modules=["q_proj","v_proj"],
        r=4, lora_alpha=32,
        lr=2.0e-4,
        epochs=3
    ),

    # 7) Medium: q+v plus MLP, but small r to keep params in check
    dict(
        name="qv_mlp_r4_alpha16",
        target_modules=["q_proj","v_proj","gate_proj","up_proj","down_proj"],
        r=4, lora_alpha=16,           # α/r = 4
        lr=1.2e-4,
        epochs=4
    ),
]

param_sets = [
    # 4) Attention + MLP (heavier); more epochs, lower LR
    dict(
        name="attn_mlp_r8_alpha32",
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
        r=8, lora_alpha=32,
        lr=8e-5,
        epochs=2,
    ),
    
    # 0) Minimal, strong baseline
    dict(
        name="qv_r8_alpha32_e2",
        target_modules=["q_proj","v_proj"],
        r=8, lora_alpha=32,   # α/r = 4
        lr=2.5e-4,
        epochs=2
    ),
    # 1) More capacity, reduce LR
    dict(
        name="qv_r16_alpha64_e2",
        target_modules=["q_proj","v_proj"],
        r=16, lora_alpha=64,  # α/r = 4
        lr=1.6e-4,
        epochs=2
    ),
    # 2) Add o_proj head-mixing, moderate LR
    dict(
        name="qvo_r8_alpha32_e2",
        target_modules=["q_proj","v_proj","o_proj"],
        r=8, lora_alpha=32,
        lr=2.0e-4,
        epochs=2
    ),
    # 3) Full attention, step down LR
    dict(
        name="attn_full_r8_alpha32_e1",
        target_modules=["q_proj","k_proj","v_proj","o_proj"],
        r=8, lora_alpha=32,
        lr=1.2e-4,
        epochs=1
    ),
]


train_df, le = load_and_preprocess_data()
n_classes = train_df['label'].nunique()
train_df['text'] = train_df.apply(format_input, axis=1)

skf = StratifiedKFold(n_splits=CV_FOLD, shuffle=True, random_state=CV_SEED)
fold_indices = list(skf.split(train_df, train_df['split_key']))
if USE_SINGLE_FOLD:
    fold_indices = [fold_indices[0]]
    

for fold, (tr_idx, va_idx) in enumerate(fold_indices):
    tr, va = train_df.iloc[tr_idx].copy(), train_df.iloc[va_idx].copy()

    ds_tr = Dataset.from_pandas(tr, preserve_index=False)
    ds_tr = ds_tr.map(preprocess_function_conversational_prompt_completion, num_proc=1, remove_columns=ds_tr.column_names)
    
    ds_va = Dataset.from_pandas(va, preserve_index=False)
    ds_va = ds_va.map(preprocess_function_conversational_prompt_completion, num_proc=1, remove_columns=ds_va.column_names)
    
    best_map = -1.0
    for repeat_idx in range(len(param_sets)):
        print(f"\n=== Fold {fold+1}/{CV_FOLD} REPEAT {repeat_idx} ===")

        cfg = param_sets[repeat_idx]
        print(f"Trying {cfg['name']}")

        # LoRA config for this repeat
        lora_config = LoraConfig(
            r=cfg["r"],
            lora_alpha=cfg["lora_alpha"],
            target_modules=cfg["target_modules"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        
        model = load_base_model_bf16()
        model = get_peft_model(model, lora_config)
        # model = PeftModel.from_pretrained(model, "./ver_10/fold_0/checkpoint-200/")

        LR_RATE = cfg["lr"]
        EPOCHS = cfg["epochs"]

        BATCH_SIZE = 16

        training_args = SFTConfig(
            output_dir=f"{DIR}/fold_{fold}",
            num_train_epochs=EPOCHS,
            per_device_train_batch_size=2,
            per_device_eval_batch_size=24,
            eval_strategy="steps", # Evaluate every 'eval_steps'
            save_strategy="steps", # Save model every 'save_steps'
            eval_steps=len(tr)/BATCH_SIZE//2,
            save_steps=len(tr)/BATCH_SIZE//2,
            save_total_limit=1,
            learning_rate=LR_RATE,
            metric_for_best_model="map@3",
            greater_is_better=True,
            load_best_model_at_end=True,
            logging_dir=f"{DIR}/logs_fold_{fold}/repeat_{repeat_idx}",
            logging_steps=len(tr)/BATCH_SIZE//2,
            report_to="tensorboard",
            bf16=True, # TRAIN WITH BF16 IF LOCAL GPU IS NEWER GPU          
            fp16=False, # INFER WITH FP16 BECAUSE KAGGLE IS T4 GPU
            eval_accumulation_steps=1,
            gradient_accumulation_steps=BATCH_SIZE//2,
            completion_only_loss=True,
            max_length = MAX_LEN,
            packing = True,
        )
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=ds_tr,
            eval_dataset=ds_va,
            # peft_config=lora_config,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            compute_metrics=compute_map3,
        )

        if TRAIN_MODEL:
            initial_map = trainer.evaluate()["eval_map@3"]  # Initial evaluation before training
            print(f"Initial eval/map@3 = {initial_map:.6f}")
            if initial_map > best_map:
                best_map = initial_map
                save_dir = f"{DIR}/fold_{fold}/best"
                os.makedirs(save_dir, exist_ok=True)
                # Save LoRA adapters
                trainer.save_model(save_dir)
                # Save label encoder once per fold
                joblib.dump(le, f"{DIR}/fold_{fold}/label_encoder.joblib")
            
            trainer.train()

            final_map = eval_result["eval_map@3"]
            print(f"Repeat {repeat_idx} eval/map@3 = {final_map:.6f}")

            if final_map > best_map:
                best_map = final_map
                save_dir = f"{DIR}/fold_{fold}/best"
                os.makedirs(save_dir, exist_ok=True)
                # Save LoRA adapters
                trainer.save_model(save_dir)
                # Save label encoder once per fold
                joblib.dump(le, f"{DIR}/fold_{fold}/label_encoder.joblib")

        # cleanup HF checkpoints if any
        for ckpt in sorted(Path(f"{DIR}/fold_{fold}").glob("checkpoint-*")):
            shutil.rmtree(ckpt, ignore_errors=True)
            

Using ignore for data cleaning Strat




Map:   0%|          | 0/29356 [00:00<?, ? examples/s]

Map:   0%|          | 0/7340 [00:00<?, ? examples/s]


=== Fold 1/5 REPEAT 0 ===
Trying attn_mlp_r8_alpha32


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Tokenizing train dataset:   0%|          | 0/29356 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/29356 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/7340 [00:00<?, ? examples/s]

Packing eval dataset:   0%|          | 0/7340 [00:00<?, ? examples/s]

Initial eval/map@3 = 0.000000


Step,Training Loss,Validation Loss,Model Preparation Time,Map@3
917,0.1164,0.084372,0.0215,0.0
