In [1]:
!pip uninstall -y torch
!pip install -q --no-index --find-links=/kaggle/input/wheels-vllm-0-6-3-post1 torchvision==0.19.1
!pip install -q --no-index --find-links=/kaggle/input/wheels-vllm-0-6-3-post1 vllm
!pip install -q -U /kaggle/input/vllm-t4-fix/grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install -q -U /kaggle/input/vllm-t4-fix/ray-2.11.0-cp310-cp310-manylinux2014_x86_64.whl
!pip install -q /kaggle/input/eedi-library/autoawq-0.2.7.post2-py3-none-any.whl --no-index --find-links=/kaggle/input/eedi-library 
!pip install -q /kaggle/input/eedi-library/peft-0.13.2-py3-none-any.whl --no-index --find-links=/kaggle/input/eedi-library 
!pip install -q /kaggle/input/eedi-library/bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl --no-index --find-links=/kaggle/input/eedi-library 
!pip install -q --no-deps --no-index /kaggle/input/logits-processor-zoo/logits_processor_zoo-0.1.0-py3-none-any.whl

Found existing installation: torch 2.4.0
Uninstalling torch-2.4.0:
  Successfully uninstalled torch-2.4.0


# Retriever

In [2]:
import os, math, numpy as np
import sys
import os
from transformers import AutoTokenizer
import pandas as pd
from tqdm import tqdm
import re, gc
import torch

model_path = "/kaggle/input/qwen2.5/transformers/32b-instruct-awq/1"
df_train = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/train.csv").fillna(-1).sample(10, random_state=42).reset_index(drop=True)
df_test = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/test.csv")
df_misconception_mapping = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")

df_ret = df_test.copy()

TEMPLATE_INPUT_V3 = '{QUESTION}\nCorrect answer: {CORRECT_ANSWER}\nStudent wrong answer: {STUDENT_WRONG_ANSWER}'
def format_input_v3(row, wrong_choice):

    assert wrong_choice in "ABCD"
    # Extract values from the row
    question_text = row.get("QuestionText", "No question text provided")
    subject_name = row.get("SubjectName", "Unknown subject")
    construct_name = row.get("ConstructName", "Unknown construct")
    # Extract the correct and wrong answer text based on the choice
    correct_answer = row.get("CorrectAnswer", "Unknown")
    assert wrong_choice != correct_answer
    correct_answer_text = row.get(f"Answer{correct_answer}Text", "No correct answer text available")
    wrong_answer_text = row.get(f"Answer{wrong_choice}Text", "No wrong answer text available")

    # Construct the question format
    formatted_question = f"""Question: {question_text}
    
SubjectName: {subject_name}
ConstructName: {construct_name}"""

    # Return the extracted data
    ret = {
        "QUESTION": formatted_question,
        "CORRECT_ANSWER": correct_answer_text,
        "STUDENT_WRONG_ANSWER": wrong_answer_text,
        "MISCONCEPTION_ID": row.get('Misconception{wrong_choice}Id'),
    }
    ret["PROMPT"] = TEMPLATE_INPUT_V3.format(**ret)

    return ret


items = []
target_ids = []
for _, row in df_ret.iterrows():
    for choice in ['A', 'B', 'C', 'D']:
        if choice == row["CorrectAnswer"]:
            continue
            
        correct_col = f"Answer{row['CorrectAnswer']}Text"
        item = {'QuestionId_Answer': '{}_{}'.format(row['QuestionId'], choice)}
        item['Prompt'] = format_input_v3(row, choice)['PROMPT']
        items.append(item)
        target_ids.append(int(row.get(f'Misconception{choice}Id', -1)))
        
df_input = pd.DataFrame(items)

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'<instruct>{task_description}\n<query>{query}'

def get_detailed_example(task_description: str, query: str, response: str) -> str:
    return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'

def get_new_queries(queries, query_max_len, examples_prefix, tokenizer):
    inputs = tokenizer(
        queries,
        max_length=query_max_len - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
            tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
        return_token_type_ids=False,
        truncation=True,
        return_tensors=None,
        add_special_tokens=False
    )
    prefix_ids = tokenizer(examples_prefix, add_special_tokens=False)['input_ids']
    suffix_ids = tokenizer('\n<response>', add_special_tokens=False)['input_ids']
    new_max_length = (len(prefix_ids) + len(suffix_ids) + query_max_len + 8) // 8 * 8 + 8
    new_queries = tokenizer.batch_decode(inputs['input_ids'])
    for i in range(len(new_queries)):
        new_queries[i] = examples_prefix + new_queries[i] + '\n<response>'
    return new_max_length, new_queries
task =  "Given a math multiple-choice problem with a student's wrong answer, retrieve the math misconceptions"
queries = [
    get_detailed_instruct(task, q) for q in df_input['Prompt']
]
documents = df_misconception_mapping['MisconceptionName'].tolist()
query_max_len, doc_max_len = 320, 48
# LORA_PATH = '/kaggle/input/lora-14b-1126/transformers/default/1'
LORA_PATH = '/kaggle/input/2211-lora-14b/transformers/default/1'
tokenizer = AutoTokenizer.from_pretrained(LORA_PATH)
examples_prefix = ''
new_query_max_len, new_queries = get_new_queries(queries, query_max_len, examples_prefix, tokenizer)


import json
with open('data.json', 'w') as f:
    data = {'texts': new_queries+ documents}
    f.write(json.dumps(data))

In [3]:
%%writefile run_embed.py
import argparse
import os
import json
import torch
import torch.nn.functional as F
import gc
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import peft

MAX_LENGTH = 320


def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[
            torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths
        ]


def get_embeddings_in_batches(model, tokenizer, texts, max_length, batch_size=32):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
        batch_texts = texts[i : i + batch_size]
        batch_dict = tokenizer(
            batch_texts,
            max_length=max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to("cuda")
        with torch.no_grad(), torch.amp.autocast("cuda"):
            outputs = model(**batch_dict)
            batch_embeddings = last_token_pool(
                outputs.last_hidden_state, batch_dict["attention_mask"]
            )
            batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1).cpu()
        embeddings.append(batch_embeddings)
    return torch.cat(embeddings, dim=0)


def load_model_and_tokenizer(base_model_path, lora_path, load_in_4bit=True):
    model = AutoModel.from_pretrained(
        base_model_path,
        device_map=0,
        torch_dtype=torch.float16,
        load_in_4bit=load_in_4bit,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        lora_path if lora_path else base_model_path
    )
    model.resize_token_embeddings(len(tokenizer))
    if lora_path:
        model = peft.PeftModel.from_pretrained(model, lora_path)
    return model, tokenizer


def main(args):
    output_file = args.input_text.replace(
        ".json", ".pt.fold.{}.{}.embed".format(*args.fold)
    )
    if os.path.exists(output_file):
        print(f"Output file {output_file} already exists. Skipping...")
        return
    model, tokenizer = load_model_and_tokenizer(
        args.base_model, args.lora_path, load_in_4bit=args.load_in_4bit
    )
    texts = json.load(open(args.input_text))["texts"][args.fold[0] :: args.fold[1]]
    embeddings = get_embeddings_in_batches(
        model,
        tokenizer,
        texts,
        max_length=MAX_LENGTH,
        batch_size=4,
    )
    text2embeds = {text: emb for text, emb in zip(texts, embeddings)}
    torch.save(text2embeds, output_file)

    del output_file, model, tokenizer, texts, embeddings, text2embeds
    gc.collect()
    torch.cuda.empty_cache()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--base_model",
        type=str,
        default="Qwen/Qwen2.5-7B",
        help="Path to the base model",
    )
    parser.add_argument(
        "--lora_path",
        type=str,
        default=None,
        help="Path to the LoRA model",
    )
    parser.add_argument(
        "--input_text",
        type=str,
        default=".cache/data.json",
    )
    parser.add_argument(
        "--load_in_4bit",
        action="store_true",
        help="Load model in 4-bit mode",
    )
    parser.add_argument("--fold", nargs=2, type=int, default=[0, 1])
    args = parser.parse_args()
    if not os.path.exists(args.lora_path):
        args.lora_path = None
    main(args)


Writing run_embed.py


In [4]:
lora_path = '/kaggle/input/2211-lora-14b/transformers/default/1'
cmd = f"(CUDA_VISIBLE_DEVICES=0 python run_embed.py --base_model /kaggle/input/qw14b-awq/transformers/default/1 --lora_path {lora_path} --input_text data.json --fold 0 2) & (CUDA_VISIBLE_DEVICES=1 python run_embed.py --base_model /kaggle/input/qw14b-awq/transformers/default/1 --lora_path {lora_path} --input_text data.json --fold 1 2)"
os.system(cmd)

Loading checkpoint shards: 100%|██████████| 2/2 [01:55<00:00, 57.70s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [01:55<00:00, 57.87s/it]
Embedding: 100%|██████████| 325/325 [03:13<00:00,  1.68it/s]
Embedding:  97%|█████████▋| 314/325 [03:16<00:06,  1.63it/s]

0

In [5]:
from glob import glob
import time
text_to_embed = {}
files = glob('*.pt*')
while len(files) != 2:
    time.sleep(1)
    files = glob('*.pt*')


time.sleep(3)    
for path in files:
    print(path)
    text_to_embed.update(torch.load(path))

Embedding: 100%|██████████| 325/325 [03:22<00:00,  1.60it/s]


data.pt.fold.1.2.embed
data.pt.fold.0.2.embed


  text_to_embed.update(torch.load(path))


In [6]:
query_embeddings = torch.stack([text_to_embed[t] for t in new_queries])
doc_embeddings = torch.stack([text_to_embed[t] for t in documents])

scores = query_embeddings @ doc_embeddings.T  # Shape: (M, N)
sorted_indices = torch.argsort(scores,1, descending=True)[:,:25].tolist()

df_input["MisconceptionId"] = [" ".join([str(x) for x in row]) for row in sorted_indices]
# df_input[["QuestionId_Answer", "MisconceptionId"]].to_csv("submission.csv", index=False)

In [7]:
import copy

full_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/test.csv")

rows = []
for idx, row in full_df.iterrows():
    for option in ["A", "B", "C", "D"]:
        if option == row.CorrectAnswer:
            continue
            
        correct_answer = row[f"Answer{row.CorrectAnswer}Text"]

        query_text =f"### SubjectName: {row['SubjectName']}\n### ConstructName: {row['ConstructName']}\n### Question: {row['QuestionText']}\n### Correct Answer: {correct_answer}\n### Misconcepte Incorrect answer: {option}.{row[f'Answer{option}Text']}"
        row["query_text"] = query_text
        row["QuestionId_Answer"] = f"{row.QuestionId}_{option}"
        row["answer_name"] = option
        row["correct_answer"] = correct_answer
        row["incorrect_answer"] = row[f"Answer{option}Text"]
        rows.append(copy.deepcopy(row))

df = pd.DataFrame(rows)
df['order_index'] = list(range(len(df)))
df["MisconceptionId"] = [" ".join([str(x) for x in row]) for row in sorted_indices]
df.to_parquet("df_target.parquet", index=False)
df.head()

Unnamed: 0,QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectAnswer,QuestionText,AnswerAText,AnswerBText,AnswerCText,AnswerDText,query_text,QuestionId_Answer,answer_name,correct_answer,incorrect_answer,order_index,MisconceptionId
0,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,\( 3 \times(2+4)-5 \),\( 3 \times 2+(4-5) \),\( 3 \times(2+4-5) \),Does not need brackets,### SubjectName: BIDMAS\n### ConstructName: Us...,1869_B,B,\( 3 \times(2+4)-5 \),\( 3 \times 2+(4-5) \),0,706 1507 1345 2306 328 1672 1005 2518 1963 253...
0,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,\( 3 \times(2+4)-5 \),\( 3 \times 2+(4-5) \),\( 3 \times(2+4-5) \),Does not need brackets,### SubjectName: BIDMAS\n### ConstructName: Us...,1869_C,C,\( 3 \times(2+4)-5 \),\( 3 \times(2+4-5) \),1,2306 1507 706 1005 1345 1999 2488 2532 987 251...
0,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,\( 3 \times(2+4)-5 \),\( 3 \times 2+(4-5) \),\( 3 \times(2+4-5) \),Does not need brackets,### SubjectName: BIDMAS\n### ConstructName: Us...,1869_D,D,\( 3 \times(2+4)-5 \),Does not need brackets,2,1005 328 1507 2532 1672 1516 706 1345 2306 248...
1,1870,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,"Simplify the following, if possible: \( \frac{...",\( m+1 \),\( m+2 \),\( m-1 \),Does not simplify,### SubjectName: Simplifying Algebraic Fractio...,1870_A,A,Does not simplify,\( m+1 \),3,2142 2068 167 891 418 1755 979 113 1421 320 22...
1,1870,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,"Simplify the following, if possible: \( \frac{...",\( m+1 \),\( m+2 \),\( m-1 \),Does not simplify,### SubjectName: Simplifying Algebraic Fractio...,1870_B,B,Does not simplify,\( m+2 \),4,2142 2068 167 891 341 979 1755 1871 143 418 11...


# LLM Reranker

In [8]:
%%writefile run_vllm_logits.py

import re

import numpy as np
import pandas as pd
import vllm
from logits_processor_zoo.vllm import MultipleChoiceLogitsProcessor
from transformers import AutoTokenizer

model_path = "/kaggle/input/qwen2.5/transformers/32b-instruct-awq/1"
tokenizer = AutoTokenizer.from_pretrained(model_path)

def preprocess_text(x):
    x = re.sub("http\w+", "", x)  # Delete URL
    x = re.sub(r"\.+", ".", x)  # Replace consecutive commas and periods with one comma and period character
    x = re.sub(r"\,+", ",", x)
    x = re.sub(r"\\\(", " ", x)
    x = re.sub(r"\\\)", " ", x)
    x = re.sub(r"[ ]{1,}", " ", x)
    x = x.strip()  # Remove empty characters at the beginning and end
    return x


PROMPT = """Here is a question about {ConstructName}({SubjectName}).
Question: {Question}
Correct Answer: {CorrectAnswer}
Incorrect Answer: {IncorrectAnswer}

You are a Mathematics teacher. Your task is to reason and identify the misconception behind the Incorrect Answer with the Question.
Answer concisely what misconception it is to lead to getting the incorrect answer.
Pick the correct misconception number from the below:

{Retrival}
"""


def apply_template(row, tokenizer):
    messages = [
        {
            "role": "user",
            "content": preprocess_text(
                PROMPT.format(
                    ConstructName=row["ConstructName"],
                    SubjectName=row["SubjectName"],
                    Question=row["QuestionText"],
                    IncorrectAnswer=row["incorrect_answer"],
                    CorrectAnswer=row["correct_answer"],
                    Retrival=row["retrieval"],
                )
            ),
        }
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return text


misconception_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")

df = pd.read_parquet("df_target.parquet")
indices = np.stack(df["MisconceptionId"].apply(lambda x: np.array(list(map(int, x.split())))))

llm = vllm.LLM(
    model_path,
    quantization="awq",
    tensor_parallel_size=2,
    gpu_memory_utilization=0.90,
    trust_remote_code=True,
    dtype="half",
    enforce_eager=True,
    max_model_len=5120,
    disable_log_stats=True,
)
tokenizer = llm.get_tokenizer()


def get_candidates(c_indices):
    candidates = []

    mis_names = misconception_df["MisconceptionName"].values
    for ix in c_indices:
        c_names = []
        for i, name in enumerate(mis_names[ix]):
            c_names.append(f"{i+1}. {name}")

        candidates.append("\n".join(c_names))

    return candidates


survivors = indices[:, -1:]

for i in range(3):
    c_indices = np.concatenate([indices[:, -8 * (i + 1) - 1 : -8 * i - 1], survivors], axis=1)

    df["retrieval"] = get_candidates(c_indices)
    df["text"] = df.apply(lambda row: apply_template(row, tokenizer), axis=1)

    responses = llm.generate(
        df["text"].values,
        vllm.SamplingParams(
            n=1,  # Number of output sequences to return for each prompt.
            top_k=1,  # Float that controls the cumulative probability of the top tokens to consider.
            temperature=0,  # randomness of the sampling
            seed=777,  # Seed for reprodicibility
            skip_special_tokens=False,  # Whether to skip special tokens in the output.
            max_tokens=1,  # Maximum number of tokens to generate per output sequence.
            logits_processors=[MultipleChoiceLogitsProcessor(tokenizer, choices=["1", "2", "3", "4", "5", "6", "7", "8", "9"])],
        ),
        use_tqdm=True,
    )

    responses = [x.outputs[0].text for x in responses]
    df["response"] = responses

    llm_choices = df["response"].astype(int).values - 1

    survivors = np.array([cix[best] for best, cix in zip(llm_choices, c_indices, strict=False)]).reshape(-1, 1)
    df[f"s{i}"] = survivors


def create_reranker_result(row):
    originals = row.MisconceptionId.split()
    rerank_result = [str(row.s2)] + originals[:8] + [str(row.s1)] + originals[8:16] + [str(row.s0)] + originals[16:]
    rerank_result = list(dict.fromkeys(rerank_result))[:25]
    return " ".join(rerank_result)


df["reranker_results"] = df.apply(create_reranker_result, axis=1)

##########################
# 2,3位もLLMに抽出させる
##########################


# def extract_candidates(row, target_rank=2):
#     target_ids = list(map(int, row.reranker_results.split()))[1:]
#     if target_rank == 2:
#         target_ids = target_ids[:9]
#     if target_rank == 3:
#         target_ids = [id for id in target_ids if id != row.f2][:9]
#     return target_ids


# for i in range(2):
#     target_rank = i + 2
#     df["candidates"] = df.apply(lambda row: extract_candidates(row, target_rank=target_rank), axis=1)

#     df["retrieval"] = get_candidates(df["candidates"].values)
#     df["text"] = df.apply(lambda row: apply_template(row, tokenizer), axis=1)
#     responses = llm.generate(
#         df["text"].values,
#         vllm.SamplingParams(
#             n=1,  # Number of output sequences to return for each prompt.
#             top_k=1,  # Float that controls the cumulative probability of the top tokens to consider.
#             temperature=0,  # randomness of the sampling
#             seed=777,  # Seed for reprodicibility
#             skip_special_tokens=False,  # Whether to skip special tokens in the output.
#             max_tokens=1,  # Maximum number of tokens to generate per output sequence.
#             logits_processors=[MultipleChoiceLogitsProcessor(tokenizer, choices=["1", "2", "3", "4", "5", "6", "7", "8", "9"])],
#         ),
#         use_tqdm=True,
#     )

#     responses = [x.outputs[0].text for x in responses]
#     df["response"] = responses

#     llm_choices = df["response"].astype(int).values - 1

#     survivors = np.array([cix[best] for best, cix in zip(llm_choices, df["candidates"].values, strict=False)]).reshape(-1, 1)
#     df[f"f{target_rank}"] = survivors


# def create_reranker_result_v2(row):
#     originals = row.reranker_results.split()
#     rerank_result = [originals[0]] + [str(row.f2), str(row.f3)] + originals[1:]
#     rerank_result = list(dict.fromkeys(rerank_result))[:25]
#     return " ".join(rerank_result)

# df["reranker_results_v2"] = df.apply(create_reranker_result_v2, axis=1)

##########################

df.to_parquet("df_target.parquet", index=False)
df_sub = df[["QuestionId_Answer", "reranker_results"]].copy()
df_sub.columns = ["QuestionId_Answer", "MisconceptionId"]
df_sub.to_csv("submission.csv", index=False)

Writing run_vllm_logits.py


In [9]:
!python run_vllm_logits.py

  pid, fd = os.forkpty()
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


INFO 12-12 04:18:41 config.py:905] Defaulting to use mp for distributed inference
INFO 12-12 04:18:41 llm_engine.py:237] Initializing an LLM engine (v0.6.3.post1) with config: model='/kaggle/input/qwen2.5/transformers/32b-instruct-awq/1', speculative_config=None, tokenizer='/kaggle/input/qwen2.5/transformers/32b-instruct-awq/1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=5120, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0,

# LLM Reranker (finetuned 32B)

In [10]:
%%writefile run_finetuned_model.py

import re
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, LogitsProcessor, BitsAndBytesConfig
from peft import PeftModel

model_path = "/kaggle/input/m/qwen-lm/qwen2.5/transformers/32b-instruct/1"
CHECKPOINT_PATH = "/kaggle/input/eedi-llm-hzghnvz"
misconception_df = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")
df = pd.read_parquet("df_target.parquet")

base_column_name = "reranker_results"
indices = np.stack(df[base_column_name].apply(lambda x: np.array(list(map(int, x.split()))[:9])))
indices_original = np.stack(df[base_column_name].apply(lambda x: np.array(list(map(int, x.split())))))

def preprocess_text(x):
    x = re.sub("http\w+", "", x)  # Delete URL
    x = re.sub(r"\.+", ".", x)  # Replace consecutive dots with a single dot
    x = re.sub(r"\,+", ",", x)   # Replace consecutive commas with a single comma
    x = re.sub(r"\\\(", " ", x)
    x = re.sub(r"\\\)", " ", x)
    x = re.sub(r"[ ]{1,}", " ", x)
    x = x.strip()  # Remove empty characters at the beginning and end
    return x

PROMPT = """<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Here is a question about {ConstructName}({SubjectName}).
Question: {Question}
Correct Answer: {CorrectAnswer}
Incorrect Answer: {IncorrectAnswer}

You are a Mathematics teacher. Your task is to reason and identify the misconception behind the Incorrect Answer with the Question.
Answer concisely what misconception it is to lead to getting the incorrect answer.
Pick the correct misconception number from the below:

{Retrival}<|im_end|>
<|im_start|>assistant
Answer:"""

def apply_template(row, tokenizer):
    return preprocess_text(PROMPT.format(
        ConstructName=row["ConstructName"],
        SubjectName=row["SubjectName"],
        Question=row["QuestionText"],
        IncorrectAnswer=row["incorrect_answer"],
        CorrectAnswer=row["correct_answer"],
        Retrival=row["retrieval"],
    ))

def get_candidates(c_indices):
    candidates = []
    mis_names = misconception_df["MisconceptionName"].values
    for ix in c_indices:
        c_names = []
        for i, name in enumerate(mis_names[ix]):
            c_names.append(f"{i+1}. {name}")
        candidates.append("\n".join(c_names))
    return candidates

# 推論用関数を定義
def batched_inference(df, batch_size, tokenizer, model, generation_config, logits_processor):
    responses = []
    for start in range(0, len(df), batch_size):
        end = start + batch_size
        batch_df = df.iloc[start:end]
        
        # テキストをトークン化
        input_texts = batch_df["text"].values.tolist()
        inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(model.device)

        # 推論
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs, 
                generation_config=generation_config, 
                logits_processor=logits_processor
            )
        
        # 結果をデコード
        batch_responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        responses.extend([r.strip()[-1] for r in batch_responses])
    return responses

# ロジット処理用クラスの実装
class MultipleChoiceLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer, choices):
        super().__init__()
        # choicesは["1","2","3",...,"9"]といったシングルトークンを想定
        self.choice_ids = []
        for c in choices:
            c_ids = tokenizer(c, add_special_tokens=False)["input_ids"]
            if len(c_ids) != 1:
                raise ValueError(f"Choice {c} is not a single token.")
            self.choice_ids.append(c_ids[0])

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 全てのトークンを -inf でマスクし、選択肢のトークンのみ元のスコアを残す
        mask = torch.full_like(scores, float("-inf"))
        for cid in self.choice_ids:
            mask[..., cid] = scores[..., cid]
        return mask

########################### 量子化あり

# 4bit量子化設定
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# トークナイザーとモデルのロード
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    device_map="auto",
    quantization_config=bnb_config
)

# LoRA を適用
model = PeftModel.from_pretrained(model, CHECKPOINT_PATH)


# 乱数シード設定（任意）
torch.manual_seed(777)

survivors = indices[:, -1:]

BATCH_SIZE = 8
generation_config = GenerationConfig(
    max_new_tokens=1,
    do_sample=False,
    top_k=1,
)

for i in range(1):
    c_indices = np.concatenate([indices[:, -8 * (i + 1) - 1 : -8 * i - 1], survivors], axis=1)
    df["retrieval"] = get_candidates(c_indices)
    df["text"] = df.apply(lambda row: apply_template(row, tokenizer), axis=1)

    if i == 0:
        print('text 0')
        print(df["text"].values[0])

    # ロジットプロセッサを設定
    logits_processor = [MultipleChoiceLogitsProcessor(tokenizer, choices=["1","2","3","4","5","6","7","8","9"])]
    
    # バッチ推論
    responses = batched_inference(df, BATCH_SIZE, tokenizer, model, generation_config, logits_processor)
    df["response"] = responses

    # 回答を整数として変換
    llm_choices = df["response"].astype(int).values - 1
    survivors = np.array([cix[best] for best, cix in zip(llm_choices, c_indices, strict=False)]).reshape(-1, 1)

results = []
for i in range(indices_original.shape[0]):
    ix = indices_original[i]
    llm_choice = survivors[i, 0]
    results.append(" ".join([str(llm_choice)] + [str(x) for x in ix if x != llm_choice]))

df["reranker_results_v3"] = results
df.to_parquet("df_target.parquet", index=False)
df_sub = df[["QuestionId_Answer", "reranker_results_v3"]].copy()
df_sub.columns = ["QuestionId_Answer", "MisconceptionId"]
df_sub.to_csv("submission.csv", index=False)

Writing run_finetuned_model.py


In [11]:
!python run_finetuned_model.py

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loading checkpoint shards: 100%|████████████████| 17/17 [06:27<00:00, 22.82s/it]
text 0
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Here is a question about Use the order of operations to carry out calculations involving powers(BIDMAS).
Question: \[
3 \times 2+4-5
\]
Where do the brackets need to go to make the answer equal 13 ?
Correct Answer: 3 \times(2+4)-5 
Incorrect Answer: 3 \times 2+(4-5) 

You are a Mathematics teacher. Your task is to reason and identify the misconception behind the Incorrect Answer with the Question.
Answer concisely what misconception it is to lead to getting the incorrect answer.
Pick the correct misconception number from the below:

1. Inserts brackets but not changed order of operation
2. Carries out operations from right to left regardless of priority order
3. Carries out operations from left to right regardless of priority order
4. Applies BIDMAS in strict order (

# Check Result

In [12]:
df_target = pd.read_parquet("df_target.parquet")
df_target.head()

Unnamed: 0,QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectAnswer,QuestionText,AnswerAText,AnswerBText,AnswerCText,...,order_index,MisconceptionId,retrieval,text,response,s0,s1,s2,reranker_results,reranker_results_v3
0,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,\( 3 \times(2+4)-5 \),\( 3 \times 2+(4-5) \),\( 3 \times(2+4-5) \),...,0,706 1507 1345 2306 328 1672 1005 2518 1963 253...,1. Inserts brackets but not changed order of o...,"<|im_start|>system\nYou are Qwen, created by A...",3,1054,1963,1345,1345 706 1507 2306 328 1672 1005 2518 1963 253...,1507 1345 706 2306 328 1672 1005 2518 1963 253...
1,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,\( 3 \times(2+4)-5 \),\( 3 \times 2+(4-5) \),\( 3 \times(2+4-5) \),...,1,2306 1507 706 1005 1345 1999 2488 2532 987 251...,1. Inserts brackets but not changed order of o...,"<|im_start|>system\nYou are Qwen, created by A...",3,2449,2518,1345,1345 2306 1507 706 1005 1999 2488 2532 2518 98...,1507 1345 2306 706 1005 1999 2488 2532 2518 98...
2,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,\( 3 \times(2+4)-5 \),\( 3 \times 2+(4-5) \),\( 3 \times(2+4-5) \),...,2,1005 328 1507 2532 1672 1516 706 1345 2306 248...,1. Believes order of operations does not affec...,"<|im_start|>system\nYou are Qwen, created by A...",4,1941,315,2532,2532 1005 328 1507 1672 1516 706 1345 315 2306...,1507 2532 1005 328 1672 1516 706 1345 315 2306...
3,1870,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,"Simplify the following, if possible: \( \frac{...",\( m+1 \),\( m+2 \),\( m-1 \),...,3,2142 2068 167 891 418 1755 979 113 1421 320 22...,1. Incorrectly cancels what they believe is a ...,"<|im_start|>system\nYou are Qwen, created by A...",6,143,1871,891,891 2142 2068 167 418 1755 979 113 1871 1421 3...,1755 891 2142 2068 167 418 979 113 1871 1421 3...
4,1870,1612,Simplify an algebraic fraction by factorising ...,1077,Simplifying Algebraic Fractions,D,"Simplify the following, if possible: \( \frac{...",\( m+1 \),\( m+2 \),\( m-1 \),...,4,2142 2068 167 891 341 979 1755 1871 143 418 11...,1. Incorrectly cancels what they believe is a ...,"<|im_start|>system\nYou are Qwen, created by A...",7,885,418,891,891 2142 2068 167 341 979 1755 1871 418 143 11...,1755 891 2142 2068 167 341 979 1871 418 143 11...


In [13]:
for idx, row in df_target.iterrows():
    print(row.MisconceptionId)
    print('------------------')
    print(row.reranker_results)
    print('------------------')
    # print(row.reranker_results_v2)
    # print('------------------')
    print(row.reranker_results_v3)
    print('==================')
    if idx >= 3:
        break

706 1507 1345 2306 328 1672 1005 2518 1963 2532 1516 2488 2181 1999 1941 987 158 2449 234 15 1862 315 657 1054 77
------------------
1345 706 1507 2306 328 1672 1005 2518 1963 2532 1516 2488 2181 1999 1941 987 1054 158 2449 234 15 1862 315 657 77
------------------
1507 1345 706 2306 328 1672 1005 2518 1963 2532 1516 2488 2181 1999 1941 987 1054 158 2449 234 15 1862 315 657 77
2306 1507 706 1005 1345 1999 2488 2532 987 2518 1672 328 1963 791 1516 1392 2392 2449 2181 1338 1214 2515 1248 158 657
------------------
1345 2306 1507 706 1005 1999 2488 2532 2518 987 1672 328 1963 791 1516 1392 2449 2392 2181 1338 1214 2515 1248 158 657
------------------
1507 1345 2306 706 1005 1999 2488 2532 2518 987 1672 328 1963 791 1516 1392 2449 2392 2181 1338 1214 2515 1248 158 657
1005 328 1507 2532 1672 1516 706 1345 2306 2488 1392 2518 158 1999 1862 315 1856 15 2181 1941 1416 2449 1319 2326 987
------------------
2532 1005 328 1507 1672 1516 706 1345 315 2306 2488 1392 2518 158 1999 1862 1941 1856 15

In [14]:
df_sub = pd.read_csv("submission.csv")
df_sub

Unnamed: 0,QuestionId_Answer,MisconceptionId
0,1869_B,1507 1345 706 2306 328 1672 1005 2518 1963 253...
1,1869_C,1507 1345 2306 706 1005 1999 2488 2532 2518 98...
2,1869_D,1507 2532 1005 328 1672 1516 706 1345 315 2306...
3,1870_A,1755 891 2142 2068 167 418 979 113 1871 1421 3...
4,1870_B,1755 891 2142 2068 167 341 979 1871 418 143 11...
5,1870_C,1755 891 2142 2068 167 418 113 2078 265 143 97...
6,1871_A,1287 1073 2439 1665 2551 1306 1059 1098 1866 1...
7,1871_C,1287 1073 2439 1665 2551 1098 1059 912 1866 13...
8,1871_D,1073 1287 1059 1866 903 2471 912 2439 1975 206...
