In [None]:
import json
import logging
import os
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path

import polars as pl
import torch
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.losses import (
    BatchAllTripletLoss,
    BatchHardTripletLoss,
    BatchSemiHardTripletLoss,  # надо попробовать этот лосс
)
from sentence_transformers.losses.BatchHardTripletLoss import (
    BatchHardTripletLossDistanceFunction,
)
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.util import dot_score
from torch.optim import AdamW
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup


def dot_distance(embeddings: torch.Tensor) -> torch.Tensor:
    return 1 - dot_score(embeddings, embeddings)


os.environ["HTTPS_PROXY"] = "http://proxy-server.sovcombank.group:3128"
os.environ["HTTP_PROXY"] = "http://proxy-server.sovcombank.group:3128"
os.environ["NO_PROXY"] = "localhost,127.0.0.1,.sovcombank.group"

In [None]:
@dataclass
class Config:
    # Columns
    text_col: str = "text_col"
    label_col: str = "label_col"
    preprocessed_text_col: str = "preprocessed_text"
    preprocessed_label_col: str = "preprocessed_label"

    # Paths
    data_volume: Path = Path(".")
    data_folds_path: Path = Path("./data/dataset-1")
    base_model: str = "sergeyzh/rubert-tiny-turbo"
    artifacts_dir: Path = Path("./artifacts")

    # Model parameters
    max_seq_length: int = 512

    # Training parameters
    distance_metric: str = "cosine"  # or "dot"
    loss_func: str = "BatchSemiHardTripletLoss"
    margin: float = 5.0
    seed: int = 545454663
    max_steps: int = 1000
    log_steps: int = 100
    per_device_train_batch_size: int = 256
    per_device_eval_batch_size: int = 32
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.1
    gradient_checkpointing: bool = False
    metric_for_best_model: str = "eval_loss"
    lr_scheduler_kwargs: dict = field(init=False)
    torch_compile: bool = True
    dataloader_pin_memory: bool = True
    use_cpu: bool = False

    def __post_init__(self):
        self.lr_scheduler_kwargs = {
            "num_warmup_steps": int(self.warmup_ratio * self.max_steps),
            "num_training_steps": self.max_steps,
            "num_cycles": 10,
            "last_epoch": -1,
        }

        self.eval_data_path = Path("./data/dataset-1/eval")
        self.train_path = Path("./data/dataset-1/train_data.parquet")
        self.val_path = Path("./data/dataset-1/val_data.parquet")
        self.test_path = Path("./data/dataset-1/test_data.parquet")
        self.val_triplets_path = Path("./data/dataset-1/eval/val_triplets.parquet")
        self.test_triplets_path = Path("./data/dataset-1/eval/test_triplets.parquet")

        self.current_time_str = datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p")
        self.experiment_name = f"{self.current_time_str}_{Path(self.base_model).stem}"
        self.outputs_dir = self.artifacts_dir / self.experiment_name
        self.checkpoints_dir = self.artifacts_dir / Path(self.base_model).stem

        self.result_dir = self.outputs_dir / "weights"
        self.metrics_dir = self.outputs_dir / "metrics"
        self.logs_dir = self.outputs_dir / "logs"

        self.outputs_dir.mkdir(parents=True, exist_ok=True)
        self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
        self.result_dir.mkdir(parents=True, exist_ok=True)
        self.metrics_dir.mkdir(parents=True, exist_ok=True)
        self.logs_dir.mkdir(parents=True, exist_ok=True)

Гиперпараметры, влияющие на потребляемую память и на скорость обучения:
1. gradient_checkpointing
3. per_device_train_batch_size желательно побольше >=128
4. per_device_eval_batch_size
5. distance_metric, cosine или dot
6. torch_compile
7. dataloader_pin_memory
8. max_seq_length

влияющие на качество обучения:
1. предобработка датасета
1. max_steps
2. warmup_ratio
3. learning_rate
7. loss_func, margin
4. per_device_train_batch_size
6. distance_metric
9. metric_for_best_model
5. max_seq_length

Тут подробно описаны все гиперпараметры, влияюшие на обучение, для библиотеки transformers 
https://huggingface.co/docs/transformers/perf_train_gpu_one

In [None]:
def prepare_datasets(
    train_path,
    val_path,
    test_path,
    text_col="text",
    label_col="label",
    preprocessed_label_col="label_",
    min_num_examples_threshold=2,
):
    train_df = pl.read_parquet(train_path)
    val_df = pl.read_parquet(val_path)
    test_df = pl.read_parquet(test_path)

    classes = set(
        [
            *train_df[label_col].unique().to_list(),
            *val_df[label_col].unique().to_list(),
            *test_df[label_col].unique().to_list(),
        ]
    )

    inverse_mapping = {c: idx for idx, c in enumerate(classes)}

    train_df = train_df.with_columns(
        pl.col(label_col)
        .replace(inverse_mapping, return_dtype=pl.Int32)
        .alias(preprocessed_label_col)
    )
    val_df = val_df.with_columns(
        pl.col(label_col)
        .replace(inverse_mapping, return_dtype=pl.Int32)
        .alias(preprocessed_label_col)
    )
    test_df = test_df.with_columns(
        pl.col(label_col)
        .replace(inverse_mapping, return_dtype=pl.Int32)
        .alias(preprocessed_label_col)
    )

    min_num_examples_threshold = 2

    train_df = train_df.filter(
        pl.col(label_col).count().over(label_col) >= min_num_examples_threshold
    )

    def process_hf_dataset(dataset):
        dataset = dataset.select_columns([text_col, preprocessed_label_col])
        dataset = dataset.rename_columns(
            {text_col: "sentence", preprocessed_label_col: "label"}
        )
        return dataset

    train_dataset = process_hf_dataset(Dataset.from_polars(train_df))
    val_dataset = process_hf_dataset(Dataset.from_polars(val_df))
    test_dataset = process_hf_dataset(Dataset.from_polars(test_df))

    return {"train": train_dataset, "val": val_dataset, "test": test_dataset}

In [None]:
def prepare_evaluator(triplets_path, fold="val", batch_size=32):
    triplets_df = pl.read_parquet(triplets_path)
    anchors = triplets_df["anchors"].to_list()
    positives = triplets_df["positives"].to_list()
    negatives = triplets_df["negatives"].to_list()

    evaluator = TripletEvaluator(
        anchors=anchors,
        positives=positives,
        negatives=negatives,
        name=f"ffl_triplet_{fold}",
        batch_size=batch_size,
        main_distance_function="cosine",
    )

    return evaluator

In [None]:
def init_model(base_model, max_seq_length=512):
    model = SentenceTransformer(base_model)
    model.max_seq_length = max_seq_length
    return model


def get_distance_func(distance_metric):
    distance_func = BatchHardTripletLossDistanceFunction.cosine_distance
    if distance_metric == "dot":
        distance_func = dot_distance
    return distance_func


def prepare_loss(model, loss_name, **args):
    if loss_name == "BatchSemiHardTripletLoss":
        distance_metric = args["distance_metric"]
        margin = args["margin"]
        loss = BatchSemiHardTripletLoss(
            model, distance_metric=distance_metric, margin=margin
        )
    elif loss_name == "BatchHardTripletLoss":
        distance_metric = args["distance_metric"]
        margin = args["margin"]
        loss = BatchHardTripletLoss(
            model, distance_metric=distance_metric, margin=margin
        )
    elif loss_name == "BatchHardSoftMarginTripletLoss":
        distance_metric = args["distance_metric"]
        loss = BatchHardTripletLoss(model, distance_metric=distance_metric)
    elif loss_name == "BatchAllTripletLoss":
        distance_metric = args["distance_metric"]
        margin = args["margin"]
        loss = BatchAllTripletLoss(
            model, distance_metric=distance_metric, margin=margin
        )
    return loss

In [None]:
def train(
    checkpoints_dir="checkpoints/",
    max_steps=100,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    log_steps=50,
    seed=42,
    metric_for_best_model="accuracy",
    gradient_checkpointing=False,
    torch_compile=False,
    dataloader_pin_memory=True,
    model=None,
    train_dataset=None,
    val_dataset=None,
    loss=None,
    val_evaluator=None,
    optimizer=None,
    lr_scheduler=None,
    use_cpu=False,
):
    train_args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=checkpoints_dir,
        # Optional training parameters:
        max_steps=max_steps,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        learning_rate=learning_rate,
        warmup_ratio=warmup_ratio,
        fp16=not use_cpu,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=False,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=log_steps,
        save_strategy="steps",
        save_steps=log_steps,
        save_total_limit=2,
        logging_steps=log_steps,
        run_name="train-experiment",
        report_to="none",
        use_cpu=use_cpu,
        seed=seed,
        load_best_model_at_end=True,
        metric_for_best_model=metric_for_best_model,
        greater_is_better=True,
        dataloader_num_workers=4,
        gradient_checkpointing=gradient_checkpointing,
        torch_compile=torch_compile,
        dataloader_pin_memory=dataloader_pin_memory,
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=train_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        loss=loss,
        evaluator=val_evaluator,
        optimizers=(optimizer, lr_scheduler),
    )
    trainer.train()

In [None]:
def save_experiment_params(config, result_dir):
    dataset_params = dict(
        text_col=config.text_col,
        label_col=config.label_col,
        preprocessed_text_col=config.preprocessed_text_col,
        preprocessed_label_col=config.preprocessed_label_col,
        data_volume=str(config.data_volume),
        dataset_path=str(config.dataset_path),
        data_folds_path=str(config.data_folds_path),
        eval_data_path=str(config.eval_data_path),
    )

    model_params = dict(
        base_model=config.base_model, max_seq_length=config.max_seq_length
    )

    train_params = dict(
        distance_metric=config.distance_metric,
        loss=config.loss_func,
        margin=config.margin,
        max_steps=config.max_steps,
        per_device_train_batch_size=config.per_device_train_batch_size,
        per_device_eval_batch_size=config.per_device_eval_batch_size,
        learning_rate=config.learning_rate,
        warmup_ratio=config.warmup_ratio,
        seed=config.seed,
        gradient_checkpointing=config.gradient_checkpointing,
        lr_scheduler_kwargs=config.lr_scheduler_kwargs,
        torch_compile=config.torch_compile,
        dataloader_pin_memory=config.dataloader_pin_memory,
    )

    params = {
        "dataset_params": dataset_params,
        "train_params": train_params,
        "model_params": model_params,
    }

    result_dir = Path(result_dir)
    with open(result_dir / "experiment_params.json", "w", encoding="utf-8") as f:
        json.dump(params, f)


def evaluate_and_save_results(model, evaluator, fold="test", metrics_dir="metrics/"):
    metrics_dir = Path(metrics_dir)
    results = evaluator(model)
    test_metrics = dict(results)
    with open(metrics_dir / f"{fold}_metrics.json", "w", encoding="utf-8") as f:
        json.dump(test_metrics, f)

In [None]:
config = Config()

datasets = prepare_datasets(
    config.train_path,
    config.val_path,
    config.test_path,
    text_col=config.text_col,
    label_col=config.label_col,
    preprocessed_label_col=config.preprocessed_label_col,
    min_num_examples_threshold=2,
)

if config.eval_data_path.exists() and any(config.eval_data_path.iterdir()):
    val_evaluator = prepare_evaluator(
        config.val_triplets_path,
        fold="val",
        batch_size=config.per_device_eval_batch_size,
    )

    test_evaluator = prepare_evaluator(
        config.test_triplets_path,
        fold="test",
        batch_size=config.per_device_eval_batch_size,
    )
else:
    val_evaluator = None
    test_evaluator = None

logging.basicConfig(
    filename=str(config.logs_dir / f"{config.experiment_name}_script.log"),
    level=logging.INFO,
)

model = init_model(config.base_model, max_seq_length=config.max_seq_length)
distance_metric = get_distance_func(config.distance_metric)
loss = prepare_loss(
    model, config.loss_func, distance_metric=distance_metric, margin=config.margin
)
optimizer = AdamW(model.parameters(), lr=config.learning_rate, fused=True)
lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
    optimizer, **config.lr_scheduler_kwargs
)

train(
    checkpoints_dir=config.checkpoints_dir,
    max_steps=config.max_steps,
    per_device_train_batch_size=config.per_device_train_batch_size,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    learning_rate=config.learning_rate,
    warmup_ratio=config.warmup_ratio,
    seed=config.seed,
    gradient_checkpointing=config.gradient_checkpointing,
    log_steps=config.log_steps,
    metric_for_best_model=config.metric_for_best_model,
    torch_compile=config.torch_compile,
    dataloader_pin_memory=config.dataloader_pin_memory,
    model=model,
    train_dataset=datasets["train"],
    val_dataset=datasets["val"],
    loss=loss,
    val_evaluator=val_evaluator,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    use_cpu=config.use_cpu,
)

model.save_pretrained(str(config.result_dir))

save_experiment_params(config, config.result_dir)

if test_evaluator is not None:
    evaluate_and_save_results(model, test_evaluator, metrics_dir=config.metrics_dir)