In [1]:
%%capture
%pip install load_dotenv sentence-transformers>=3 transformers==4.41.2

In [2]:
import torch
from huggingface_hub import login
from datasets import load_dataset, concatenate_datasets
import os
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from sentence_transformers import SentenceTransformerModelCardData, SentenceTransformer, SentenceTransformerTrainingArguments, SentenceTransformerTrainer
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

In [3]:
load_dotenv('./.env')
login(os.getenv("hf"))

In [4]:
dataset = load_dataset("mp-ac/mpac-dataset", split="train")

dataset = dataset.rename_column("instruction", "anchor")
dataset = dataset.rename_column("output", "positive")
dataset = dataset.add_column("id", range(len(dataset)))

dataset = dataset.train_test_split(test_size=0.2)

In [5]:
dataset["train"].to_json("datasets/mpac/train_dataset.json", orient="records")
dataset["test"].to_json("datasets/mpac/test_dataset.json", orient="records")

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

2281

In [6]:
model_id = "BAAI/bge-large-en-v1.5"
matryoshka_dimensions = [768, 512, 256, 128, 64]

model = SentenceTransformer(
    model_id, 
    device="cuda" if torch.cuda.is_available() else "cpu"
)



In [7]:
test_dataset = load_dataset("json", data_files="datasets/mpac/test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="datasets/mpac/train_dataset.json", split="train")
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [8]:
# Converter dataset para dicionarios
corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["positive"])
)
queries = dict(
    zip(test_dataset["id"], test_dataset["anchor"])
)

In [9]:
relevant_docs = {}
for q_id in queries:
    relevant_docs[q_id] = [q_id]

In [10]:
matryoshka_evaluators = []
# Iteracao entre as diferentes dimencoes
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate (reduzir) embeddings em determinada dimencao (dim)
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)
 
evaluator = SequentialEvaluator(matryoshka_evaluators)

In [11]:
results = evaluator(model)
# print(results)
 
# Exibir o score principal
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.7367699726190509
dim_512_cosine_ndcg@10: 0.7367699726190509
dim_256_cosine_ndcg@10: 0.7222222222222222
dim_128_cosine_ndcg@10: 0.6686425067562222
dim_64_cosine_ndcg@10: 0.7222222222222222


### Flash Attention

Flash Attention é uma implementação de **atenção eficiente** que reduz significativamente o uso de memória e melhora a velocidade de cálculo. Ele é projetado para lidar com grandes entradas (tokens) e funciona diretamente na GPU, otimizando as operações de leitura e escrita na memória.

### SDPA

**SDPA**, ou **Scaled Dot Product Attention**, é uma base do mecanismo de atenção nos Transformers. Flash Attention 2 otimiza o cálculo do produto escalar e a normalização exponencial (softmax), essenciais para o SDPA.

In [12]:
model_id = "BAAI/bge-large-en-v1.5"
 
# Carrega o modelo com SDPA para usar Flash Attention 2
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPAC BGE Large",
    ),
)

In [13]:
matryoshka_dimensions = [768, 512, 256, 128, 64]
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

In [14]:
train_dataset = load_dataset("json", data_files="datasets/mpac/train_dataset.json", split="train")

In [15]:
args = SentenceTransformerTrainingArguments(
    output_dir="mpac-base_v1.2",
    num_train_epochs=5,
    per_device_train_batch_size=32,             # tamanho do batch para treinamento
    gradient_accumulation_steps=16,             # global batch size de 512. Acumula para um batch de 512 mesmo em GPUs que nao suportam carregar um batch de 512
    per_device_eval_batch_size=16,              # tamanho do batch para evaluation
    warmup_ratio=0.1,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    tf32=True,
    bf16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # Nao usar exemplos duplicados no batch
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",
)

In [16]:
trainer = SentenceTransformerTrainer(
    model=model, # bge-large-en-v1.5
    args=args,
    train_dataset=train_dataset.select_columns(
        ["anchor", "positive"]
    ),
    loss=train_loss,
    evaluator=evaluator,
)

In [17]:
trainer.train()

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
1,No log,No log,0.666667,0.777778,0.777778,0.777778,0.666667,0.259259,0.155556,0.077778,0.666667,0.777778,0.777778,0.777778,0.73677,0.722222,0.731628,0.666667,0.777778,0.777778,0.777778,0.666667,0.259259,0.155556,0.077778,0.666667,0.777778,0.777778,0.777778,0.73677,0.722222,0.733032,0.666667,0.777778,0.777778,0.777778,0.666667,0.259259,0.155556,0.077778,0.666667,0.777778,0.777778,0.777778,0.722222,0.703704,0.712869,0.555556,0.666667,0.777778,0.777778,0.555556,0.222222,0.155556,0.077778,0.555556,0.666667,0.777778,0.777778,0.668643,0.633333,0.640868,0.666667,0.777778,0.777778,0.777778,0.666667,0.259259,0.155556,0.077778,0.666667,0.777778,0.777778,0.777778,0.722222,0.703704,0.710074,0.722222
2,No log,No log,0.777778,0.777778,0.777778,0.888889,0.777778,0.259259,0.155556,0.088889,0.777778,0.777778,0.777778,0.888889,0.812829,0.791667,0.800926,0.666667,0.777778,0.777778,0.888889,0.666667,0.259259,0.155556,0.088889,0.666667,0.777778,0.777778,0.888889,0.773807,0.738095,0.747354,0.555556,0.777778,0.777778,0.888889,0.555556,0.259259,0.155556,0.088889,0.555556,0.777778,0.777778,0.888889,0.72921,0.679012,0.688272,0.666667,0.777778,0.777778,0.888889,0.666667,0.259259,0.155556,0.088889,0.666667,0.777778,0.777778,0.888889,0.773807,0.738095,0.746032,0.666667,0.777778,0.777778,0.888889,0.666667,0.259259,0.155556,0.088889,0.666667,0.777778,0.777778,0.888889,0.770218,0.734568,0.741104,0.770218
3,No log,No log,0.777778,0.777778,0.888889,0.888889,0.777778,0.259259,0.177778,0.088889,0.777778,0.777778,0.888889,0.888889,0.825631,0.805556,0.814815,0.666667,0.888889,0.888889,1.0,0.666667,0.296296,0.177778,0.1,0.666667,0.888889,0.888889,1.0,0.825773,0.771605,0.771605,0.777778,0.777778,0.888889,1.0,0.777778,0.259259,0.177778,0.1,0.777778,0.777778,0.888889,1.0,0.854209,0.812346,0.812346,0.777778,0.888889,0.888889,1.0,0.777778,0.296296,0.177778,0.1,0.777778,0.888889,0.888889,1.0,0.879999,0.844444,0.844444,0.777778,0.777778,0.888889,1.0,0.777778,0.259259,0.177778,0.1,0.777778,0.777778,0.888889,1.0,0.859079,0.817901,0.817901,0.859079
4,No log,No log,0.777778,0.888889,0.888889,0.888889,0.777778,0.296296,0.177778,0.088889,0.777778,0.888889,0.888889,0.888889,0.833333,0.814815,0.824074,0.666667,0.888889,0.888889,1.0,0.666667,0.296296,0.177778,0.1,0.666667,0.888889,0.888889,1.0,0.825773,0.771605,0.771605,0.777778,0.888889,0.888889,1.0,0.777778,0.296296,0.177778,0.1,0.777778,0.888889,0.888889,1.0,0.87037,0.830688,0.830688,0.777778,0.888889,0.888889,1.0,0.777778,0.296296,0.177778,0.1,0.777778,0.888889,0.888889,1.0,0.881329,0.845679,0.845679,0.777778,0.888889,0.888889,1.0,0.777778,0.296296,0.177778,0.1,0.777778,0.888889,0.888889,1.0,0.882933,0.847222,0.847222,0.882933
5,No log,No log,0.777778,0.888889,0.888889,0.888889,0.777778,0.296296,0.177778,0.088889,0.777778,0.888889,0.888889,0.888889,0.833333,0.814815,0.824916,0.777778,0.888889,0.888889,1.0,0.777778,0.296296,0.177778,0.1,0.777778,0.888889,0.888889,1.0,0.881329,0.845679,0.845679,0.777778,0.888889,0.888889,1.0,0.777778,0.296296,0.177778,0.1,0.777778,0.888889,0.888889,1.0,0.884918,0.849206,0.849206,0.777778,0.888889,0.888889,1.0,0.777778,0.296296,0.177778,0.1,0.777778,0.888889,0.888889,1.0,0.881329,0.845679,0.845679,0.777778,0.888889,0.888889,1.0,0.777778,0.296296,0.177778,0.1,0.777778,0.888889,0.888889,1.0,0.884918,0.849206,0.849206,0.884918


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

TrainOutput(global_step=5, training_loss=0.7644370555877685, metrics={'train_runtime': 180.8303, 'train_samples_per_second': 0.94, 'train_steps_per_second': 0.028, 'total_flos': 0.0, 'train_loss': 0.7644370555877685, 'epoch': 5.0})

In [18]:
# Salva o modelo com melhor performance
trainer.save_model()

## Evaluate fine-tuned model

In [19]:
from sentence_transformers import SentenceTransformer
 
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)

results = evaluator(fine_tuned_model)
 # print(results)
 
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.8333333333333334
dim_512_cosine_ndcg@10: 0.8257733054706043
dim_256_cosine_ndcg@10: 0.8703703703703703
dim_128_cosine_ndcg@10: 0.8813288610261599
dim_64_cosine_ndcg@10: 0.8829327367063541


In [22]:
trainer.model.push_to_hub("mp-ac/mpac-bge-large-v1.2")

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

'https://huggingface.co/mp-ac/mpac-bge-large-v1.2/commit/a37710204fc620f50b7a0f12cfac198da1ec01da'