In [23]:
!pip install sentence-transformers scikit-learn datasets numpy pandas torch



In [34]:
import os
import random
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import GroupKFold
import torch
from datasets import Dataset
from sentence_transformers import SentenceTransformer, util
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers import SentenceTransformerTrainer
import warnings
warnings.filterwarnings("ignore", module="accelerate")

def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(0)

def preprocess(df):
    is_train = "MisconceptionAId" in df.columns

    df_new = df.copy()
    for col in df.columns[df.dtypes == "object"]:
        df_new[col] = df_new[col].str.strip()
    for option in ["A", "B", "C", "D"]:
        df_new[f"Answer{option}Text"] = df_new[f"Answer{option}Text"].str.replace("Only\n", "Only ")

    return df_new


def wide_to_long(df):
    is_train = "MisconceptionAId" in df.columns

    rows = []
    for _, row in df.iterrows():
        correct_option = row["CorrectAnswer"]
        correct_text = row[f"Answer{correct_option}Text"]
        for option in ["A", "B", "C", "D"]:
            if option == correct_option:
                continue
            elif is_train:
                misconception_id = row[f"Misconception{option}Id"]
                if np.isnan(misconception_id):
                    continue
                else:
                    misconception_id = int(misconception_id)
            row_new = row[:"QuestionText"]
            row_new["CorrectAnswerText"] = correct_text
            row_new["Answer"] = option
            row_new["AnswerText"] = row[f"Answer{option}Text"]
            if is_train:
                row_new["MisconceptionId"] = misconception_id
            rows.append(row_new)

    df_long = pd.DataFrame(rows).reset_index(drop=True)
    df_long.insert(0, "QuestionId_Answer", df_long["QuestionId"].astype(str) + "_" + df_long["Answer"])
    df_long = df_long.drop(["Answer", "CorrectAnswer"], axis=1)

    return df_long

In [35]:
INPUT_DIR_0 = "/content"  # Kaggle 데이터 디렉토리

df = pd.read_csv(f"{INPUT_DIR_0}/train.csv")
df = preprocess(df)
df_long = wide_to_long(df)

prompt = (
    "Subject: {SubjectName}\n"
    "Construct: {ConstructName}\n"
    "Question: {QuestionText}\n"
    "Incorrect Answer: {AnswerText}"
)

queries = [prompt.format(
    SubjectName=row["SubjectName"],
    ConstructName=row["ConstructName"],
    QuestionText=row["QuestionText"],
    AnswerText=row["AnswerText"]
) for _, row in df_long.iterrows()]
df_long["anchor"] = queries
df_long.to_parquet("train_long.parquet")

In [36]:
config = {
    "model": "sentence-transformers/all-MiniLM-L6-v2",
    "n_negatives": 10,
    "per_device_train_batch_size": 512,
    "gradient_accumulation_steps": 1,
    "learning_rate": 5e-5,
    "n_epochs": 10,
    "lr_scheduler_type": "cosine",
    "warmup_ratio": 0.1,
}

model = SentenceTransformer(config["model"])
embs_query = model.encode(df_long["anchor"], show_progress_bar=True, normalize_embeddings=True)
print("Model loaded and embeddings created.")


Batches:   0%|          | 0/137 [00:00<?, ?it/s]

Model loaded and embeddings created.


In [37]:
FULL_TRAIN = True

In [38]:
if FULL_TRAIN:
    df_long_eval = df_long.copy()
else:
    groups = df_long["MisconceptionId"]
    gkf = GroupKFold(n_splits=5)
    for fold, (idx_train, idx_valid) in enumerate(gkf.split(df_long, groups=groups)):
        df_long_eval = df_long.iloc[idx_valid].reset_index(drop=True)
        df_long = df_long.iloc[idx_train].reset_index(drop=True)
        break

In [39]:
df_map = pd.read_parquet(f"{INPUT_DIR_0}/misconception_mapping.parquet")
sr_map = df_map.set_index("MisconceptionId")["MisconceptionName"]

In [40]:
model = SentenceTransformer(config["model"])

In [41]:
embs_query = model.encode(df_long_eval["anchor"], show_progress_bar=False, normalize_embeddings=True)

list_embs_misconception = []
for i in range(10):
    misconceptions = df_map["MisconceptionName"] + "\n" + df_map[f"Misconception_explain_{i}"]
    embs_misconception = model.encode(misconceptions, show_progress_bar=False, normalize_embeddings=True)
    np.save(f"embs_misconception-{i}.npy", embs_misconception)
    list_embs_misconception.append(embs_misconception)

In [42]:
rank = np.zeros((len(list_embs_misconception), len(df_long_eval), len(df_map)), dtype=float)
for i, embs_misconception in enumerate(list_embs_misconception):
    similarities = cosine_similarity(embs_query, embs_misconception)
    rank[i, :, :] = np.argsort(np.argsort(-similarities))

rank_ave = np.mean(rank**(1/4), axis=0)
argsort = np.argsort(rank_ave, axis=1, kind="stable")

labels = df_long_eval["MisconceptionId"].values
scores = np.zeros(len(labels))
for i in range(len(labels)):
    hit_idx = np.where(labels[i] == argsort[i])[0][0]
    if hit_idx < 25:
        scores[i] = 1 / (hit_idx + 1)

map_25 = scores.mean()
recall_25 = (scores > 0).mean()

print(f"MAP@25: {map_25:.4f}  Recall@25: {recall_25:.4f}")

MAP@25: 0.2205  Recall@25: 0.6407


In [43]:
df_long_eval

Unnamed: 0,QuestionId_Answer,QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,QuestionText,CorrectAnswerText,AnswerText,MisconceptionId,anchor
0,0_D,0,856,Use the order of operations to carry out calcu...,33,BIDMAS,\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,\( 3 \times(2+4)-5 \),Does not need brackets,1672,Subject: BIDMAS\nConstruct: Use the order of o...
1,1_A,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,"Simplify the following, if possible: \( \frac{...",Does not simplify,\( m+1 \),2142,Subject: Simplifying Algebraic Fractions\nCons...
2,1_B,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,"Simplify the following, if possible: \( \frac{...",Does not simplify,\( m+2 \),143,Subject: Simplifying Algebraic Fractions\nCons...
3,1_C,1,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,"Simplify the following, if possible: \( \frac{...",Does not simplify,\( m-1 \),2142,Subject: Simplifying Algebraic Fractions\nCons...
4,2_A,2,2774,Calculate the range from a list of data,339,Range and Interquartile Range from a List of Data,Tom and Katie are discussing the \( 5 \) plant...,Only Katie,Only Tom,1287,Subject: Range and Interquartile Range from a ...
...,...,...,...,...,...,...,...,...,...,...,...
4365,1867_C,1867,2634,Distinguish between congruency and similarity,274,Congruency in Other Shapes,Tom and Katie are discussing congruence and si...,Only Katie,Both Tom and Katie,2312,Subject: Congruency in Other Shapes\nConstruct...
4366,1867_D,1867,2634,Distinguish between congruency and similarity,274,Congruency in Other Shapes,Tom and Katie are discussing congruence and si...,Only Katie,Neither is correct,2312,Subject: Congruency in Other Shapes\nConstruct...
4367,1868_A,1868,2680,Describe a 90° or 270° rotation giving the ang...,93,Rotation,Jo and Paul are arguing about how to fully des...,Only Paul,Only Jo,801,Subject: Rotation\nConstruct: Describe a 90° o...
4368,1868_C,1868,2680,Describe a 90° or 270° rotation giving the ang...,93,Rotation,Jo and Paul are arguing about how to fully des...,Only Paul,Both Jo and Paul,801,Subject: Rotation\nConstruct: Describe a 90° o...


In [45]:
for epoch in range(config["n_epochs"]):
    print(f"epoch: {epoch}")

    # Load a model to finetune
    if epoch > 0:
        model = SentenceTransformer(f"./model-{epoch-1}")

    # ----- Generate negative samples -----

    queries = df_long["anchor"]
    embs_query = model.encode(queries, show_progress_bar=False, normalize_embeddings=True)

    # Change misconception explanation texts in each epoch
    df_map["Misconception_concat"] = df_map["MisconceptionName"] + "\n" + df_map[f"Misconception_explain_{epoch%len(list_embs_misconception)}"]
    sr_map = df_map.set_index("MisconceptionId")["Misconception_concat"]
    embs_misconception = model.encode(sr_map, show_progress_bar=False, normalize_embeddings=True)

    similarities = cosine_similarity(embs_query, embs_misconception)
    argsort = np.argsort(-similarities, axis=1)

    n_negatives = np.repeat(config["n_negatives"], len(df_long))
    dup_count = df_long.groupby("MisconceptionId")["QuestionId_Answer"].transform("count")
    n_negatives = np.ceil(n_negatives / dup_count).astype(int)

    # Exclude misconceptions don't appear in the trainig data from negative samples
    labels = df_long["MisconceptionId"]
    valid_ids = set(df_long["MisconceptionId"])
    negative_mids = [[mid for mid in mids if mid != labels[i] and mid in valid_ids][:n_negatives[i]] for i, mids in enumerate(argsort)]

    df_long["MisconceptionId_negative"] = pd.Series(negative_mids)
    df_train = df_long.explode("MisconceptionId_negative")

    df_train["positive"] = df_train["MisconceptionId"].map(sr_map)
    df_train["negative"] = df_train["MisconceptionId_negative"].map(sr_map)

    # Shuffle training data manually (because NoDuplicatesBatchSampler never shuffles data due to a bug)
    df_train = df_train.sample(frac=1.0, random_state=epoch, axis=0, ignore_index=True)

    # ----- Train -----

    # Load a dataset to finetune on
    ds_train = Dataset.from_pandas(df_train).select_columns(["anchor", "positive", "negative"])

    # Define a loss function
    loss = CachedMultipleNegativesRankingLoss(model)

    # Specify training arguments
    args = SentenceTransformerTrainingArguments(
        output_dir="tmp",
        num_train_epochs=1,
        per_device_train_batch_size=config["per_device_train_batch_size"],
        gradient_accumulation_steps=config["gradient_accumulation_steps"],
        learning_rate=config["learning_rate"],
        lr_scheduler_type=config["lr_scheduler_type"],
        warmup_ratio=config["warmup_ratio"],
        fp16=True,
        bf16=False,
        save_strategy="epoch",
        save_total_limit=50,
        report_to="none"
    )

    # Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=ds_train,
        loss=loss,
    )
    trainer.train()

    # Save the trained model and misconception embeddings
    model.save_pretrained(f"./model-{epoch}")
    for i in range(len(list_embs_misconception)):
        misconceptions = df_map["MisconceptionName"] + "\n" + df_map[f"Misconception_explain_{i}"]
        embs_misconception = model.encode(misconceptions, show_progress_bar=False, normalize_embeddings=True)
        np.save(f"embs_misconception-{i}-{epoch}.npy", embs_misconception)
        list_embs_misconception[i] = embs_misconception

    # ----- Evaluate -----

    embs_query = model.encode(df_long_eval["anchor"], normalize_embeddings=True)

    rank = np.zeros((len(list_embs_misconception), len(df_long_eval), len(df_map)), dtype=float)
    for i, embs_misconception in enumerate(list_embs_misconception):
        similarities = cosine_similarity(embs_query, embs_misconception)
        rank[i, :, :] = np.argsort(np.argsort(-similarities))

    rank_ave = np.mean(rank**(1/4), axis=0)
    argsort = np.argsort(rank_ave, axis=1, kind="stable")

    labels = df_long_eval["MisconceptionId"].values
    scores = np.zeros(len(labels))
    for i in range(len(labels)):
        hit_idx = np.where(labels[i] == argsort[i])[0][0]
        if hit_idx < 25:
            scores[i] = 1 / (hit_idx + 1)

    map_25 = scores.mean()
    recall_25 = (scores > 0).mean()

    print(f"MAP@25: {map_25:.4f}  Recall@25: {recall_25:.4f}")

epoch: 0


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.2992  Recall@25: 0.7741
epoch: 1


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.3687  Recall@25: 0.8293
epoch: 2


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.4245  Recall@25: 0.8803
epoch: 3


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.4733  Recall@25: 0.9053
epoch: 4


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.5209  Recall@25: 0.9284
epoch: 5


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.5544  Recall@25: 0.9435
epoch: 6


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.5826  Recall@25: 0.9531
epoch: 7


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.6126  Recall@25: 0.9586
epoch: 8


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.6358  Recall@25: 0.9625
epoch: 9


Step,Training Loss


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

MAP@25: 0.6564  Recall@25: 0.9691


In [46]:
# 테스트 데이터 로드 및 전처리
test_df = pd.read_csv(f"{INPUT_DIR_0}/test.csv")  # 테스트 파일 경로
test_df = preprocess(test_df)
test_df_long = wide_to_long(test_df)

# 테스트 데이터에서 쿼리(anchor) 생성
queries_test = [
    prompt.format(
        SubjectName=row["SubjectName"],
        ConstructName=row["ConstructName"],
        QuestionText=row["QuestionText"],
        AnswerText=row["AnswerText"]
    ) for _, row in test_df_long.iterrows()
]
test_df_long["anchor"] = queries_test


In [51]:
# 모델 로드 (최종 훈련된 모델 사용)
model = SentenceTransformer(f"./model-{config['n_epochs']-1}")

# 테스트 데이터 임베딩
embs_test_query = model.encode(test_df_long["anchor"], normalize_embeddings=True)

# 저장된 Misconception Embeddings 로드
list_embs_misconception = []
for i in range(len(df_map.columns) - 2):  # -2는 MisconceptionName과 Id 제외
    embs_misconception = np.load(f"embs_misconception-{9}-{config['n_epochs']-1}.npy")
    list_embs_misconception.append(embs_misconception)


In [52]:
# 유사도 계산 및 순위 산출
rank_test = np.zeros((len(list_embs_misconception), len(test_df_long), len(df_map)), dtype=float)
for i, embs_misconception in enumerate(list_embs_misconception):
    similarities = cosine_similarity(embs_test_query, embs_misconception)
    rank_test[i, :, :] = np.argsort(np.argsort(-similarities))

# 평균 순위 계산
rank_ave_test = np.mean(rank_test**(1/4), axis=0)
argsort_test = np.argsort(rank_ave_test, axis=1, kind="stable")

# 예측 결과 저장
test_df_long["PredictedMisconceptions"] = [
    argsort_test[i, :25].tolist() for i in range(len(argsort_test))
]


In [53]:
# 결과 저장
test_results = test_df_long[["QuestionId", "PredictedMisconceptions"]]
test_results.to_csv("test_predictions.csv", index=False)
print("Test predictions saved to test_predictions.csv")


Test predictions saved to test_predictions.csv


In [54]:
# 예시로 첫 번째 질문의 예측 확인
sample_idx = 0
print("Anchor:", test_df_long.iloc[sample_idx]["anchor"])

top_predictions = argsort_test[sample_idx, :5]  # 상위 5개 예측
print("\nTop 5 Predicted Misconceptions:")
for rank, pred_idx in enumerate(top_predictions, 1):
    print(f"{rank}. {sr_map.iloc[pred_idx]}")


Anchor: Subject: BIDMAS
Construct: Use the order of operations to carry out calculations involving powers
Question: \[
3 \times 2+4-5
\]
Where do the brackets need to go to make the answer equal \( 13 \) ?
Incorrect Answer: \( 3 \times 2+(4-5) \)

Top 5 Predicted Misconceptions:
1. Applies BIDMAS in strict order (does not realize addition and subtraction, and multiplication and division, are of equal priority)
The passage is discussing a common misconception about the order of operations in mathematics, often remembered by the acronym BIDMAS (Brackets, Indices, Division/Multiplication, Addition/Subtraction). The misconception here is that some people believe BIDMAS should be followed in a strict, sequential order. However, in reality, addition and subtraction have equal precedence and should be performed from left to right in the order they appear. Similarly, multiplication and division also have equal precedence and should be carried out from left to right as they occur in the express