# 加载模型

In [None]:
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
import pandas as pd
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets
from sentence_transformers import SentenceTransformerTrainer

model_id = "/nfs_data_new/wp/weights/bge-small-zh" 
 
model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu",model_kwargs={'ignore_mismatched_sizes':True}
)

# 加载数据

In [None]:
dataset  = load_dataset("csv", data_files="embedding_data.csv")['train']
dataset = dataset.rename_column("context", "positive")
dataset = dataset.rename_column("query", "anchor")

dataset = dataset.train_test_split(test_size=0.2, shuffle=True)

train_dataset = dataset['train']
test_dataset = dataset['test']

In [1]:
# 模型评估

In [None]:
corpus = list(set([_['positive'] for _ in test_dataset]))
queries = list(set([_['anchor'] for _ in test_dataset]))

id2corpus = dict(zip(range(len(corpus)),corpus))
id2corpus = {str(i):j for i,j in id2corpus.items()}

id2queries = dict(zip(range(len(queries)),queries))
id2queries = {str(i):j for i,j in id2queries.items()}
relevant_docs = {_['anchor']:_['positive'] for _ in test_dataset}
relevant_docs = {str(queries.index(i)):str(corpus.index(j)) for i,j in relevant_docs.items()}

  
ir_evaluator = InformationRetrievalEvaluator(
    queries=id2queries,
    corpus=id2corpus,
    relevant_docs=relevant_docs,
    name=f"bge_m3",
    score_functions={"cosine": cos_sim},
batch_size=8
)
 
evaluator = SequentialEvaluator([ir_evaluator])
results = evaluator(model)

In [None]:
from sentence_transformers.losses import MultipleNegativesRankingLoss
train_loss = MultipleNegativesRankingLoss(model)


from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
 

args = SentenceTransformerTrainingArguments(
    output_dir="bge-m3-ft", 
    num_train_epochs=4,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    tf32=True,
    bf16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="bge_m3_cosine_ndcg@10"
)


trainer = SentenceTransformerTrainer(
    model=model,
    args=args,  
    train_dataset=train_dataset.select_columns(
        ["positive", "anchor"]
    ), 
    loss=train_loss,
    evaluator=evaluator,
)
trainer.train()

In [None]:
trainer.save()