In [1]:
# data preprocessing
import pandas as pd
import numpy as np

# load xlsx file
df = pd.read_excel('citiao.xlsx')
# keep only point_name, item_id_cate, point_content, from_name
df = df[['point_name','point_content']]
# remove html tags
from bs4 import BeautifulSoup
df['point_content'] = df['point_content'].apply(lambda x: BeautifulSoup(x, 'html.parser').get_text())
df.head()

  df['point_content'] = df['point_content'].apply(lambda x: BeautifulSoup(x, 'html.parser').get_text())


Unnamed: 0,point_name,point_content
0,载波从节点,采集器或计量点（电能表）所在的载波节点。
1,电容器,用来提供电容的器件。
2,供电客户服务,电力供应过程中，企业为满足客户获得和使用电力产品的各种相关需求的一系列活动的总称。
3,从节点附属节点,指与从节点具有绑定关系的附加设备，简称附属节点。
4,供电客户服务渠道,供电企业与客户进行交互、提供服务的具体途径。以下简称“服务渠道”。


In [2]:
import torch
from datasets import Dataset, load_dataset, concatenate_datasets

# 1. Create & Prepare embedding dataset

# Load your dataset (replace with your actual dataset loading method)
dataset = Dataset.from_pandas(df)

# Convert term-definition pairs to Q&A format
def convert_to_qa(example):
    return {
        "question": f"什么是 {example['point_name']}?",
        "answer": example['point_content']
    }

dataset = dataset.map(convert_to_qa)

# Rename columns
dataset = dataset.rename_column("question", "anchor")
dataset = dataset.rename_column("answer", "positive")

# delete columns
dataset = dataset.remove_columns(["point_name", "point_content"])

# Add an id column to the dataset
dataset = dataset.add_column("id", range(len(dataset)))

# Split dataset into a 10% test set
dataset = dataset.train_test_split(test_size=0.1)

# Save datasets to disk
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")

Map:   0%|          | 0/2960 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

98009

In [3]:
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainingArguments,
    SentenceTransformerTrainer,
    losses
)
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from sentence_transformers.training_args import BatchSamplers
# 2. Create baseline and evaluate pretrained model

model_id = "lier007/xiaobu-embedding-v2"
matryoshka_dimensions = [768, 512, 256, 128, 64]  # Adjust if necessary based on the model's capabilities

# Load the model
model = SentenceTransformer(
    model_id,
    device="cuda" if torch.cuda.is_available() else "cpu",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="Xiaobu Embedding V2 QA Matryoshka",
    ),
)

# Load test dataset
test_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])

# Convert the datasets to dictionaries
corpus = dict(zip(corpus_dataset["id"], corpus_dataset["positive"]))
queries = dict(zip(test_dataset["id"], test_dataset["anchor"]))

# Create a mapping of relevant document (1 in our case) for each query
relevant_docs = {q_id: [q_id] for q_id in queries}

matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

# Evaluate the pretrained model
results = evaluator(model)

print("Baseline Results:")
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

  return self.fget.__get__(instance, owner)()


Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Baseline Results:
dim_768_cosine_ndcg@10: 0.5315972259369834
dim_512_cosine_ndcg@10: 0.5221719774892145
dim_256_cosine_ndcg@10: 0.5266728324925485
dim_128_cosine_ndcg@10: 0.5067767666832826
dim_64_cosine_ndcg@10: 0.4465020550092783


In [4]:
# 3. Define loss function with Matryoshka Representation

inner_train_loss = losses.MultipleNegativesRankingLoss(model)
train_loss = losses.MatryoshkaLoss(model, inner_train_loss, matryoshka_dims=matryoshka_dimensions)

# 4. Fine-tune embedding model with SentenceTransformersTrainer

args = SentenceTransformerTrainingArguments(
    output_dir="xiaobu-v2-qa-matryoshka",
    num_train_epochs=4,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    tf32=True,
    bf16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.select_columns(["positive", "anchor"]),
    loss=train_loss,
    evaluator=evaluator,
)

# Start training
trainer.train()

# Save the best model
trainer.save_model()

# Push model to hub (uncomment if you want to push to Hugging Face Hub)
# trainer.model.push_to_hub("xiaobu-v2-qa-matryoshka")

# 5. Evaluate fine-tuned model against baseline

fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
results = evaluator(fine_tuned_model)

print("\nFine-tuned Model Results:")
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_dim_768_cosine_accuracy@1': 0.4358108108108108, 'eval_dim_768_cosine_accuracy@3': 0.6351351351351351, 'eval_dim_768_cosine_accuracy@5': 0.7128378378378378, 'eval_dim_768_cosine_accuracy@10': 0.8074324324324325, 'eval_dim_768_cosine_precision@1': 0.4358108108108108, 'eval_dim_768_cosine_precision@3': 0.2117117117117117, 'eval_dim_768_cosine_precision@5': 0.14256756756756755, 'eval_dim_768_cosine_precision@10': 0.08074324324324324, 'eval_dim_768_cosine_recall@1': 0.4358108108108108, 'eval_dim_768_cosine_recall@3': 0.6351351351351351, 'eval_dim_768_cosine_recall@5': 0.7128378378378378, 'eval_dim_768_cosine_recall@10': 0.8074324324324325, 'eval_dim_768_cosine_ndcg@10': 0.6151730383055712, 'eval_dim_768_cosine_mrr@10': 0.5542819605319604, 'eval_dim_768_cosine_map@100': 0.5614997143794617, 'eval_dim_512_cosine_accuracy@1': 0.42567567567567566, 'eval_dim_512_cosine_accuracy@3': 0.6351351351351351, 'eval_dim_512_cosine_accuracy@5': 0.7195945945945946, 'eval_dim_512_cosine_accuracy@10': 

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

{'loss': 3.2768, 'grad_norm': 8.516111373901367, 'learning_rate': 1.1736481776669307e-05, 'epoch': 1.9}
{'eval_dim_768_cosine_accuracy@1': 0.4594594594594595, 'eval_dim_768_cosine_accuracy@3': 0.6554054054054054, 'eval_dim_768_cosine_accuracy@5': 0.7466216216216216, 'eval_dim_768_cosine_accuracy@10': 0.8175675675675675, 'eval_dim_768_cosine_precision@1': 0.4594594594594595, 'eval_dim_768_cosine_precision@3': 0.21846846846846843, 'eval_dim_768_cosine_precision@5': 0.1493243243243243, 'eval_dim_768_cosine_precision@10': 0.08175675675675675, 'eval_dim_768_cosine_recall@1': 0.4594594594594595, 'eval_dim_768_cosine_recall@3': 0.6554054054054054, 'eval_dim_768_cosine_recall@5': 0.7466216216216216, 'eval_dim_768_cosine_recall@10': 0.8175675675675675, 'eval_dim_768_cosine_ndcg@10': 0.636712413310936, 'eval_dim_768_cosine_mrr@10': 0.5789427820677819, 'eval_dim_768_cosine_map@100': 0.5860482941975683, 'eval_dim_512_cosine_accuracy@1': 0.46621621621621623, 'eval_dim_512_cosine_accuracy@3': 0.6621