# Для обучения используем метод `SimCSE`

Подробнее можно узнать тут - https://sbert.net/examples/sentence_transformer/unsupervised_learning/SimCSE/README.html

In [1]:
import os
import time
import random

from pymilvus import MilvusClient, DataType, Collection
import numpy as np

from mistralai import Mistral
from dotenv import load_dotenv

from tqdm import tqdm
load_dotenv('./.env')

True

In [2]:
client = MilvusClient("http://localhost:19530")

In [3]:
client.list_collections()

['test']

## Выгрузка всех документов из базы

In [4]:
# Супер! работает
res = client.query(
    collection_name="test",
    filter = "text!=''",
    output_fields=["id", "vector", "text"]
)

new_data = []

for hits in res:
    new_data.append(hits)

print(new_data[:10])

[{'id': 456982585391863969, 'vector': [np.float32(0.016613783), np.float32(-0.03117604), np.float32(-0.06495151), np.float32(-0.04809748), np.float32(0.07030371), np.float32(-0.01678212), np.float32(0.0480928), np.float32(0.068178505), np.float32(0.026706189), np.float32(0.024004344), np.float32(0.032107104), np.float32(0.013239677), np.float32(0.02745225), np.float32(-0.07346909), np.float32(-0.02458688), np.float32(0.022508852), np.float32(0.0678414), np.float32(-0.0049844105), np.float32(-0.07537136), np.float32(-0.03835464), np.float32(0.013881684), np.float32(-0.0024716763), np.float32(-0.013568725), np.float32(0.012999376), np.float32(0.05845901), np.float32(0.05723197), np.float32(-0.06548442), np.float32(-0.00032750855), np.float32(0.017923556), np.float32(-0.07087158), np.float32(-0.0470262), np.float32(-0.039671287), np.float32(0.07738932), np.float32(-0.07757249), np.float32(0.07430752), np.float32(0.047685582), np.float32(-0.039626937), np.float32(-0.020585554), np.float32(

## Создаём датасет

In [5]:
# Перемешиваем данные
random.shuffle(new_data)

In [6]:
eval_split = int(len(new_data) * 0.3)
eval_split

400

In [18]:
train = {
    "anchor": [],
    "positive": []
}

eval = {
    "anchor": [],
    "positive": []
}

for i in tqdm(range(len(new_data[eval_split:]))):
    train["anchor"].append("query: " + new_data[i]['text'])
    train["positive"].append("passage: " + new_data[i]['text'])

for i in tqdm(range(len(new_data[:eval_split]))):
    eval["anchor"].append("query: " + new_data[i]['text'])
    eval["positive"].append("passage: " + new_data[i]['text'])

100%|██████████| 935/935 [00:00<00:00, 467422.44it/s]
100%|██████████| 400/400 [00:00<?, ?it/s]


## Ошибка `InfoNCE loss`

https://www.sbert.net/docs/package_reference/sentence_transformer/losses.html#multiplenegativesrankingloss

In [35]:
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator, RerankingEvaluator
from transformers import EarlyStoppingCallback
import numpy as np
import torch

In [36]:
model = SentenceTransformer("intfloat/multilingual-e5-small", device='cpu')
# model.eval()

In [37]:
train_dataset = Dataset.from_dict(
    train
)

eval_dataset = Dataset.from_dict(
    eval
)

In [38]:
eval_dataset

Dataset({
    features: ['anchor', 'positive'],
    num_rows: 400
})

In [39]:
loss = MultipleNegativesRankingLoss(model)

In [41]:
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="fine-tunning",
    # Optional training parameters:
    num_train_epochs=6,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=40,
    
    save_strategy="steps",
    save_steps=40,
    save_total_limit=5,
    logging_steps=100,
    run_name="e5",  # Will be used in W&B if `wandb` is installed,
    load_best_model_at_end = True
)

In [42]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)
trainer.train()

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss
40,No log,0.000105
80,No log,5.5e-05
120,0.161900,5e-05
160,0.161900,4.9e-05
200,0.000000,4.5e-05
240,0.000000,4.3e-05
280,0.000000,4.3e-05
320,0.000000,4.2e-05


TrainOutput(global_step=354, training_loss=0.04577225145139326, metrics={'train_runtime': 63.586, 'train_samples_per_second': 88.227, 'train_steps_per_second': 5.567, 'total_flos': 0.0, 'train_loss': 0.04577225145139326, 'epoch': 6.0})

## Сохраняем модель

In [74]:
# 8. Save the trained model
model.save_pretrained("e5-base-retrievers")

In [59]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id="e5-base-retrievers")

In [80]:
import os

folder_path = 'e5-base-retrievers'
repo_id = "YarKo69/e5-base-retrievers"

for filename in os.listdir(folder_path):
    file_path = os.path.join(folder_path, filename)
    if os.path.isfile(file_path):  # Проверяем, что это файл, а не папка
        # print(file_path)
        api.upload_file(
            path_or_fileobj=file_path,
            repo_id=repo_id,
            path_in_repo=filename,
            repo_type="model",
            commit_message="Pushing retriever model",
            commit_description="Model trained on custom dataset (Article for Reinforcement Learning)"
        )
    else:
        api.upload_folder(
            folder_path=file_path,
            path_in_repo=filename,
            repo_id=repo_id,
            repo_type="model",
        )

No files have been modified since last commit. Skipping to prevent empty commit.


model.safetensors:   0%|          | 0.00/471M [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]