In [1]:
from tqdm.auto import tqdm
from bs4 import BeautifulSoup
import gc
import pandas as pd
import pickle
import sys
import numpy as np
from tqdm.autonotebook import trange
from sklearn.model_selection import GroupKFold
import json
import torch
from numpy.linalg import norm
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from peft import (
    LoraConfig,
    get_peft_model,
)
import json
import copy
import warnings
warnings.filterwarnings('ignore')


def apk(actual, predicted, k=25):
    """
    Computes the average precision at k.
    
    This function computes the average prescision at k between two lists of
    items.
    
    Parameters
    ----------
    actual : list
             A list of elements that are to be predicted (order doesn't matter)
    predicted : list
                A list of predicted elements (order does matter)
    k : int, optional
        The maximum number of predicted elements
        
    Returns
    -------
    score : double
            The average precision at k over the input lists
    """
    
    if not actual:
        return 0.0

    if len(predicted)>k:
        predicted = predicted[:k]

    score = 0.0
    num_hits = 0.0

    for i,p in enumerate(predicted):
        # first condition checks whether it is valid prediction
        # second condition checks if prediction is not repeated
        if p in actual and p not in predicted[:i]:
            num_hits += 1.0
            score += num_hits / (i+1.0)

    return score / min(len(actual), k)

def mapk(actual, predicted, k=25):
    """
    Computes the mean average precision at k.
    
    This function computes the mean average prescision at k between two lists
    of lists of items.
    
    Parameters
    ----------
    actual : list
             A list of lists of elements that are to be predicted 
             (order doesn't matter in the lists)
    predicted : list
                A list of lists of predicted elements
                (order matters in the lists)
    k : int, optional
        The maximum number of predicted elements
        
    Returns
    -------
    score : double
            The mean average precision at k over the input lists
    """
    
    return np.mean([apk(a,p,k) for a,p in zip(actual, predicted)])

def batch_to_device(batch, target_device):
    """
    send a pytorch batch to a device (CPU/GPU)
    """
    for key in batch:
        if isinstance(batch[key], Tensor):
            batch[key] = batch[key].to(target_device)
    return batch

def last_token_pool(last_hidden_states: Tensor,
                    attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

def inference(df, model, tokenizer, device):
    batch_size = 32
    max_length = 512
    sentences = list(df['query_text'].values)
    pids = list(df['order_index'].values)
    all_embeddings = []
    length_sorted_idx = np.argsort([-len(sen) for sen in sentences])
    sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
    for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=False):
        sentences_batch = sentences_sorted[start_index: start_index + batch_size]
        features = tokenizer(sentences_batch, max_length=max_length, padding=True, truncation=True,
                             return_tensors="pt")
        features = batch_to_device(features, device)
        with torch.no_grad():
            outputs = model(**features)
            embeddings = last_token_pool(outputs.last_hidden_state, features['attention_mask'])
            embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
            embeddings = embeddings.detach().cpu().numpy().tolist()
        all_embeddings.extend(embeddings)

    all_embeddings = [np.array(all_embeddings[idx]).reshape(1, -1) for idx in np.argsort(length_sorted_idx)]

    sentence_embeddings = np.concatenate(all_embeddings, axis=0)
    result = {pids[i]: em for i, em in enumerate(sentence_embeddings)}
    return result

In [2]:
path_prefix = "../data"
model_path="./SFR-Embedding-2_R"
device='cuda:0'

In [3]:
train = pd.read_csv(f"{path_prefix}/train.csv")
test = pd.read_csv(f"{path_prefix}/test.csv")
sample_submission = pd.read_csv(f"{path_prefix}/sample_submission.csv")
misconception_mapping = pd.read_csv(f"{path_prefix}/misconception_mapping.csv")

# load model

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path,device_map=device)
model = model.to(torch.float16)
model = model.eval()

Loading checkpoint shards: 100%|██████████| 3/3 [00:17<00:00,  5.96s/it]


# 划分数据集

In [5]:
groups = train['QuestionId'].values
# 创建 GroupKFold 对象
group_kfold = GroupKFold(n_splits=5)
train = train.reset_index(drop=True)
# 进行分组交叉验证
for train_index, test_index in group_kfold.split(train, groups=groups):
    tra = train.iloc[train_index,:]
    val = train.iloc[test_index,:]
    # tra = train
    break
tra.shape,val.shape

((1495, 15), (374, 15))

In [6]:
train['SubjectName'].values

array(['BIDMAS', 'Simplifying Algebraic Fractions',
       'Range and Interquartile Range from a List of Data', ..., 'BIDMAS',
       'Congruency in Other Shapes', 'Rotation'], dtype=object)

# 获得query embedding

In [7]:
task_description = 'Given a math question and a misconcepte incorrect answer, please retrieve the most accurate reason for the misconception.'

In [8]:
tra = pd.read_parquet("../create_data/save_data/cv1.parquet")

In [9]:
tra['mis_id'].nunique()

2585

In [10]:
train_data = []
for _,row in tra.iterrows():
    real_text = row['CorrectAnswer'].split('.',1)[-1]
    SelectedAnswer = row['SelectedAnswer'].split('.',1)[-1]
    query_text =f"###question###:{row['SubjectName']}-{row['ConstructName']}-{row['Question']}\n###Correct Answer###:{real_text}\n###Misconcepte Incorrect answer###:{SelectedAnswer}"
    row['query_text'] = get_detailed_instruct(task_description,query_text)
    row['answer_id'] = row['mis_id']
    train_data.append(copy.deepcopy(row))
train_df = pd.DataFrame(train_data)
train_df['order_index'] = list(range(len(train_df)))

In [11]:
# train_data = []
# for _,row in tra.iterrows():
#     for c in ['A','B','C','D']:
#         if str(row[f"Misconception{c}Id"])!="nan":
#             real_answer_id = row['CorrectAnswer']
#             real_text = row[f'Answer{real_answer_id}Text']
#             query_text = f"###question###:{row['SubjectName']}###{row['ConstructName']}###{row['QuestionText']}\n###CandidateAnswer:A.{row['AnswerAText']}\nB.{row['AnswerBText']}\nC.{row['AnswerCText']}\nD.{row['AnswerDText']}\n###CorrectAnswer:{real_answer_id}.{real_text}###distractor answer###:{c}.{row[f'Answer{c}Text']}"
#             row['query_text'] = get_detailed_instruct(task_description,query_text)
#             row['answer_id'] = row[f"Misconception{c}Id"]
#             train_data.append(copy.deepcopy(row))
# train_df = pd.DataFrame(train_data)
# train_df['order_index'] = list(range(len(train_df)))

In [12]:
train_df['query_text'].apply(lambda x: len(x.split(' '))).describe()

count    7691.000000
mean       51.072032
std        13.770476
min        26.000000
25%        42.000000
50%        48.000000
75%        57.000000
max       177.000000
Name: query_text, dtype: float64

In [13]:
train_df['query_text'].values[0]

'Instruct: Given a math question and a misconcepte incorrect answer, please retrieve the most accurate reason for the misconception.\nQuery: ###question###:Number Properties-Identify prime numbers from a list-Which of the following numbers is a prime number?\n###Correct Answer###:\\( 11 \\)\n###Misconcepte Incorrect answer###:\\( 4 \\)'

# 推理query embedding

In [14]:
train_embeddings = inference(train_df, model, tokenizer, device)

Batches: 100%|██████████| 241/241 [03:20<00:00,  1.20it/s]


In [15]:
misconception_mapping['query_text'] = misconception_mapping['MisconceptionName']
misconception_mapping['order_index'] = misconception_mapping['MisconceptionId']
doc_embeddings = inference(misconception_mapping, model, tokenizer, device)

Batches: 100%|██████████| 81/81 [00:14<00:00,  5.60it/s]


In [16]:
sentence_embeddings = np.concatenate([e.reshape(1, -1) for e in list(doc_embeddings.values())])
index_text_embeddings_index = {index: paper_id for index, paper_id in
                                         enumerate(list(doc_embeddings.keys()))}

# 召回文本topn

In [17]:
predicts_test = []
for _, row in tqdm(train_df.iterrows()):
    query_id = row['order_index']
    query_em = train_embeddings[query_id].reshape(1, -1)
    
    # 计算点积
    cosine_similarity = np.dot(query_em, sentence_embeddings.T).flatten()
    
    # 对余弦相似度进行排序并获取前100个索引
    sort_index = np.argsort(-cosine_similarity)[:100]
    pids = [index_text_embeddings_index[index] for index in sort_index]
    predicts_test.append(pids)

7691it [00:04, 1904.89it/s]


In [18]:
train_df['recall_ids'] = predicts_test

In [19]:
mapk([[data] for data in train_df['answer_id'].values],train_df['recall_ids'].values)

0.35846623332326505

In [20]:
def recall_score(reals,recalls,k=100):
    res = 0.
    for i in range(len(reals)):
        real = reals[i][0]
        for c in recalls[i][:k]:
            if c==real:
                res+=1
                break
    return res/len(reals)
recall_score([[data] for data in train_df['answer_id'].values],train_df['recall_ids'].values,k=40)

0.8361721492653751

# 构建训练集

In [21]:
bge_train = []
cnt = 100
for _, row in train_df.iterrows():
    query = row['query_text']
    pos = [int(row['answer_id'])]
    neg = []
    hard_negative_ctxs = row['recall_ids'][:cnt]
    for data in hard_negative_ctxs:
        if data!=pos[0]:
            neg.append(int(data))
    bge_train.append({'query': query, 'pos': pos, 'neg': neg})
with open(f"../train_data/recall_v9_gen/train.jsonl", 'w') as f:
    json.dump(bge_train, f)