# Overview

- [Infer Notebook](https://www.kaggle.com/code/sinchir0/fine-tuning-bge-infer/notebook)

- make 25 retrieval data by `bge-large-en-v1.5`
- Fine-tuning `bge-large-en-v1.5` by retrieval data
  - `anchor`: `ConstructName` + `SubjectName` + `QuestionText` + `Answer[A-D]Text`
  - `positive`: Correct MisconceptionName
  - `negative`: Wrong MisconceptionName

ref: https://sbert.net/docs/sentence_transformer/training_overview.html#trainer

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]

PyTorch version: 2.4.0

CUDA version: 12.3

# Setting

In [1]:
EXP_NAME = "fine-tuning-bge"
DATA_PATH = "./eedi-mining-misconceptions-in-mathematics"
MODEL_NAME = "BAAI/bge-large-en-v1.5"
COMPETITION_NAME = "eedi-mining-misconceptions-in-mathematics"
OUTPUT_PATH = "."
MODEL_OUTPUT_PATH = f"{OUTPUT_PATH}/trained_model"

RETRIEVE_NUM = 25

EPOCH = 2
LR = 2e-05
BS = 8
GRAD_ACC_STEP = 128 // BS
WEIGHT_DECAY = 0.01

TRAINING = True
DEBUG = False
WANDB = False

# Import

In [2]:
import os
import numpy as np

from datasets import load_dataset, Dataset

import wandb
import polars as pl

from sklearn.metrics.pairwise import cosine_similarity

from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.losses import ContrastiveLoss 
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator

In [3]:
import datasets
import sentence_transformers

print(pl.__version__) # == "1.7.1"
print(datasets.__version__) # == "3.0.0"
print(sentence_transformers.__version__) # == "3.1.0"

1.7.1
3.0.0
3.1.0


In [4]:
NUM_PROC = os.cpu_count()

# WANDB

In [5]:
if WANDB:
    # Settings -> add wandb api
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    wandb.login(key=user_secrets.get_secret("wandbkey"))
    wandb.init(project=COMPETITION_NAME, name=EXP_NAME)
    REPORT_TO = "wandb"
else:
    REPORT_TO = "none"

REPORT_TO

'none'

# Data Load

In [6]:
train = pl.read_csv(f"{DATA_PATH}/train.csv")
misconception_mapping = pl.read_csv(f"{DATA_PATH}/misconception_mapping.csv")

In [7]:
common_col = [
    "QuestionId",
    "ConstructName",
    "SubjectName",
    "QuestionText",
    "CorrectAnswer",
]

train_long = (
    train
    .select(
        pl.col(common_col + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]])
    )
    .unpivot(
        index=common_col,
        variable_name="AnswerType",
        value_name="AnswerText",
    )
    .with_columns(
        pl.concat_str(
            [
                pl.col("ConstructName"),
                pl.col("SubjectName"),
                pl.col("QuestionText"),
                pl.col("AnswerText"),
            ],
            separator=" ",
        ).alias("AllText"),
        pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
)
train_long.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer
i64,str,str,str,str,str,str,str,str,str
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""0_A"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""0_B"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""0_C"""
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D"""
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction…","""A""","""1000_A"""


In [9]:
train_misconception_long = (
    train.select(
        pl.col(
            common_col + [f"Misconception{alpha}Id" for alpha in ["A", "B", "C", "D"]]
        )
    )
    .unpivot(
        index=common_col,
        variable_name="MisconceptionType",
        value_name="MisconceptionId",
    )
    .with_columns(
        pl.col("MisconceptionType")
        .str.extract(r"Misconception([A-D])Id$")
        .alias("AnswerAlphabet"),
    )
    .with_columns(
        pl.concat_str(
            [pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_"
        ).alias("QuestionId_Answer"),
    )
    .sort("QuestionId_Answer")
    .select(pl.col(["QuestionId_Answer", "MisconceptionId"]))
    .with_columns(pl.col("MisconceptionId").cast(pl.Int64))
)

train_misconception_long.head()

QuestionId_Answer,MisconceptionId
str,i64
"""0_A""",
"""0_B""",
"""0_C""",
"""0_D""",1672.0
"""1000_A""",891.0


In [10]:
# join MisconceptionId
train_long = train_long.join(train_misconception_long, on="QuestionId_Answer")
train_long.head()

QuestionId,ConstructName,SubjectName,QuestionText,CorrectAnswer,AnswerType,AnswerText,AllText,AnswerAlphabet,QuestionId_Answer,MisconceptionId
i64,str,str,str,str,str,str,str,str,str,i64
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerAText""","""\( 3 \times(2+4)-5 \)""","""Use the order of operations to…","""A""","""0_A""",
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerBText""","""\( 3 \times 2+(4-5) \)""","""Use the order of operations to…","""B""","""0_B""",
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerCText""","""\( 3 \times(2+4-5) \)""","""Use the order of operations to…","""C""","""0_C""",
0,"""Use the order of operations to…","""BIDMAS""","""\[ 3 \times 2+4-5 \] Where do …","""A""","""AnswerDText""","""Does not need brackets""","""Use the order of operations to…","""D""","""0_D""",1672.0
1000,"""Simplify an algebraic fraction…","""Simplifying Algebraic Fraction…","""Simplify the following, if pos…","""B""","""AnswerAText""","""\( t \)""","""Simplify an algebraic fraction…","""A""","""1000_A""",891.0


# Make retrieval data

In [11]:
model = SentenceTransformer(MODEL_NAME)

train_long_vec = model.encode(
    train_long["AllText"].to_list(), normalize_embeddings=True
)
misconception_mapping_vec = model.encode(
    misconception_mapping["MisconceptionName"].to_list(), normalize_embeddings=True
)
print(train_long_vec.shape)
print(misconception_mapping_vec.shape)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/94.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

Batches:   0%|          | 0/234 [00:00<?, ?it/s]

Batches:   0%|          | 0/81 [00:00<?, ?it/s]

(7476, 1024)
(2587, 1024)


In [12]:
train_cos_sim_arr = cosine_similarity(train_long_vec, misconception_mapping_vec)
train_sorted_indices = np.argsort(-train_cos_sim_arr, axis=1)

In [13]:
train_long = train_long.with_columns(
    pl.Series(train_sorted_indices[:, :RETRIEVE_NUM].tolist()).alias(
        "PredictMisconceptionId"
    )
)

In [14]:
train_retrieved = (
    train_long.filter(
        pl.col(
            "MisconceptionId"
        ).is_not_null()  # TODO: Consider ways to utilize data where MisconceptionId is NaN.
    )
    .explode("PredictMisconceptionId")
    .join(
        misconception_mapping,
        on="MisconceptionId",
    )
    .join(
        misconception_mapping.rename(lambda x: "Predict" + x),
        on="PredictMisconceptionId",
    )
)
train_retrieved.shape

(109250, 14)

# Fine-Tune bge

In [15]:
train = (
    Dataset.from_polars(train_retrieved)
    .filter(  # To create an anchor, positive, and negative structure, delete rows where the positive and negative are identical.
        lambda example: example["MisconceptionId"] != example["PredictMisconceptionId"],
        num_proc=NUM_PROC,
    )
)

  self.pid = os.fork()


Filter (num_proc=4):   0%|          | 0/109250 [00:00<?, ? examples/s]

In [16]:
if DEBUG:
    train = train.select(range(1000))
    EPOCH = 1

In [17]:
model = SentenceTransformer(MODEL_NAME)

# loss = ContrastiveLoss(model)
loss = MultipleNegativesRankingLoss(model)

args = SentenceTransformerTrainingArguments(
    # 必需参数:
    output_dir=OUTPUT_PATH,  # 模型输出目录，用于保存训练后的模型。

    # 可选的训练参数:
    num_train_epochs=EPOCH,  # 训练的轮数，决定模型在训练集上迭代的次数。
    per_device_train_batch_size=BS,  # 每个设备（如GPU）上的训练批次大小。
    gradient_accumulation_steps=GRAD_ACC_STEP,  # 梯度累积步数，允许在更新权重之前累积多个批次的梯度。
    per_device_eval_batch_size=BS,  # 每个设备上的评估批次大小。
    eval_accumulation_steps=GRAD_ACC_STEP,  # 评估时的梯度累积步数。
    learning_rate=LR,  # 学习率，控制模型权重更新的步长。
    weight_decay=WEIGHT_DECAY,  # 权重衰减，用于防止过拟合，通过在损失函数中添加权重的L2正则化项来实现。
    warmup_ratio=0.1,  # 学习率预热比例，表示在训练开始时逐渐增加学习率的比例，以帮助模型稳定训练。

    fp16=True,  # 是否使用16位浮点数进行训练，如果GPU不支持FP16，则设置为False。
    bf16=False,  # 如果你的GPU支持BF16，则设置为True，否则保持为False。
    
    # 如果使用ContrastiveLoss，记得把这行注释掉，如果使用MultipleNegativesRankingLoss，就不用注释。
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # 批量采样器，确保每个批次中没有重复样本，这对使用MultipleNegativesRankingLoss非常有利。
    
    # 可选的跟踪/调试参数:
    lr_scheduler_type="cosine_with_restarts",  # 学习率调度器类型，这里使用余弦退火调度器，允许学习率周期性地重新开始。
    save_strategy="steps",  # 保存策略，指定何时保存模型，这里设置为按步骤保存。
    save_steps=0.1,  # 保存模型的步长比例，表示每经过多少步保存一次模型。
    save_total_limit=2,  # 保存的模型总数限制，超过该数量将删除最旧的模型。
    logging_steps=100,  # 日志记录步长，每经过多少步记录一次训练日志。
    
    report_to=REPORT_TO,  # 报告方式，如果安装了W&B，则会使用此参数进行报告。
    run_name=EXP_NAME,  # 当前训练运行的名称，用于标识和跟踪不同实验。
    
    do_eval=False  # 是否在训练过程中进行评估，设置为True则会在每个epoch结束时进行评估。
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train.select_columns(
        ["AllText", "MisconceptionName", "PredictMisconceptionName"]
    ),
    loss=loss
)

trainer.train()
model.save_pretrained(MODEL_OUTPUT_PATH)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Step,Training Loss
100,1.7786
200,1.4053
300,1.4183
400,1.4172
500,1.2773
600,1.2027
700,1.2535
800,1.2134
900,1.1186
1000,0.8603


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]