In [None]:
# Google Colabでの設定
google_colab = True

if google_colab:
    from google.colab import drive
    from google.colab import userdata

    drive.mount("/content/drive")

    # ディレクトリ移動
    %cd /content/drive/MyDrive/Python/kaggle_map/src/validation

    # AWQ関連のライブラリをインストール
    !pip install -q autoawq==0.2.9 accelerate==1.2.1 transformers==4.51.3

In [None]:
import os
import gc
import json
import random
import shutil

import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold

import torch
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, AwqConfig
from datasets import Dataset

import warnings
warnings.filterwarnings("ignore")

In [None]:
class CFG:
    """AWQ量子化設定管理クラス"""

    # ============== 実験情報 =============
    comp_name = "kaggle_map"
    exp_name = "exp026_qwen2.5-14b-lora-softlabel"
    fold = 0

    # ============== パス設定 =============
    comp_dir_path = "../../kaggle/input/"
    comp_dataset_path = f"{comp_dir_path}/map-charting-student-math-misunderstandings/"

    model_path = f"{exp_name}/model/"
    awq_path = f"{exp_name}/awq/"

    # ============== AWQ設定 =============
    max_calib_seq_len = 256
    n_calib_samples = 1000
    n_parallel_calib_samples = 4

    # ============== プロンプト設定 =============
    prompt_format = """You are a specialist in identifying the types of misunderstandings that arise from students' answers to math problems.
Based on the information provided below, please determine what kind of misunderstanding the student has.

Question: {QuestionText}
Answer: {MC_Answer}
Correct: {Correct}
Student Explanation: {StudentExplanation}
"""

    # ============== その他設定 =============
    seed = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# 乱数固定
def set_seed(seed=None):
    if seed is None:
        seed = 42

    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def make_dirs(cfg):
    os.makedirs(cfg.awq_path, exist_ok=True)

In [None]:
# 初期化
set_seed(CFG.seed)
make_dirs(CFG)

## データの読み込みと前処理

In [None]:
def add_folds_by_qid_cat_misc(df, n_splits=5, random_state=42, fallback="pair"):
    s_qid = df["QuestionId"].astype(str).fillna("NA")
    s_cat = df["Category"].astype(str).fillna("NA")
    s_misc = df["Misconception"].astype(str).fillna("NA")

    y_triple = s_qid + "|" + s_cat + "|" + s_misc
    y_pair = s_cat + "|" + s_misc

    cnt = y_triple.value_counts()
    if (cnt < n_splits).any():
        if fallback == "pair":
            rare = y_triple.map(cnt) < n_splits
            y = np.where(rare, y_pair, y_triple)
        elif fallback == "category":
            rare = y_triple.map(cnt) < n_splits
            y = np.where(rare, s_cat, y_triple)
        elif fallback == "none":
            y = y_triple
        else:
            raise ValueError("fallback は 'pair' / 'category' / 'none' のいずれかにしてください。")
    else:
        y = y_triple

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    folds = np.full(len(df), -1, dtype=int)
    for fold, (_, val_idx) in enumerate(skf.split(np.zeros(len(df)), y)):
        folds[val_idx] = fold

    out = df.copy()
    out["fold"] = folds
    return out

In [None]:
def wrong_corrections(df: pd.DataFrame) -> pd.DataFrame:
    """既知の誤りを修正する"""
    false_to_true_ids = [12878, 12901, 13876, 14089, 14159, 14185]
    df["MC_Answer"] = np.where(
        df["row_id"].isin(false_to_true_ids),
        df["MC_Answer"].str.replace(r"\( 6 \)", r"\( 9 \)"),
        df["MC_Answer"]
    )

    true_to_false_ids = [14280, 14305, 14321, 14335, 14338,  14352, 14355, 14403, 14407, 14412, 14413, 14418]
    df["MC_Answer"] = np.where(
        df["row_id"].isin(true_to_false_ids),
        df["MC_Answer"].str.replace(r"\( 9 \)", r"\( 6 \)"),
        df["MC_Answer"]
    )
    return df


def replace_duplicate_misc(df: pd.DataFrame) -> pd.DataFrame:
    """誤りのある誤答ラベルを修正する"""
    df["Misconception"] = df["Misconception"].replace({"Wrong_Fraction": "Wrong_fraction"})
    return df


def make_completion(df: pd.DataFrame) -> pd.DataFrame:
    """completion列を作成する"""
    df["Misconception"] = df["Misconception"].fillna("NA")
    df["completion"] = df["Category"] + ":" + df["Misconception"]
    return df


def add_is_correct(df: pd.DataFrame) -> pd.DataFrame:
    """正答かどうかのフラグを追加する"""
    idx = df.apply(lambda row: row["Category"].split("_")[0], axis=1) == "True"
    correct = df.loc[idx].copy()
    correct["count"] = correct.groupby(["QuestionId", "MC_Answer"]).MC_Answer.transform("count")
    correct = correct.sort_values("count", ascending=False)
    correct = correct.drop_duplicates(["QuestionId"])
    correct = correct[["QuestionId", "MC_Answer"]]
    correct["is_correct"] = 1

    df = df.merge(correct, on=["QuestionId", "MC_Answer"], how="left")
    df["is_correct"] = df["is_correct"].fillna(0)
    return df


def format_input(row) -> str:
    """入力テキストのフォーマット"""
    return CFG.prompt_format.format(
        QuestionText=row["QuestionText"],
        MC_Answer=row["MC_Answer"],
        Correct="Yes" if row["is_correct"] else "No",
        StudentExplanation=row["StudentExplanation"],
    )

In [None]:
# 学習データの読み込み
train = pd.read_csv(f"{CFG.comp_dataset_path}/train.csv")

# Fold分割
train = add_folds_by_qid_cat_misc(train, n_splits=5, random_state=42, fallback="pair")

# 既知の誤り修正
train = wrong_corrections(train)

# 重複するMisconceptionの統一
train = replace_duplicate_misc(train)

# completion列の作成
train = make_completion(train)

# 正解フラグの作成
train = add_is_correct(train)

# 入力プロンプトの作成
train["prompt"] = train.apply(format_input, axis=1)

In [None]:
train = train[train["fold"] != CFG.fold].reset_index(drop=True)

## Calibrationデータセットの作成

In [None]:
# Calibration用のサンプリング
set_seed(CFG.seed)
sampled_df = train.sample(n=min(CFG.n_calib_samples, len(train)), random_state=CFG.seed)

print(f"Sampled {len(sampled_df)} samples for calibration from {len(train)} total samples")

## AWQ量子化処理

In [None]:
# マージ済みモデルとトークナイザーの読み込み
print(f"Loading merged model from {CFG.model_path}...")
model = AutoAWQForCausalLM.from_pretrained(CFG.model_path)

tokenizer = AutoTokenizer.from_pretrained(
    CFG.model_path,
    trust_remote_code=True,
)

print("Model and tokenizer loaded successfully")

In [None]:
# Calibrationデータの作成
print("Creating calibration dataset...")
calib_data = []

for _, row in sampled_df.iterrows():
    # メッセージを作成
    messages = [
        {"role": "user", "content": row["prompt"]}
    ]

    # チャットテンプレートを適用
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    calib_data.append(text)

print(f"Created {len(calib_data)} calibration samples")

In [None]:
text

In [None]:
# AWQ量子化の設定
quant_config = {
    "zero_point": True,
    "q_group_size": 64,
    "w_bit": 4,
    "version": "GEMM"
}

print("Quantization config:")
for key, value in quant_config.items():
    print(f"  {key}: {value}")

In [None]:
# 量子化の実行
print("Starting quantization process...")
print(f"Max calibration sequence length: {CFG.max_calib_seq_len}")

model.quantize(
    tokenizer,
    quant_config=quant_config,
    calib_data=calib_data,
    max_calib_seq_len=CFG.max_calib_seq_len,
    n_parallel_calib_samples=CFG.n_parallel_calib_samples,
)

print("Quantization completed successfully!")

## モデルの保存

In [None]:
# 量子化設定の追加
quantization_config = AwqConfig(
    bits=quant_config["w_bit"],
    group_size=quant_config["q_group_size"],
    zero_point=quant_config["zero_point"],
    version=quant_config["version"].lower(),
).to_dict()

model.model.config.quantization_config = quantization_config

# AWQモデルの保存
print(f"Saving quantized model to {CFG.awq_path}...")
model.save_quantized(CFG.awq_path)
tokenizer.save_pretrained(CFG.awq_path)
print("Model saved successfully!")

# all_completions.jsonのコピー
source_file = os.path.join(CFG.model_path, "all_completions.json")
destination_file = os.path.join(CFG.awq_path, "all_completions.json")

try:
    shutil.copyfile(source_file, destination_file)
    print(f"Successfully copied all_completions.json to {destination_file}")
except FileNotFoundError:
    print(f"Warning: all_completions.json not found at {source_file}")
except Exception as e:
    print(f"Error copying all_completions.json: {e}")

# 保存確認
print("\nSaved files:")
for file in os.listdir(CFG.awq_path):
    file_path = os.path.join(CFG.awq_path, file)
    file_size = os.path.getsize(file_path) / (1024**3)  # Convert to GB
    print(f"  {file}: {file_size:.2f} GB")