# Дообучение encoder'а для классификации токенов.

Реализация классификатора.

In [None]:
!pip install --upgrade --no-cache-dir \
  --extra-index-url https://download.pytorch.org/whl/cu124 \
  numpy==1.26.4 \
  pandas==2.2.3 \
  tqdm==4.67.1 \
  transformers==4.51.3 \
  evaluate==0.4.5 \
  torch==2.6.0+cu124 \
  seqeval==1.2.2

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import gc
import math
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import torch
import evaluate
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForTokenClassification,
    TrainerCallback,
    PrinterCallback,
    EarlyStoppingCallback,
)
from transformers.modeling_outputs import ModelOutput
import pandas as pd
from datasets import Dataset
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm


def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class PbarConsoleLogger(TrainerCallback):
    def __init__(self, pbar):
        self.pbar = pbar
        self.last_logs = {}
        self.last_train_loss = None
        self.printed_eval_steps = set()

    def _step(self, state) -> int:
        return int(state.global_step or 0)

    def _fmt_postfix(self):
        parts = []
        if 'loss' in self.last_logs:
            parts.append(f"loss {self.last_logs['loss']:.4f}")
        if 'eval_loss' in self.last_logs:
            parts.append(f"val {self.last_logs['eval_loss']:.4f}")
        for k, v in self.last_logs.items():
            if k.startswith('eval_') and k not in (
                'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'
            ):
                try:
                    parts.append(f"{k.replace('eval_', '')} {float(v):.4f}")
                except Exception:
                    pass
        return " | ".join(parts)

    def on_step_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        if self.last_logs:
            self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        for k, v in logs.items():
            if isinstance(v, (int, float)):
                self.last_logs[k] = float(v)
        if 'loss' in logs and isinstance(logs['loss'], (int, float)):
            self.last_train_loss = float(logs['loss'])

        self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

        if any(k.startswith('eval_') for k in logs.keys()):
            step = self._step(state)
            if step in self.printed_eval_steps:
                return
            self.printed_eval_steps.add(step)

            train_loss_str = f"{self.last_train_loss:.10f}" if self.last_train_loss is not None else "n/a"
            val_loss = logs.get('eval_loss', None)
            val_loss_str = f"{float(val_loss):.10g}" if isinstance(val_loss, (int, float)) else "n/a"

            exclude = {'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'}
            extra_parts = []
            for k, v in logs.items():
                if k.startswith('eval_') and k not in exclude:
                    metric_name = k.replace('eval_', '')
                    try:
                        extra_parts.append(f"val {metric_name}: {float(v):.10f}")
                    except Exception:
                        pass

            line = f"step: {step}, train loss: {train_loss_str}, val loss: {val_loss_str}"
            if extra_parts:
                line += ", " + ", ".join(extra_parts)
            tqdm.write(line)

    def on_train_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        self.pbar.refresh()


class WeightedTokenCETrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = torch.as_tensor(class_weights, dtype=torch.float32) if class_weights is not None else None
        self._warned_label_tiling = False

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits if hasattr(outputs, "logits") else outputs["logits"]

        if logits.size(0) != labels.size(0):
            ngpu = torch.cuda.device_count()
            if ngpu > 1 and logits.size(0) == labels.size(0) * ngpu:
                labels = labels.repeat_interleave(ngpu, dim=0)
                if not self._warned_label_tiling:
                    print(f"[Warning] DataParallel удвоил batch для logits. "
                          f"Повторяем labels x{ngpu}. logits: {tuple(logits.shape)}, labels: {tuple(labels.shape)}")
                    self._warned_label_tiling = True
            else:
                raise ValueError(f"Batch size mismatch: logits {tuple(logits.shape)} vs labels {tuple(labels.shape)}")

        loss_fct = torch.nn.CrossEntropyLoss(
            weight=(self.class_weights.to(logits.device) if self.class_weights is not None else None),
            ignore_index=-100,
        )
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return (loss, outputs) if return_outputs else loss


class TokenClassification:
    """
    Класс-обёртка для обучения и инференса моделей токен-классификации (NER/POS и т.д.)
    на базе Hugging Face Transformers. Поддерживает обучение со «скользящим окном»,
    выравнивание меток слов с субтокенами, расчёт весов классов по словам, раннюю остановку,
    агрегирование логитов по перекрывающимся окнам и извлечение эмбеддингов слов.
    """

    def __init__(
        self,
        checkpoint: str,
        label2id: Dict[str, int],
        tokens_column_name: str,
        tags_column_name: str
    ):
        """
        Инициализирует модель, токенайзер и инфраструктуру для обучения/инференса.

        :param checkpoint: имя/путь модели в Hugging Face (например, 'bert-base-cased').
        :param label2id: словарь отображения строковых меток в целочисленные id.
        :param tokens_column_name: имя колонки DataFrame с токенами (словами).
        :param tags_column_name: имя колонки DataFrame с метками (строки или уже id).
        :return: None
        :raises: ValueError при некорректных входных параметрах (например, пустой label2id).
        """
        self.id2label = {v: k for k, v in label2id.items()}
        self.label2id = label2id

        self.model = AutoModelForTokenClassification.from_pretrained(
            checkpoint,
            num_labels=len(self.id2label),
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        self.tokens_column_name = tokens_column_name
        self.tags_column_name = tags_column_name

        # Градиентный чекпоинтинг (если поддерживается моделью)
        try:
            self.model.gradient_checkpointing_enable()
        except Exception:
            pass

        self.trainer: Optional[Trainer] = None
        self.compute_metrics = None
        self.data_collator = DataCollatorForTokenClassification(tokenizer=self.tokenizer)
        self.progress_callback: Optional[TrainerCallback] = None

    # ------------------------------
    # Вспомогательные хелперы
    # ------------------------------
    @staticmethod
    def _labels_are_strings(labels_col_list) -> bool:
        """
        Определяет, представлены ли метки строками (а не id).

        :param labels_col_list: итерируемая коллекция списков меток (по документам).
        :return: True, если метки строковые; False, если уже id или все пусто.
        :raises: None
        """
        for tags in labels_col_list:
            if isinstance(tags, (list, tuple)) and len(tags) > 0:
                return isinstance(tags[0], str)
        return False

    def _label_to_id(self, tag: str) -> int:
        """
        Преобразует строковую метку в id согласно self.label2id.

        :param tag: строковая метка.
        :return: целочисленный id метки.
        :raises ValueError: если метка отсутствует в label2id.
        """
        if tag not in self.label2id:
            raise ValueError(
                f"Unknown label encountered: '{tag}'. "
                f"Known labels: {sorted(self.label2id.keys())}"
            )
        return int(self.label2id[tag])

    def _assert_tokens_labels_same_len(self, tokens_seq, labels_seq):
        """
        Проверяет совпадение длины списков токенов и меток для каждого документа.

        :param tokens_seq: iterable со списками токенов (по документам).
        :param labels_seq: iterable со списками меток (по документам).
        :return: None
        :raises ValueError: если типы неверны или длины не совпадают.
        """
        for i, (toks, labs) in enumerate(zip(tokens_seq, labels_seq)):
            if not isinstance(toks, (list, tuple)) or not isinstance(labs, (list, tuple)):
                raise ValueError(
                    f"Row {i}: tokens/labels must be lists, got "
                    f"{type(toks).__name__} and {type(labs).__name__}"
                )
            if len(toks) != len(labs):
                raise ValueError(
                    f"Row {i}: tokens and labels length mismatch: "
                    f"{len(toks)} vs {len(labs)}"
                )

    @staticmethod
    def _to_token_list(obj):
        """
        Приводит значение ячейки к списку токенов.

        :param obj: значение колонки токенов (list/tuple/np.ndarray/None/другое).
        :return: список токенов (или пустой список при неподдерживаемом типе).
        :raises: None
        """
        if obj is None:
            return []
        if isinstance(obj, (list, tuple)):
            return list(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return []

    def _get_effective_max_length(self) -> int:
        """
        Возвращает безопасную максимальную длину контекста:
        min(model.config.max_position_embeddings, tokenizer.model_max_length),
        игнорируя «бесконечные» значения токенайзера.

        :return: целочисленное значение безопасной максимальной длины.
        :raises: None
        """
        m_conf = int(getattr(self.model.config, "max_position_embeddings", 512) or 512)
        m_tok = int(getattr(self.tokenizer, "model_max_length", 512) or 512)
        if m_tok > 100000:
            return m_conf
        return min(m_conf, m_tok)

    @staticmethod
    def _sanitize_stride(stride: int, max_length: int) -> int:
        """
        Ограничивает stride до [0, max_length - 2], учитывая спецтокены.

        :param stride: желаемый страйд перекрытия.
        :param max_length: безопасная максимальная длина контекста.
        :return: целочисленное безопасное значение stride.
        :raises: None
        """
        stride = int(max(0, stride))
        return int(min(stride, max(0, max_length - 2)))

    # ------------------------------
    # Алгоритмика
    # ------------------------------
    @staticmethod
    def _align_labels_with_word_ids(labels_ids: List[int], word_ids: List[Optional[int]]) -> List[int]:
        """
        Выравнивает метки слов по субтокенам: первый субтокен слова получает метку,
        последующие субтокены — -100 (игнор в CrossEntropy).

        :param labels_ids: список меток по словам (id), длина = числу слов.
        :param word_ids: список индексов слов для каждого субтокена (tokenizer.word_ids()).
        :return: список меток по длине субтокенов, с -100 для игнорируемых позиций.
        :raises: None
        """
        new_labels = []
        prev_word_id = None
        L = len(labels_ids)
        for wid in word_ids:
            if wid is None or wid < 0 or wid >= L:
                new_labels.append(-100)
            else:
                if wid != prev_word_id:
                    new_labels.append(labels_ids[wid])
                else:
                    new_labels.append(-100)
            prev_word_id = wid
        return new_labels

    def _tokenize_and_align_chunk(
        self,
        docs_tokens: List[List[str]],
        docs_labels_ids: List[List[int]],
        max_length: int,
        stride: int
    ) -> Dataset:
        """
        Токенизирует документы с разбиением на окна и выравниванием меток.

        :param docs_tokens: списки токенов по документам.
        :param docs_labels_ids: списки меток (id) по документам.
        :param max_length: безопасная максимальная длина контекста.
        :param stride: перекрытие между окнами.
        :return: HF Dataset с полями input_ids, attention_mask, labels.
        :raises: None
        """
        enc = self.tokenizer(
            docs_tokens,
            is_split_into_words=True,
            return_overflowing_tokens=True,
            max_length=max_length,
            stride=stride,
            truncation=True
        )
        mapping = enc.pop("overflow_to_sample_mapping")
        num_chunks = len(enc["input_ids"])

        all_labels = []
        for i in range(num_chunks):
            doc_idx = int(mapping[i])
            word_ids = enc.word_ids(batch_index=i)
            aligned = self._align_labels_with_word_ids(docs_labels_ids[doc_idx], word_ids)
            all_labels.append(aligned)

        return Dataset.from_dict({
            "input_ids": enc["input_ids"],
            "attention_mask": enc["attention_mask"],
            "labels": all_labels
        })

    def _count_total_chunks(
        self,
        docs_tokens: List[List[str]],
        max_length: int,
        stride: int,
        batch_docs: int = 64
    ) -> int:
        """
        Подсчитывает число чанков (окон) после токенизации набора документов.

        :param docs_tokens: списки токенов по документам.
        :param max_length: безопасная максимальная длина контекста.
        :param stride: перекрытие между окнами.
        :param batch_docs: размер батча документов при токенизации.
        :return: общее число чанков (int).
        :raises: None
        """
        total = 0
        for i in range(0, len(docs_tokens), batch_docs):
            batch = docs_tokens[i:i + batch_docs]
            enc = self.tokenizer(
                batch,
                is_split_into_words=True,
                return_overflowing_tokens=True,
                max_length=max_length,
                stride=stride,
                truncation=True
            )
            total += len(enc["input_ids"])
        return total

    def _compute_class_weights_over_words(self, docs_labels_ids: List[List[int]]) -> np.ndarray:
        """
        Считает веса классов по словам (без влияния overlap-окон).

        :param docs_labels_ids: списки меток (id) по документам.
        :return: массив весов классов shape=(num_labels,), dtype=float32.
        :raises: None
        """
        num_labels = len(self.id2label)
        counts = np.zeros(num_labels, dtype=np.int64)
        for labs in docs_labels_ids:
            if isinstance(labs, (list, tuple)) and len(labs) > 0:
                arr = np.asarray(labs, dtype=np.int64)
                arr = arr[(arr >= 0) & (arr < num_labels)]
                if arr.size > 0:
                    counts += np.bincount(arr, minlength=num_labels)
        N = counts.sum()
        weights = np.zeros(num_labels, dtype=np.float32)
        if N > 0:
            nonzero = counts > 0
            weights[nonzero] = N / (num_labels * counts[nonzero].astype(np.float32))
        return weights

    @staticmethod
    def _normalize_clip_weights(w: np.ndarray, clip: float = 5.0) -> np.ndarray:
        """
        Нормирует и клипует веса классов: клип сверху до clip и нормировка
        положительных весов к среднему ~1.0.

        :param w: исходные веса классов.
        :param clip: верхняя граница клипа (None/<=0 — без клипа).
        :return: нормированные веса dtype=float32.
        :raises: None
        """
        w = np.asarray(w, dtype=np.float32)
        if clip is not None and clip > 0:
            w = np.minimum(w, clip)
        mask = w > 0
        mean = float(np.mean(w[mask])) if np.any(mask) else 1.0
        if mean > 0:
            w = w / mean
        return w

    def _setup_compute_metrics(self):
        """
        Создаёт и сохраняет функцию метрик для seqeval (self.compute_metrics).

        Метрики:
        - precision/recall/f1/accuracy — агрегированные;
        - f1_{entity} — по каждой сущности.

        :return: None
        :raises: None
        """
        metric = evaluate.load("seqeval")

        def compute_seqeval_metrics(p):
            if isinstance(p, (tuple, list)):
                predictions, labels = p
            else:
                predictions, labels = p.predictions, p.label_ids

            predictions = np.argmax(predictions, axis=2)

            true_predictions = [
                [self.id2label[p] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
            true_labels = [
                [self.id2label[l] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]

            results = metric.compute(predictions=true_predictions, references=true_labels)

            out = {
                "precision": results.get("overall_precision", 0.0),
                "recall": results.get("overall_recall", 0.0),
                "f1": results.get("overall_f1", 0.0),
                "accuracy": results.get("overall_accuracy", 0.0),
            }
            for ent, vals in results.items():
                if isinstance(vals, dict) and "f1" in vals:
                    out[f"f1_{ent}"] = float(vals["f1"])
            return out

        self.compute_metrics = compute_seqeval_metrics

    def _prepare_dataset_with_sliding_window(self, df: pd.DataFrame, max_length: int, stride: int) -> Dataset:
        """
        Готовит HF Dataset для оценки/валидации со «скользящим окном».

        :param df: DataFrame с колонками токенов и меток.
        :param max_length: безопасная максимальная длина контекста.
        :param stride: перекрытие между окнами.
        :return: HF Dataset с полями input_ids, attention_mask, labels.
        :raises ValueError: при неверных типах или несовпадении длины токенов и меток.
        """
        docs_tokens = df[self.tokens_column_name].tolist()
        docs_labels = df[self.tags_column_name].tolist()

        if self._labels_are_strings(docs_labels):
            docs_labels = [[self._label_to_id(tag) for tag in tags] for tags in docs_labels]

        filtered_tokens, filtered_labels = [], []
        for i, (toks, labs) in enumerate(zip(docs_tokens, docs_labels)):
            if not isinstance(toks, (list, tuple)) or not isinstance(labs, (list, tuple)):
                raise ValueError(
                    f"Row {i}: tokens/labels must be lists, got "
                    f"{type(toks).__name__} and {type(labs).__name__}"
                )
            if len(toks) == 0 and len(labs) == 0:
                continue
            if len(toks) != len(labs):
                raise ValueError(
                    f"Row {i}: tokens and labels length mismatch: "
                    f"{len(toks)} vs {len(labs)}"
                )
            filtered_tokens.append(list(toks))
            filtered_labels.append(list(labs))

        if len(filtered_tokens) == 0:
            return Dataset.from_dict({"input_ids": [], "attention_mask": [], "labels": []})

        return self._tokenize_and_align_chunk(filtered_tokens, filtered_labels, max_length, stride)

    # ------------------------------
    # Обучение
    # ------------------------------
    def fit(
        self,
        train_data: pd.DataFrame,
        epochs: int = 3,
        per_device_train_batch_size: int = 16,
        gradient_accumulation_steps: int = 1,
        test_size: float = 0.2,
        learning_rate: float = 2e-5,
        fp16: bool = True,
        stride: int = 128,
        logging_steps: int = 50,
        eval_steps: int = 100,
        output_dir: str = "./result",
        seed: int = 42,
        fit_chunk_size_docs: Optional[int] = None,
        early_stopping_patience: Optional[int] = 3,
        early_stopping_threshold: float = 0.0,
    ):
        """
        Обучает модель токен-классификации на данных.

        :param train_data: DataFrame с колонками токенов и меток.
        :param epochs: число эпох (проходов) по обучающему набору документов.
        :param per_device_train_batch_size: размер батча на устройство при обучении.
        :param gradient_accumulation_steps: число шагов аккумуляции градиента.
        :param test_size: доля/размер валидации; при слишком малом наборе eval отключается автоматически.
        :param learning_rate: скорость обучения (LR).
        :param fp16: использовать ли fp16 (если bf16 не используется и доступен CUDA).
        :param stride: перекрытие между окнами для токенизации длинных документов.
        :param logging_steps: частота логирования в шагах.
        :param eval_steps: частота валидации/сохранения (если есть eval).
        :param output_dir: директория для артефактов обучения.
        :param seed: seed для воспроизводимости.
        :param fit_chunk_size_docs: сколько документов обучать за один «кусок» перед сменой train_dataset (None = все).
        :param early_stopping_patience: количество подряд неулучшающихся точек валидации до остановки;
                                       если None или <= 0 — ранняя остановка не используется.
        :param early_stopping_threshold: минимальное относительное улучшение метрики, требуемое для сброса счётчика patience.
        :return: self (для чейнинга).
        :raises ValueError: при несогласованных данных (тип/длина токенов и меток).
        """
        set_seed(seed)

        max_length = self._get_effective_max_length()
        stride = self._sanitize_stride(stride, max_length)

        df_all = train_data.copy()
        if self._labels_are_strings(df_all[self.tags_column_name].tolist()):
            df_all[self.tags_column_name] = df_all[self.tags_column_name].apply(
                lambda tags: [self._label_to_id(tag) for tag in tags]
            )

        self._assert_tokens_labels_same_len(
            df_all[self.tokens_column_name].tolist(),
            df_all[self.tags_column_name].tolist()
        )

        # Робастный train/val split
        n_total = len(df_all)
        use_eval = False
        test_size_abs = 0
        if n_total >= 2 and test_size and float(test_size) > 0:
            if isinstance(test_size, float):
                test_size_abs = int(round(n_total * float(test_size)))
            else:
                test_size_abs = int(test_size)
            if test_size_abs <= 0:
                test_size_abs = 1
            if test_size_abs >= n_total:
                test_size_abs = n_total - 1
            use_eval = test_size_abs > 0

        if use_eval:
            df_train, df_eval = train_test_split(df_all, test_size=test_size_abs, random_state=seed, shuffle=True)
        else:
            df_train = df_all
            df_eval = df_all.iloc[0:0]

        eval_dataset = None
        if len(df_eval) > 0:
            eval_dataset = self._prepare_dataset_with_sliding_window(df_eval, max_length, stride)

        train_docs_tokens = df_train[self.tokens_column_name].tolist()
        train_docs_labels = df_train[self.tags_column_name].tolist()

        class_weights = self._compute_class_weights_over_words(train_docs_labels)
        class_weights = self._normalize_clip_weights(class_weights, clip=5.0)

        self._setup_compute_metrics()

        args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=max(4, per_device_train_batch_size // 2),
            gradient_accumulation_steps=gradient_accumulation_steps,
            eval_accumulation_steps=max(1, gradient_accumulation_steps * 2),
            learning_rate=learning_rate,
            warmup_ratio=0.1,
            lr_scheduler_type="cosine",
            weight_decay=0.01,
            eval_strategy="steps" if eval_dataset is not None else "no",
            eval_steps=eval_steps,
            save_strategy="steps" if eval_dataset is not None else "no",
            save_steps=eval_steps,
            load_best_model_at_end=bool(eval_dataset is not None),
            metric_for_best_model="eval_f1",
            save_total_limit=1,
            logging_strategy="steps",
            logging_steps=logging_steps,
            report_to="none",
            fp16=bool(fp16 and torch.cuda.is_available()),
            dataloader_num_workers=min(4, os.cpu_count() or 4),
            seed=seed,
            remove_unused_columns=False,
            disable_tqdm=True,
            dataloader_pin_memory=True,
            gradient_checkpointing=True,
        )

        data_collator = self.data_collator

        def steps_for_size(n_samples: int, bsz: int, accum: int) -> int:
            return max(0, math.ceil(math.ceil(n_samples / max(1, bsz)) / max(1, accum)))

        def chunk_slices(n_docs: int, chunk_docs: int):
            for i in range(0, n_docs, chunk_docs):
                yield slice(i, min(i + chunk_docs, n_docs))

        n_docs = len(train_docs_tokens)
        chunk_docs = int(fit_chunk_size_docs) if (fit_chunk_size_docs and fit_chunk_size_docs > 0) else n_docs

        total_steps = 0
        rng = np.random.default_rng(seed)
        doc_indices = np.arange(n_docs)
        for _ in range(epochs):
            rng.shuffle(doc_indices)
            for slc in chunk_slices(n_docs, chunk_docs):
                idx = doc_indices[slc]
                toks_chunk = [train_docs_tokens[i] for i in idx]
                n_samples = self._count_total_chunks(toks_chunk, max_length, stride, batch_docs=64)
                total_steps += steps_for_size(n_samples, per_device_train_batch_size, gradient_accumulation_steps)

        if n_docs > 0:
            init_chunk_ds = self._tokenize_and_align_chunk(
                [train_docs_tokens[0]], [train_docs_labels[0]], max_length, stride
            )
        else:
            init_chunk_ds = Dataset.from_dict({"input_ids": [], "attention_mask": [], "labels": []})

        self.trainer = WeightedTokenCETrainer(
            model=self.model,
            args=args,
            data_collator=data_collator,
            compute_metrics=self.compute_metrics if eval_dataset is not None else None,
            train_dataset=init_chunk_ds,
            eval_dataset=eval_dataset,
            tokenizer=self.tokenizer,
            class_weights=class_weights
        )
        try:
            self.trainer.remove_callback(PrinterCallback)
        except Exception:
            pass

        if total_steps > 0:
            self.trainer.create_optimizer_and_scheduler(num_training_steps=total_steps)

        pbar = tqdm(total=total_steps, desc="Training Progress", unit="step")
        cb = PbarConsoleLogger(pbar)
        self.trainer.add_callback(cb)

        # Ранняя остановка (если есть eval и запрошена)
        if eval_dataset is not None and (early_stopping_patience is not None) and (early_stopping_patience > 0):
            early_cb = EarlyStoppingCallback(
                early_stopping_patience=int(early_stopping_patience),
                early_stopping_threshold=float(early_stopping_threshold),
            )
            self.trainer.add_callback(early_cb)

        self.progress_callback = cb

        steps_done = 0
        for ep in range(epochs):
            rng = np.random.default_rng(seed + ep)
            order = np.arange(n_docs)
            rng.shuffle(order)

            for slc in chunk_slices(n_docs, chunk_docs):
                idx = order[slc]
                toks_chunk = [train_docs_tokens[i] for i in idx]
                labs_chunk = [train_docs_labels[i] for i in idx]

                ds_chunk = self._tokenize_and_align_chunk(toks_chunk, labs_chunk, max_length, stride)
                self.trainer.train_dataset = ds_chunk

                n_samples = len(ds_chunk)
                chunk_steps = steps_for_size(n_samples, per_device_train_batch_size, gradient_accumulation_steps)
                if chunk_steps == 0:
                    del ds_chunk
                    continue

                self.trainer.args.max_steps = steps_done + chunk_steps
                self.trainer.train()
                steps_done += chunk_steps

                del ds_chunk
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        pbar.close()
        return self

    # ------------------------------
    # Инференс
    # ------------------------------
    def _predict_single_document(self, tokens: List[str], stride: int) -> List[str]:
        """
        Предсказывает метки для одного документа со «скользящим окном».

        :param tokens: список слов (токенов) документа.
        :param stride: перекрытие между окнами.
        :return: список строковых меток той же длины, что и tokens.
        :raises: None
        """
        if not isinstance(tokens, (list, tuple)) or len(tokens) == 0:
            return []

        max_length = self._get_effective_max_length()
        stride = self._sanitize_stride(stride, max_length)

        tokenized_inputs = self.tokenizer(
            [tokens],
            is_split_into_words=True,
            return_overflowing_tokens=True,
            max_length=max_length,
            stride=stride,
            truncation=True,
        )
        tokenized_inputs.pop("overflow_to_sample_mapping", None)

        default_id = int(min(self.id2label.keys())) if len(self.id2label) else 0
        default_label = self.id2label.get(default_id, str(default_id))

        if not isinstance(tokenized_inputs.get("input_ids", None), list) or len(tokenized_inputs["input_ids"]) == 0:
            return [default_label] * len(tokens)

        chunk_dataset = Dataset.from_dict(tokenized_inputs)
        outputs = self.trainer.predict(chunk_dataset)

        if hasattr(outputs, "predictions"):
            preds = outputs.predictions
        else:
            preds = outputs["predictions"]

        num_original_words = len(tokens)

        # Основной путь: 3D логиты (num_chunks, seq_len, num_labels)
        if isinstance(preds, np.ndarray) and preds.ndim == 3:
            num_labels = preds.shape[-1]
            word_logits = np.zeros((num_original_words, num_labels), dtype=np.float32)
            word_counts = np.zeros((num_original_words,), dtype=np.float32)

            for i in range(preds.shape[0]):
                chunk_logits = preds[i]
                try:
                    chunk_word_ids = tokenized_inputs.word_ids(batch_index=i)
                except Exception:
                    continue
                if chunk_word_ids is None:
                    continue

                for token_pos, word_id in enumerate(chunk_word_ids):
                    if word_id is None:
                        continue
                    if token_pos == 0 or chunk_word_ids[token_pos - 1] != word_id:
                        if 0 <= word_id < num_original_words:
                            word_logits[word_id] += chunk_logits[token_pos]
                            word_counts[word_id] += 1.0

            mask = word_counts > 0
            if np.any(mask):
                word_logits[mask] /= word_counts[mask, None]

            pred_ids = np.full(num_original_words, default_id, dtype=np.int32)
            if np.any(mask):
                pred_ids[mask] = word_logits[mask].argmax(-1)

            filled = [self.id2label.get(int(x), str(int(x))) for x in pred_ids]
            return filled

        # Fallback: если preds не 3D
        if isinstance(preds, np.ndarray):
            if preds.ndim == 2:
                predictions = preds
            elif preds.ndim == 1:
                predictions = preds[None, :]
            else:
                predictions = preds.reshape(len(tokenized_inputs["input_ids"]), -1)
        else:
            predictions = np.asarray(preds)

        final_predictions = np.full(num_original_words, -1, dtype=np.int32)
        num_chunks = predictions.shape[0]
        for i in range(num_chunks):
            chunk_preds = predictions[i]
            try:
                chunk_word_ids = tokenized_inputs.word_ids(batch_index=i)
            except Exception:
                continue
            if chunk_word_ids is None:
                continue

            chunk_len = len(chunk_preds)
            for token_pos, word_id in enumerate(chunk_word_ids):
                if token_pos >= chunk_len:
                    break
                if word_id is None:
                    continue
                if 0 <= word_id < num_original_words and final_predictions[word_id] == -1:
                    final_predictions[word_id] = int(chunk_preds[token_pos])

        filled = [
            self.id2label.get(pid, default_label) if pid != -1 else default_label
            for pid in final_predictions
        ]
        return filled

    def predict(self, df: pd.DataFrame, stride: int = 128) -> List[List[str]]:
        """
        Предсказывает метки для всех документов из DataFrame.

        :param df: DataFrame с колонкой токенов.
        :param stride: перекрытие между окнами.
        :return: список документов, каждый — список строковых меток по словам.
        :raises RuntimeError: если модель не обучена и документы непустые.
        """
        all_final_labels = []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Предсказание (sliding window)"):
            try:
                original_tokens = row.get(self.tokens_column_name, None)
            except Exception:
                original_tokens = None

            tokens = self._to_token_list(original_tokens)

            if len(tokens) == 0:
                all_final_labels.append([])
                continue

            if self.trainer is None or self.trainer.model is None:
                raise RuntimeError("Модель не обучена. Вызовите .fit() перед .predict() для непустых документов.")

            document_labels = self._predict_single_document(tokens, stride)
            all_final_labels.append(document_labels)
        return all_final_labels

    # ------------------------------
    # Эмбеддинги
    # ------------------------------
    def _get_embeddings_single_document(self, tokens: List[str], stride: int, device: torch.device) -> np.ndarray:
        """
        Извлекает эмбеддинги слов для одного документа.

        :param tokens: список слов документа.
        :param stride: перекрытие между окнами.
        :param device: устройство модели (CPU/GPU).
        :return: массив формы (num_words, hidden_size), dtype=float32.
        :raises: None
        """
        max_length = self._get_effective_max_length()
        stride = self._sanitize_stride(stride, max_length)
        num_original_words = len(tokens)

        chunk_inputs = self.tokenizer(
            [tokens],
            is_split_into_words=True,
            return_overflowing_tokens=True,
            max_length=max_length,
            stride=stride,
            truncation=True,
            return_tensors="pt"
        ).to(device)

        chunk_inputs.pop("overflow_to_sample_mapping")

        with torch.no_grad():
            base_model = getattr(self.trainer.model, self.trainer.model.base_model_prefix)
            outputs = base_model(**chunk_inputs)

        chunk_embeddings = outputs.last_hidden_state

        hidden_size = int(self.model.config.hidden_size)
        final_word_embeddings = torch.zeros(num_original_words, hidden_size, device=device)
        word_counts = torch.zeros(num_original_words, device=device)

        for i in range(len(chunk_embeddings)):
            chunk_embeds = chunk_embeddings[i]
            chunk_word_ids = chunk_inputs.word_ids(batch_index=i)
            for token_pos, word_id in enumerate(chunk_word_ids):
                if word_id is not None:
                    final_word_embeddings[word_id] += chunk_embeds[token_pos]
                    if token_pos == 0 or chunk_word_ids[token_pos - 1] != word_id:
                        word_counts[word_id] += 1

        average_embeddings = final_word_embeddings / (word_counts.unsqueeze(1) + 1e-8)
        return average_embeddings.detach().cpu().numpy()

    def get_embeddings(self, df: pd.DataFrame, stride: int = 128) -> List[np.ndarray]:
        """
        Извлекает эмбеддинги слов для каждого документа в DataFrame.

        :param df: DataFrame с колонкой токенов.
        :param stride: перекрытие между окнами.
        :return: список массивов эмбеддингов, по одному на документ (num_words, hidden_size).
        :raises RuntimeError: если модель не обучена.
        """
        if self.trainer is None or self.trainer.model is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        self.trainer.model.eval()
        device = self.trainer.model.device
        all_final_embeddings = []

        for _, row in tqdm(df.iterrows(), total=len(df), desc="Генерация эмбеддингов (sliding window)"):
            original_tokens = row[self.tokens_column_name]
            if not original_tokens:
                all_final_embeddings.append(np.zeros((0, int(self.model.config.hidden_size)), dtype=np.float32))
                continue

            document_embeddings = self._get_embeddings_single_document(original_tokens, stride, device)
            all_final_embeddings.append(document_embeddings)

        return all_final_embeddings

Пример использования 1.

In [None]:
import pandas as pd

# Данные: строковые метки (будут конвертированы в id внутри .fit)
tokens_col, tags_col = "tokens", "tags"
label2id = {
    "O": 0,
    "B-PER": 1, "I-PER": 2,
    "B-LOC": 3, "I-LOC": 4,
    "B-ORG": 5, "I-ORG": 6,
}

df_train = pd.DataFrame([
    {tokens_col: ["John", "Doe", "lives", "in", "Berlin"], tags_col: ["B-PER","I-PER","O","O","B-LOC"]},
    {tokens_col: ["Mary", "works", "at", "Google"], tags_col: ["B-PER","O","O","B-ORG"]},
    {tokens_col: ["Alice", "is", "from", "Paris"], tags_col: ["B-PER","O","O","B-LOC"]},
    {tokens_col: ["IBM", "is", "in", "Armonk"], tags_col: ["B-ORG","O","O","B-LOC"]},
    {tokens_col: ["Bob", "moved", "to", "London"],  tags_col: ["B-PER","O","O","B-LOC"]},
    {tokens_col: ["Google", "is", "in", "California"], tags_col: ["B-ORG","O","O","B-LOC"]},
])

# Инициализация (минимум, что требует класс)
CKPT = "prajjwal1/bert-tiny"
tc = TokenClassification(
    checkpoint=CKPT,
    label2id=label2id,
    tokens_column_name=tokens_col,
    tags_column_name=tags_col
)

# Обучение с максимальной параметризацией
tc.fit(
    train_data=df_train,
    epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    test_size=0.33,             # включаем валидацию
    learning_rate=3e-5,
    fp16=True,                  # если есть CUDA — включит fp16
    stride=64,                  # скользящее окно
    logging_steps=1,
    eval_steps=2,
    output_dir="./tokcls_max_param",
    seed=123,
    fit_chunk_size_docs=2,      # обучаемся «кусками» по 2 документа
    early_stopping_patience=2,  # ранняя остановка после 2 неулучшений
    early_stopping_threshold=0.0,
)

# Метрики от Trainer (включая per-entity F1: eval_f1_PER/LOC/ORG и т.д., если встретились)
metrics = tc.trainer.evaluate()
print("Eval metrics (subset):", {k: float(v) for k, v in metrics.items() if isinstance(v, (int, float))})

# Предсказание на части данных (с отдельным stride на инференсе)
df_infer = df_train.iloc[:3]
preds = tc.predict(df_infer, stride=32)
for i, (tokens, pred) in enumerate(zip(df_infer[tokens_col], preds), 1):
    print(f"Doc {i}:")
    print(list(zip(tokens, pred)))

# Эмбеддинги слов (каждый документ -> массив [num_words, hidden_size])
embs = tc.get_embeddings(df_infer, stride=32)
print("Embeddings shapes:", [e.shape for e in embs])

Пример использования 2.

In [None]:
import pandas as pd

tokens_col, tags_col = "tokens", "tags"
label2id = {"O": 0, "B-PER": 1}  # минимальный набор меток

# Данные уже в id (минимальная разметка)
df_small = pd.DataFrame([
    {tokens_col: ["John", "works"], tags_col: [1, 0]},  # ["B-PER","O"]
    {tokens_col: ["Mary", "smiles"], tags_col: [1, 0]},  # ["B-PER","O"]
])

# Инициализация
CKPT = "prajjwal1/bert-tiny"
tc = TokenClassification(
    checkpoint=CKPT,
    label2id=label2id,
    tokens_column_name=tokens_col,
    tags_column_name=tags_col
)

# Обучение — все параметры по умолчанию
tc.fit(train_data=df_small)

# Базовый предикт — тоже по умолчанию
preds = tc.predict(df_small)
print("Preds:", preds)

# При необходимости — эмбеддинги (тоже с параметрами по умолчанию)
embs = tc.get_embeddings(df_small)
print("Embeddings shape for doc 0:", embs[0].shape)

# Дообучение классификатора, который работает с данными разной модальностью.

Реализация классификатора.

In [None]:
!pip install --upgrade --no-cache-dir \
  --extra-index-url https://download.pytorch.org/whl/cu124 \
  pillow==11.1.0 \
  numpy==1.26.4 \
  pandas==2.2.3 \
  tqdm==4.67.1 \
  transformers==4.51.3 \
  evaluate==0.4.5 \
  wav2clip==0.1.0 \
  torch==2.6.0+cu124 \
  torchaudio==2.6.0+cu124
# !pip install evaluate wav2clip

import os
import time

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import math
import random
import gc
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Union

from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from transformers import TrainingArguments, Trainer
from transformers.trainer_callback import TrainerCallback, PrinterCallback, EarlyStoppingCallback
from transformers.modeling_outputs import SequenceClassifierOutput
import evaluate

# =========================
# Утилиты
# =========================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def to_pil(x: Union[str, np.ndarray, 'Image.Image']) -> 'Image.Image':
    if isinstance(x, Image.Image): return x.convert("RGB")
    if isinstance(x, str): return Image.open(x).convert("RGB")
    if isinstance(x, np.ndarray): return Image.fromarray(x).convert("RGB")
    raise ValueError("Ожидается путь/np.ndarray/PIL.Image")

def load_audio(path: str, target_sr: int) -> np.ndarray:
    try:
        import torchaudio
    except Exception as e:
        raise RuntimeError("Требуется torchaudio: pip install torchaudio") from e
    waveform, sr = torchaudio.load(path)
    if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True)
    if sr != target_sr: waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=target_sr)
    return waveform.squeeze(0).numpy().astype(np.float32)

def safe_load(component_cls, checkpoint: str, local_cache_dir: str = "./model_cache",
              local_files_only: Optional[bool] = None, **kwargs):
    if local_files_only is None:
        local_files_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
    name = getattr(component_cls, "__name__", "")
    if "Tokenizer" in name:
        kwargs.setdefault("use_fast", True)
    return component_cls.from_pretrained(
        checkpoint, cache_dir=local_cache_dir, local_files_only=local_files_only, **kwargs
    )


# =========================
# Токенизатор батчевый
# =========================

class BatchTokenizer:
    def __init__(
        self,
        tokenizer,
        max_length: int = 512,
        cache_size: int = 10000,
        batch_size: int = 256,
        use_fast: bool = True,
        device: str = "cpu",
        padding_strategy: str = "max_length"  # "max_length" или "dynamic"
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_fast = use_fast
        self.device = device
        self.padding_strategy = padding_strategy
        if self.padding_strategy not in {"max_length", "dynamic"}:
            raise ValueError("padding_strategy должен быть 'max_length' или 'dynamic'")
        self._cache = lru_cache(maxsize=cache_size)(self._tokenize_single)
        self.is_fast = hasattr(tokenizer, "is_fast") and tokenizer.is_fast
        if self.is_fast:
            print("✓ Используется Fast Tokenizer")

    def _tokenize_single(self, text: str) -> tuple:
        # В кэше храним только версии с фиксированной длиной — иначе нельзя будет склеивать батчи
        result = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return tuple((k, v.squeeze(0).cpu().numpy()) for k, v in result.items())

    def tokenize_batch(self, texts: List[str], use_cache: bool = True) -> Dict[str, torch.Tensor]:
        # Динамический паддинг: всегда токенизируем списком целиком (без кэша по-элементно)
        if self.padding_strategy == "dynamic":
            result = self.tokenizer(
                texts,
                padding=True,  # паддинг до «longest» в батче
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            for key in result:
                if key in ["input_ids", "attention_mask", "token_type_ids"]:
                    result[key] = result[key].long()
            return result

        # Фиксированный паддинг — как раньше, но явно padding="max_length"
        if use_cache and len(texts) < 100:
            results = [dict(self._cache(text)) for text in texts]
            keys = results[0].keys()
            batch_dict = {}
            for key in keys:
                dtype = torch.long if key in ["input_ids", "attention_mask", "token_type_ids"] else torch.float32
                batch_dict[key] = torch.tensor(np.stack([r[key] for r in results]), dtype=dtype)
            return batch_dict
        else:
            result = self.tokenizer(
                texts,
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            for key in result:
                if key in ["input_ids", "attention_mask", "token_type_ids"]:
                    result[key] = result[key].long()
            return result

    def tokenize_dataset_lazy(
        self,
        texts: List[str],
        batch_size: Optional[int] = None
    ) -> Generator[Dict[str, torch.Tensor], None, None]:
        batch_size = batch_size or self.batch_size
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            yield self.tokenize_batch(batch, use_cache=False)

    def clear_cache(self):
        self._cache.cache_clear()


# =========================
# Универсальный датасет
# =========================

class MultiComboDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        target_col: str,
        label2id: Dict[Any, int],
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_tokenizer: Optional[Any] = None,         # BatchTokenizer
        text_tokenizer_fn: Optional[Callable] = None, # custom fn -> dict of tensors
        special_tokens: Optional[Dict[str, Any]] = None,
        pretokenize: bool = True,
        pretokenize_batch_size: int = 1024,
        tokenizer_returns_tensors: bool = True,
        deduplicate_texts: bool = True,
        max_cache_size=None
    ):
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.target_col = target_col
        self.label2id = label2id

        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []

        # Один путь: BatchTokenizer; другой: custom fn
        self.batch_tokenizer = text_tokenizer
        self.text_tokenizer_fn = text_tokenizer_fn

        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors

        self._N = len(self.df)

        # labels сразу в тензор [N]
        if self.target_col in self.df.columns:
            y = self.df[self.target_col].map(self.label2id).fillna(0).astype(int).values
        else:
            # если на инференсе метки нет — пусть будут 0
            y = np.zeros(self._N, dtype=np.int64)
        self._labels = torch.tensor(y, dtype=torch.long)

        # Предсобранные списки изображений/аудио (чтобы не дёргать pandas в __getitem__)
        self._image_lists = None
        if self.image_columns:
            self._image_lists = self._collect_multi_values(self.df, self.image_columns)

        self._audio_lists = None
        if self.audio_columns:
            self._audio_lists = self._collect_multi_values(self.df, self.audio_columns)

        # Предтокенизированные банки: dict(key -> torch.Tensor [N, ...])
        self._tok_bank: Optional[Dict[str, torch.Tensor]] = None

        # Предтокенизация текста (ускоряет обучение на порядки)
        self._has_text = bool(self.text_columns)

        # dynamic-паддинг несовместим с предтокенизацией (формы в батчах будут разные)
        if pretokenize and self.batch_tokenizer is not None and getattr(self.batch_tokenizer, "padding_strategy", "max_length") == "dynamic":
            print("⚠ Предтокенизация отключена: выбран dynamic-паддинг для текста.")
            pretokenize = False

        if self._has_text and pretokenize:
            t0 = time.time()
            if self.batch_tokenizer is not None and self.text_tokenizer_fn is None:
                # BatchTokenizer путь
                self._pretokenize_with_batch_tokenizer(pretokenize_batch_size)
            elif self.text_tokenizer_fn is not None:
                # Custom fn путь (equal-split и т.п.)
                self._pretokenize_with_custom_fn(pretokenize_batch_size, deduplicate_texts=deduplicate_texts)
            else:
                # Ни BatchTokenizer, ни custom fn — оставляем без токенов (коллатор потом сам токенизирует из строк)
                pass
            t1 = time.time()
            if self._tok_bank is not None:
                shapes = {k: tuple(v.shape) for k, v in self._tok_bank.items()}
                print(f"✓ Предтокенизация завершена: {self._N} образцов за {t1 - t0:.2f}s | keys={list(self._tok_bank.keys())}, shapes={shapes}")

    def __len__(self) -> int:
        return self._N

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item: Dict[str, Any] = {}
        item["labels"] = int(self._labels[idx])

        # Текст: если есть предтокенизированные банки — просто слайсим
        if self._tok_bank is not None:
            item["text_tokens"] = {k: v[idx] for k, v in self._tok_bank.items()}
        elif self._has_text:
            # Фоллбек: отдаём строку (коллатор бэкенда сам батчево токенизирует)
            item["text"] = self._join_text(self.df.iloc[idx])  # быстрый fallback; лучше всегда pretokenize

        # Изображения/аудио — просто отдаём подготовленные списки
        if self._image_lists is not None:
            item["images"] = self._image_lists[idx]
        if self._audio_lists is not None:
            item["audios"] = self._audio_lists[idx]

        return item

    # --------------------------
    # Вспомогательные методы
    # --------------------------

    def clear_cache(self):
        # Освободить предтокенизированные банки (для экономии RAM между чанками)
        self._tok_bank = None
        torch.cuda.empty_cache()

    def get_cache_stats(self) -> Dict[str, Any]:
        has = self._tok_bank is not None
        sizes = {k: tuple(v.shape) for k, v in (self._tok_bank or {}).items()}
        return {"has_pretokenized": has, "shapes": sizes, "N": self._N}

    def _join_text(self, row: pd.Series) -> str:
        sep = self.special_tokens.get("sep", " [SEP] ")
        return sep.join([("" if pd.isna(row[c]) else str(row[c])) for c in self.text_columns])

    @staticmethod
    def _as_list(v):
        if v is None or (isinstance(v, float) and math.isnan(v)):
            return []
        if isinstance(v, (list, tuple)):
            return list(v)
        return [v]

    def _collect_multi_values(self, df: pd.DataFrame, columns: List[str]) -> List[List[Any]]:
        out = []
        as_list = self._as_list
        for _, row in df.iterrows():
            lst: List[Any] = []
            for c in columns:
                if c in row:
                    lst.extend([x for x in as_list(row[c]) if x is not None])
            out.append(lst)
        return out

    # --------------------------
    # Предтокенизация: BatchTokenizer
    # --------------------------
    def _pretokenize_with_batch_tokenizer(self, batch_size: int):
        texts = [self._join_text(self.df.iloc[i]) for i in range(self._N)]
        banks: Dict[str, List[torch.Tensor]] = {}

        for start in range(0, self._N, batch_size):
            batch = texts[start:start + batch_size]
            tok = self.batch_tokenizer.tokenize_batch(batch, use_cache=False)  # dict[str, torch.Tensor [B, L]]
            # Нормализуем типы
            for k in tok:
                if k in ("input_ids", "attention_mask", "token_type_ids"):
                    tok[k] = tok[k].long()
                else:
                    tok[k] = tok[k].to(torch.float32)

            # Сохраняем
            for k, v in tok.items():
                banks.setdefault(k, []).append(v)

        # Склеиваем по первой оси
        self._tok_bank = {k: torch.cat(v_parts, dim=0).contiguous() for k, v_parts in banks.items()}

    # --------------------------
    # Предтокенизация: custom fn
    # --------------------------
    def _pretokenize_with_custom_fn(self, batch_size: int, deduplicate_texts: bool = True):
        # Подготовим «сырые» тексты как списки (без pandas в горячем цикле)
        cols = self.text_columns
        col_arrays = [self.df[c].astype(str).where(~self.df[c].isna(), other="").tolist() for c in cols]

        # Детектируем форму по первому примеру
        first_td = {c: col_arrays[i][0] for i, c in enumerate(cols)}
        first_tok = self.text_tokenizer_fn(first_td, self.special_tokens)
        if not isinstance(first_tok, dict):
            raise ValueError("custom text_tokenizer_fn должна возвращать dict тензоров")

        # Проверим одинаковую длину для всех ключей
        shapes = {k: tuple(t.shape) for k, t in first_tok.items()}
        if any(len(s) == 0 for s in shapes.values()):
            raise ValueError("text_tokenizer_fn должна возвращать тензоры с размерностью хотя бы [L]")

        # Предвыделим банки
        bank: Dict[str, torch.Tensor] = {}
        for k, t in first_tok.items():
            dtype = t.dtype if torch.is_tensor(t) else torch.long
            bank[k] = torch.empty((self._N, *t.shape), dtype=dtype)

        # Заполним первую строку
        for k, t in first_tok.items():
            bank[k][0].copy_(t if torch.is_tensor(t) else torch.tensor(t))

        # Дедупликация (опционально)
        cache: Dict[tuple, Dict[str, torch.Tensor]] = {}
        if deduplicate_texts:
            key0 = tuple(first_td.get(c, "") for c in cols)
            cache[key0] = {k: (v.clone() if v.is_floating_point() else v.clone()) for k, v in first_tok.items()}

        # Основной цикл: батчами формируем text_data и токенизируем per-sample (но единожды)
        for start in range(1, self._N, batch_size):
            end = min(self._N, start + batch_size)
            for i in range(start, end):
                # Сформировать text_data для i-й строки
                td = {c: col_arrays[j][i] for j, c in enumerate(cols)}
                if deduplicate_texts:
                    key = tuple(td.get(c, "") for c in cols)
                    got = cache.get(key)
                    if got is None:
                        tok = self.text_tokenizer_fn(td, self.special_tokens)
                        cache[key] = tok
                    else:
                        tok = got
                else:
                    tok = self.text_tokenizer_fn(td, self.special_tokens)

                # Записать в банки
                for k, t in tok.items():
                    if not torch.is_tensor(t):
                        t = torch.tensor(t)
                    bank[k][i].copy_(t)

        # Сохраняем банки
        self._tok_bank = {k: v.contiguous() for k, v in bank.items()}


# =========================
# Универсальный бэкенд
# =========================

class BaseBackend(nn.Module):
    name: str = "base"
    supported: set = set()
    embed_dim: int = 0
    out_dim_per_modality: Dict[str, int] = {}
    text_tokenizer_fn: Optional[Callable] = None
    batch_tokenizer: Optional[BatchTokenizer] = None
    special_tokens: Dict[str, str] = {}
    tokenizer_returns_tensors: bool = False
    local_cache_dir: str = "./model_cache"
    text_padding_strategy: str = "max_length"  # стратегия паддинга текста

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        raise NotImplementedError

    def encode(self, backend_inputs: Dict[str, Any], device: torch.device) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def freeze_all(self):
        for p in self.parameters():
            p.requires_grad = False

    def get_out_dim(self, modality: str) -> int:
        return self.out_dim_per_modality.get(modality, self.embed_dim)

    def set_text_tokenizer(self, tokenizer_fn: Optional[Callable], special_tokens: Optional[Dict[str, str]] = None,
                           returns_tensors: bool = False):
        self.text_tokenizer_fn = tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = returns_tensors

    def set_batch_tokenizer(self, tokenizer, max_length: int = 512,
                            cache_size: int = 10000, batch_size: int = 256,
                            padding_strategy: str = "max_length"):
        self.text_padding_strategy = padding_strategy
        self.batch_tokenizer = BatchTokenizer(
            tokenizer=tokenizer,
            max_length=max_length,
            cache_size=cache_size,
            batch_size=batch_size,
            use_fast=True,
            padding_strategy=padding_strategy
        )


class UniversalMultiBackend(BaseBackend):
    name = "universal"

    class _ParamDeviceProxy(nn.Module):
        def __init__(self, base, device: torch.device):
            super().__init__()
            self.base = base if isinstance(base, nn.Module) else None
            self._callable = base if not isinstance(base, nn.Module) else None
            self._dummy = nn.Parameter(torch.empty(0), requires_grad=False)
            with torch.no_grad():
                self._dummy.data = self._dummy.data.to(device)
            try:
                target = self.base if self.base is not None else self._callable
                if hasattr(target, "to"):
                    target.to(device)
            except Exception:
                pass

        def forward(self, *args, **kwargs):
            target = self.base if self.base is not None else self._callable
            return target(*args, **kwargs)

        def to(self, device, *args, **kwargs):
            self._dummy.data = self._dummy.data.to(device)
            try:
                target = self.base if self.base is not None else self._callable
                if hasattr(target, "to"):
                    target.to(device)
            except Exception:
                pass
            return super().to(device, *args, **kwargs)

    def _preferred_device(self) -> torch.device:
        if torch.cuda.is_available():
            return torch.device("cuda")
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")

    def _wrap_if_parameterless(self, model, device: torch.device):
        try:
            it = model.parameters() if hasattr(model, "parameters") else iter(())
            next(it)
            return model
        except StopIteration:
            return self._ParamDeviceProxy(model, device)
        except Exception:
            return self._ParamDeviceProxy(model, device)

    def __init__(
        self,
        text_model_config: Optional[Dict[str, Any]] = None,
        image_model_config: Optional[Dict[str, Any]] = None,
        audio_model_config: Optional[Dict[str, Any]] = None,
        freeze: bool = True,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        tokenizer_returns_tensors: bool = False,
        use_batch_tokenizer: bool = True,
        tokenizer_cache_size: int = 10000,
        tokenizer_batch_size: int = 256,
        local_cache_dir: str = "./model_cache",
        text_padding_strategy: str = "max_length"  # стратегия паддинга текста
    ):
        super().__init__()
        self.supported = set()
        self.out_dim_per_modality = {}
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors
        self.use_batch_tokenizer = use_batch_tokenizer
        self.tokenizer_cache_size = tokenizer_cache_size
        self.tokenizer_batch_size = tokenizer_batch_size
        self.local_cache_dir = local_cache_dir
        self.text_padding_strategy = text_padding_strategy

        self.text_model = None
        self.text_processor = None
        self.text_config = text_model_config or {}
        if text_model_config:
            self._init_text_model(text_model_config)
            self.supported.add("text")

        self.image_model = None
        self.image_processor = None
        self.image_config = image_model_config or {}
        if image_model_config:
            self._init_image_model(image_model_config)
            self.supported.add("image")

        self.audio_model = None
        self.audio_processor = None
        self.audio_config = audio_model_config or {}
        if audio_model_config:
            self._init_audio_model(audio_model_config)
            self.supported.add("audio")

        if freeze:
            self.freeze_all()

    def _ensure_2d(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if x is None:
            return None
        if x.dim() == 1:
            return x.unsqueeze(0)
        if x.dim() > 2:
            return x.view(x.size(0), -1)
        return x

    def _normalize_2d(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        x = self._ensure_2d(x)
        return F.normalize(x, dim=-1, eps=1e-12) if x is not None and x.numel() > 0 else x

    def _init_text_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPTokenizer, ClapModel, ClapProcessor

        checkpoint = config['checkpoint']
        model_type = config.get('model_type', 'auto').lower()

        print(f"Загрузка текстовой модели {checkpoint}...")

        if model_type == 'clip':
            self.text_model = safe_load(CLIPModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = safe_load(CLIPTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = self.text_model.config.projection_dim
        elif model_type == 'clap':
            self.text_model = safe_load(ClapModel, checkpoint, local_cache_dir=self.local_cache_dir)
            proc = safe_load(ClapProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = getattr(proc, 'tokenizer', None) or safe_load(AutoTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = getattr(self.text_model.config, "projection_dim", 512)
        else:
            self.text_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = safe_load(AutoTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = self.text_model.config.hidden_size

        dev = self._preferred_device()
        self.text_model  = self._wrap_if_parameterless(self.text_model, dev)

        if self.use_batch_tokenizer and self.text_processor is not None:
            self.set_batch_tokenizer(
                self.text_processor,
                max_length=config.get('max_length', 512),
                cache_size=self.tokenizer_cache_size,
                batch_size=self.tokenizer_batch_size,
                padding_strategy=self.text_padding_strategy
            )

        self.text_config['max_length'] = config.get('max_length', 512)
        self.text_config['dim'] = dim
        self.text_config['model_type'] = model_type
        self.out_dim_per_modality['text'] = dim

    def _init_image_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoImageProcessor, CLIPModel, CLIPImageProcessor

        checkpoint = config['checkpoint']
        model_type = config.get('model_type', 'auto').lower()

        print(f"Загрузка визуальной модели {checkpoint}...")

        if model_type == 'clip':
            self.image_model = safe_load(CLIPModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.image_processor = safe_load(CLIPImageProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.image_model.config.projection_dim
        else:
            self.image_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.image_processor = safe_load(AutoImageProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.image_model.config.hidden_size

        dev = self._preferred_device()
        self.image_model = self._wrap_if_parameterless(self.image_model, dev)

        self.image_config['max_images'] = config.get('max_images', 1)
        self.image_config['image_agg'] = config.get('image_agg', 'concat')
        self.image_config['dim'] = dim
        self.image_config['model_type'] = model_type

        self.out_dim_per_modality['image'] = (dim * self.image_config['max_images']) if self.image_config['image_agg'] == 'concat' else dim

    def _init_audio_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoProcessor, ClapModel, ClapProcessor

        model_type = config.get('model_type', 'auto').lower()
        checkpoint = config.get('checkpoint', None)

        print(f"Загрузка аудио модели (type={model_type})...")

        if model_type == 'wav2clip':
            import wav2clip as w2c
            self._w2c = w2c

            w2c_model = None
            if hasattr(w2c, "get_model"):
                w2c_model = w2c.get_model()
            elif hasattr(w2c, "model"):
                m = w2c.model
                w2c_model = m() if callable(m) else m
            else:
                raise RuntimeError("wav2clip не содержит get_model()/model. Обновите пакет wav2clip.")

            self.audio_model = w2c_model

            try:
                if isinstance(self.audio_model, torch.nn.Module) and torch.cuda.is_available():
                    self.audio_model = self.audio_model.to("cuda")
            except Exception:
                pass

            self.audio_processor = None
            dim = 512
            sr = config.get('sr', 16000)

        elif model_type == 'clap':
            if checkpoint is None:
                raise ValueError("audio_model_config['checkpoint'] обязателен для CLAP")
            self.audio_model = safe_load(ClapModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.audio_processor = safe_load(ClapProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = getattr(self.audio_model.config, "projection_dim", 512)
            sr = getattr(self.audio_processor, "sampling_rate", None)
            if sr is None:
                fe = getattr(self.audio_processor, "feature_extractor", None)
                sr = getattr(fe, "sampling_rate", 48000) if fe is not None else 48000

        else:
            if checkpoint is None:
                raise ValueError("audio_model_config['checkpoint'] обязателен для аудио-моделей, кроме wav2clip")
            self.audio_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.audio_processor = safe_load(AutoProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.audio_model.config.hidden_size
            fe = getattr(self.audio_processor, "feature_extractor", None)
            sr = getattr(fe, "sampling_rate", 16000) if fe is not None else 16000

        dev = self._preferred_device()
        self.audio_model = self._wrap_if_parameterless(self.audio_model, dev)

        self.audio_config['sr'] = config.get('sr', sr)
        self.audio_config['max_audios'] = config.get('max_audios', 1)
        self.audio_config['audio_agg'] = config.get('audio_agg', 'concat')
        self.audio_config['dim'] = dim
        self.audio_config['model_type'] = model_type

        self.out_dim_per_modality['audio'] = (
            dim * self.audio_config['max_audios']
            if self.audio_config['audio_agg'] == 'concat' else dim
        )

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        labels = []
        for b in batch:
            labels.append(torch.tensor(b.get("labels", 0), dtype=torch.long))
        labels = torch.stack(labels)

        backend_inputs: Dict[str, Any] = {}
        batch_size = len(batch)

        if self.text_model is not None:
            if "text_tokens" in batch[0]:
                text_inputs = {}
                for key in batch[0]["text_tokens"].keys():
                    if torch.is_tensor(batch[0]["text_tokens"][key]):
                        text_inputs[key] = torch.stack([b["text_tokens"][key] for b in batch])
                    else:
                        dtype = torch.long if key in ["input_ids", "attention_mask", "token_type_ids"] else torch.float32
                        text_inputs[key] = torch.tensor([b["text_tokens"][key] for b in batch], dtype=dtype)
                backend_inputs["text_inputs"] = text_inputs
            else:
                texts = [b.get("text", "") or " " for b in batch]
                if self.batch_tokenizer:
                    text_inputs = self.batch_tokenizer.tokenize_batch(texts, use_cache=True)
                else:
                    pad = "max_length" if getattr(self, "text_padding_strategy", "max_length") == "max_length" else True
                    text_inputs = self.text_processor(
                        texts, padding=pad, truncation=True,
                        max_length=self.text_config.get('max_length', 512),
                        return_tensors="pt"
                    )
                backend_inputs["text_inputs"] = {k: v for k, v in text_inputs.items()}

        if self.image_model is not None:
            images_lists = [b.get("images", []) for b in batch]
            flat_images, img_counts = [], []
            for lst in images_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [img for img in lst if img is not None]
                img_counts.append(len(lst))
                for img in lst:
                    flat_images.append(to_pil(img))

            if len(flat_images) > 0:
                img_proc = self.image_processor(images=flat_images, return_tensors="pt")
                backend_inputs["image_inputs"] = {"pixel_values": img_proc["pixel_values"]}
            else:
                backend_inputs["image_inputs"] = {"pixel_values": None}

            backend_inputs["image_counts"] = torch.tensor(img_counts, dtype=torch.long)

        if self.audio_model is not None:
            audios_lists = [b.get("audios", []) for b in batch]
            flat_audios, aud_counts = [], []
            for lst in audios_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [a for a in lst if a is not None]
                aud_counts.append(len(lst))
                for a in lst:
                    if isinstance(a, str):
                        flat_audios.append(load_audio(a, self.audio_config['sr']))
                    elif isinstance(a, np.ndarray):
                        aa = np.asarray(a, dtype=np.float32)
                        if aa.ndim > 1:
                            aa = np.squeeze(aa)
                        if aa.ndim > 1:
                            aa = aa.reshape(-1)
                        flat_audios.append(aa)
            if self.audio_config.get('model_type') == 'wav2clip':
                backend_inputs["audio_inputs"] = {"raw_audios": flat_audios}
            elif len(flat_audios) > 0:
                if self.audio_config.get('model_type') == 'clap':
                    aud_proc = self.audio_processor(
                        audios=flat_audios, sampling_rate=self.audio_config['sr'], padding=True, return_tensors="pt"
                    )
                    backend_inputs["audio_inputs"] = {"input_features": aud_proc["input_features"]}
                else:
                    aud_proc = self.audio_processor(
                        flat_audios, sampling_rate=self.audio_config['sr'], padding=True, return_tensors="pt"
                    )
                    backend_inputs["audio_inputs"] = {"input_values": aud_proc["input_values"]}
            else:
                backend_inputs["audio_inputs"] = {"input_features": None, "input_values": None, "raw_audios": []}

            backend_inputs["audio_counts"] = torch.tensor(aud_counts, dtype=torch.long)

        backend_inputs["batch_size"] = batch_size
        return {"labels": labels, "backend_inputs": backend_inputs}

    def _aggregate_embeddings(
        self,
        embs: Optional[torch.Tensor],
        counts: List[int],
        max_k: int,
        dim_hint: int,
        agg_type: str,
        batch_size: int,
        device: torch.device
    ) -> torch.Tensor:
        if embs is None or (torch.is_tensor(embs) and embs.numel() == 0):
            feat_dim = int(dim_hint) if dim_hint is not None else 0
            out_dim = feat_dim * max_k if agg_type == 'concat' else feat_dim
            return torch.zeros((batch_size, out_dim), device=device, dtype=torch.float32)

        if not torch.is_tensor(embs):
            embs = torch.as_tensor(embs, device=device, dtype=torch.float32)
        if embs.dim() == 1:
            embs = embs.unsqueeze(0)
        elif embs.dim() > 2:
            embs = embs.view(embs.size(0), -1)

        N, D = embs.size()
        out_dim = (D * max_k) if agg_type == 'concat' else D
        out = torch.zeros((batch_size, out_dim), device=device, dtype=embs.dtype)

        offset = 0
        for i, c in enumerate(counts):
            if c <= 0 or offset >= N:
                continue
            take_n = min(c, N - offset)
            sample = embs[offset:offset + take_n]
            offset += take_n

            if agg_type == 'concat':
                take = sample[:max_k]
                if take.size(0) < max_k:
                    pad = torch.zeros((max_k - take.size(0), D), device=device, dtype=embs.dtype)
                    take = torch.cat([take, pad], dim=0)
                out[i] = take.reshape(-1)
            else:
                out[i] = sample.mean(dim=0)

        return F.normalize(out, dim=-1, eps=1e-12) if out.size(1) > 0 else out

    @torch.no_grad()
    def _wav2clip_embed(self, arr: np.ndarray, device: torch.device) -> torch.Tensor:
        arr = np.asarray(arr, dtype=np.float32)
        if arr.ndim > 1:
            arr = np.squeeze(arr)
        if arr.ndim > 1:
            arr = arr.reshape(-1)
        if arr.size < 512:
            arr = np.pad(arr, (0, 512 - arr.size), mode="constant")

        try:
            emb = self._w2c.embed_audio(arr, self.audio_model)
            emb = np.asarray(emb)
        except Exception:
            x = torch.from_numpy(arr).float().unsqueeze(0).to(device)
            y = self.audio_model(x)
            if isinstance(y, (tuple, list)):
                y = y[0]
            if torch.is_tensor(y):
                if y.dim() == 2 and y.size(0) == 1:
                    y = y.squeeze(0)
                emb = y.detach().cpu().numpy()
            else:
                emb = np.asarray(y)

        if emb.ndim > 1:
            emb = emb.reshape(-1)
        return torch.as_tensor(emb, device=device, dtype=torch.float32)

    def encode(self, backend_inputs: Dict[str, Any], device: torch.device) -> Dict[str, torch.Tensor]:
        results: Dict[str, torch.Tensor] = {}
        batch_size = int(backend_inputs.get("batch_size", 1))

        if self.text_model is not None and "text_inputs" in backend_inputs:
            text_inputs = {k: v.to(device) for k, v in backend_inputs["text_inputs"].items()}
            if hasattr(self.text_model, "get_text_features"):
                text_z = self.text_model.get_text_features(**text_inputs)
            else:
                outputs = self.text_model(**text_inputs)
                text_z = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state.mean(dim=1)
            results["text"] = self._normalize_2d(text_z)

        if self.image_model is not None and "image_inputs" in backend_inputs:
            pi = backend_inputs["image_inputs"]["pixel_values"]
            counts = backend_inputs["image_counts"].tolist()
            total_images_needed = sum(counts)

            img_flat = None
            actual_img_dim = self.image_config.get("dim", 768)

            if pi is not None and pi.numel() > 0 and total_images_needed > 0:
                pi = pi.to(device)
                if pi.size(0) > total_images_needed:
                    pi = pi[:total_images_needed]

                if hasattr(self.image_model, "get_image_features"):
                    img_flat = self.image_model.get_image_features(pixel_values=pi)
                else:
                    outputs = self.image_model(pixel_values=pi)
                    img_flat = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state[:, 0]

                img_flat = self._normalize_2d(img_flat)
                actual_img_dim = img_flat.size(1) if img_flat is not None else actual_img_dim

            img_z = self._aggregate_embeddings(
                img_flat, counts,
                self.image_config["max_images"],
                actual_img_dim,
                self.image_config["image_agg"],
                len(counts),
                device
            )

            if actual_img_dim != self.image_config.get("dim"):
                self.image_config["dim"] = actual_img_dim
                self.out_dim_per_modality["image"] = (
                    actual_img_dim * self.image_config["max_images"]
                    if self.image_config["image_agg"] == "concat" else actual_img_dim
                )

            results["image"] = img_z

        if self.audio_model is not None and "audio_inputs" in backend_inputs:
            counts = backend_inputs["audio_counts"].tolist()
            total_audios_needed = sum(counts)

            aud_flat = None
            actual_aud_dim = self.audio_config.get("dim", 768)
            model_type = self.audio_config.get("model_type")

            if total_audios_needed > 0:
                if model_type == "clap":
                    af = backend_inputs["audio_inputs"]["input_features"]
                    if af is not None and af.numel() > 0:
                        af = af.to(device)
                        if af.size(0) > total_audios_needed:
                            af = af[:total_audios_needed]
                        with torch.cuda.amp.autocast(enabled=False):
                            aud_flat = self.audio_model.get_audio_features(input_features=af.float())
                        aud_flat = self._normalize_2d(aud_flat.float())
                        actual_aud_dim = aud_flat.size(1)

                elif model_type == "wav2clip":
                    raw_list = backend_inputs["audio_inputs"].get("raw_audios", [])
                    if len(raw_list) > total_audios_needed:
                        raw_list = raw_list[:total_audios_needed]
                    if len(raw_list) > 0:
                        embs = [self._wav2clip_embed(arr, device) for arr in raw_list]
                        aud_flat = torch.stack(embs, dim=0)
                        aud_flat = self._normalize_2d(aud_flat)
                        actual_aud_dim = aud_flat.size(1)

                else:
                    av = backend_inputs["audio_inputs"]["input_values"]
                    if av is not None and av.numel() > 0:
                        av = av.to(device)
                        if av.size(0) > total_audios_needed:
                            av = av[:total_audios_needed]
                        av = av.clamp_(-1.0, 1.0)
                        with torch.cuda.amp.autocast(enabled=False):
                            outputs = self.audio_model(input_values=av.float())
                            feats = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state.mean(dim=1)
                        aud_flat = self._normalize_2d(feats.float())
                        actual_aud_dim = aud_flat.size(1)

            aud_z = self._aggregate_embeddings(
                aud_flat, counts,
                self.audio_config["max_audios"],
                actual_aud_dim,
                self.audio_config["audio_agg"],
                len(counts),
                device
            )

            if aud_flat is not None and actual_aud_dim != self.audio_config.get("dim"):
                self.audio_config["dim"] = actual_aud_dim
                self.out_dim_per_modality["audio"] = (
                    actual_aud_dim * self.audio_config["max_audios"]
                    if self.audio_config["audio_agg"] == "concat" else actual_aud_dim
                )

            results["audio"] = aud_z

        if results:
            bs_list = [v.size(0) for v in results.values()]
            if len(set(bs_list)) != 1:
                raise RuntimeError(f"Inconsistent batch sizes across modalities: {bs_list}")

        return results


# =========================
# Классификатор
# =========================

class SingleBackboneClassifier(nn.Module):
    def __init__(
        self,
        backend: BaseBackend,
        modalities: List[str],
        num_labels: int,
        fusion: str = "concat",
        hidden: int = 512,
        dropout: float = 0.1
    ):
        super().__init__()
        self.backend = backend
        self.modalities = modalities
        self.fusion = fusion
        self.num_labels = num_labels

        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        if fusion == "concat":
            in_dim = sum(self.backend.get_out_dim(m) for m in order)
        elif fusion == "mean":
            dims = [self.backend.get_out_dim(m) for m in order]
            if len(set(dims)) != 1:
                raise ValueError('Для fusion="mean" размеры модальностей должны совпадать')
            in_dim = dims[0]
        else:
            raise ValueError('fusion должен быть "concat" или "mean"')

        self.head = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_labels)
        )

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        for m in [getattr(self.backend, "text_model", None),
                  getattr(self.backend, "image_model", None),
                  getattr(self.backend, "audio_model", None)]:
            if m is None:
                continue
            try:
                has_trainable = any(p.requires_grad for p in m.parameters()) if hasattr(m, "parameters") else False
            except Exception:
                has_trainable = False
            if not has_trainable:
                continue
            try:
                cfg = getattr(m, "config", None)
                if cfg is not None and hasattr(cfg, "use_cache"):
                    cfg.use_cache = False
            except Exception:
                pass
            try:
                if hasattr(m, "gradient_checkpointing_enable"):
                    try:
                        if gradient_checkpointing_kwargs is not None:
                            m.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
                        else:
                            m.gradient_checkpointing_enable()
                    except TypeError:
                        m.gradient_checkpointing_enable()
            except Exception:
                pass

    def gradient_checkpointing_disable(self):
        for m in [getattr(self.backend, "text_model", None),
                  getattr(self.backend, "image_model", None),
                  getattr(self.backend, "audio_model", None)]:
            if m is None:
                continue
            try:
                if hasattr(m, "gradient_checkpointing_disable"):
                    m.gradient_checkpointing_disable()
            except Exception:
                pass

    def _infer_device_from_inputs(self, obj) -> torch.device:
        if isinstance(obj, torch.Tensor):
            return obj.device
        if isinstance(obj, dict):
            for v in obj.values():
                d = self._infer_device_from_inputs(v)
                if d is not None:
                    return d
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _fuse(self, z: Dict[str, torch.Tensor]) -> torch.Tensor:
        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        feats = []
        batch_size = None
        for m in order:
            if m in z:
                t = z[m]
                if t.dim() == 3:
                    t = t.mean(dim=1)
                elif t.dim() > 3:
                    t = t.view(t.size(0), -1)
                feats.append(t)
                if batch_size is None:
                    batch_size = t.size(0)
        if batch_size is not None:
            feats = [f[:batch_size] for f in feats]
        if self.fusion == "concat":
            out = torch.cat(feats, dim=-1)
        elif self.fusion == "mean":
            out = torch.stack(feats, dim=0).mean(dim=0)
        return out

    def forward(self, backend_inputs: Dict[str, Any], labels: Optional[torch.Tensor] = None):
        device = self._infer_device_from_inputs(backend_inputs)
        z = self.backend.encode(backend_inputs, device=device)
        if not z:
            raise ValueError("Backend не вернул эмбеддинги")
        fused = self._fuse(z)
        logits = self.head(fused)
        return SequenceClassifierOutput(logits=logits)

    @torch.no_grad()
    def get_embeddings(self, backend_inputs: Dict[str, Any], return_per_modality: bool = False):
        device = self._infer_device_from_inputs(backend_inputs)
        z = self.backend.encode(backend_inputs, device=device)
        fused = self._fuse(z)
        if return_per_modality:
            return fused, z
        return fused


# =========================
# Trainer с весами классов
# =========================

class WeightedCETrainer(Trainer):
    def __init__(self, *args, num_labels=None, train_labels=None, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            self.class_weights = torch.as_tensor(class_weights, dtype=torch.float32)
        elif train_labels is not None and num_labels is not None:
            y = np.asarray(train_labels).astype(int)
            counts = np.bincount(y, minlength=num_labels)
            n = counts.sum()
            w = np.zeros(num_labels, dtype=np.float32)
            nz = counts > 0
            w[nz] = n / (num_labels * counts[nz].astype(np.float32))
            self.class_weights = torch.tensor(w, dtype=torch.float32)
        else:
            self.class_weights = None

    def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch: Optional[int] = None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        if logits.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: logits batch={logits.size(0)} vs labels batch={labels.size(0)}")

        weight = self.class_weights.to(logits.device) if self.class_weights is not None else None
        loss = F.cross_entropy(logits, labels.long(), weight=weight)
        return (loss, outputs) if return_outputs else loss


# =========================
# Прогресс-логгер
# =========================

class PbarConsoleLogger(TrainerCallback):
    def __init__(self, pbar):
        self.pbar = pbar
        self.last_logs = {}
        self.last_train_loss = None
        self.printed_eval_steps = set()
        self.tqdm = tqdm

    def _step(self, state) -> int:
        return int(state.global_step or 0)

    def _fmt_postfix(self):
        parts = []
        if 'loss' in self.last_logs:
            parts.append(f"loss {self.last_logs['loss']:.4f}")
        if 'eval_loss' in self.last_logs:
            parts.append(f"val {self.last_logs['eval_loss']:.4f}")
        for k, v in self.last_logs.items():
            if k.startswith('eval_') and k not in (
                'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'
            ):
                parts.append(f"{k.replace('eval_', '')} {v:.4f}")
        return " | ".join(parts)

    def on_step_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        if self.last_logs:
            self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        for k, v in logs.items():
            if isinstance(v, (int, float)):
                self.last_logs[k] = float(v)
        if 'loss' in logs and isinstance(logs['loss'], (int, float)):
            self.last_train_loss = float(logs['loss'])
        self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

        if any(k.startswith('eval_') for k in logs.keys()):
            step = self._step(state)
            if step in self.printed_eval_steps:
                return
            self.printed_eval_steps.add(step)
            train_loss_str = f"{self.last_train_loss:.10f}" if self.last_train_loss is not None else "n/a"
            val_loss = logs.get('eval_loss', None)
            val_loss_str = f"{float(val_loss):.10g}" if isinstance(val_loss, (int, float)) else "n/a"
            exclude = {'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'}
            extra_parts = []
            for k, v in logs.items():
                if k.startswith('eval_') and k not in exclude:
                    metric_name = k.replace('eval_', '')
                    extra_parts.append(f"val {metric_name}: {float(v):.10f}")
            line = f"step: {step}, train loss: {train_loss_str}, val loss: {val_loss_str}"
            if extra_parts:
                line += ", " + ", ".join(extra_parts)
            self.tqdm.write(line)

    def on_train_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        self.pbar.refresh()


# =========================
# Основной пайплайн
# =========================

class SingleModelMultiComboClassification:
    """
    Универсальный пайплайн для мульти-модальной классификации (text / image / audio) поверх моделей Hugging Face.
    Поддерживает автоматический выбор бэкенда (CLIP/CLAP/Auto), батчевую токенизацию текста, кэширование,
    чанковую тренировку, раннюю остановку, взвешивание классов, предсказание и извлечение эмбеддингов.

    Основные возможности:
    - Автоматическая сборка бэкенда: CLIP для связки text+image, CLAP для text+audio, либо произвольные Auto-модели.
    - Работа с тремя модальностями: text, image, audio (любой поднабор).
    - Батчевая токенизация текста с кэшем и опциональной предварительной токенизацией датасета.
    - Чанковая тренировка очень больших датасетов без перегрузки памяти.
    - Сбалансированная (взвешенная) кросс-энтропия на основе частот классов в тренировочных данных.
    - Прогресс-бар, ранняя остановка и выбор лучшей модели по метрике.
    - Предсказания и извлечение эмбеддингов (в том числе по модальностям).

    :param modalities: Список используемых модальностей из {"text", "image", "audio"}.
    :param num_labels: Число классов в задаче классификации.
    :param target_column_name: Имя колонки с целевой меткой в DataFrame.
    :param text_columns: Список текстовых колонок (используются, если выбрана модальность "text"). Значения будут
                         конкатенированы через special_tokens["sep"] при подготовке примеров.
    :param image_columns: Список колонок с изображениями (пути к файлам, PIL.Image, np.ndarray или списки таких объектов),
                          используется, если выбрана модальность "image".
    :param audio_columns: Список колонок с аудио (пути к файлам или массивы np.ndarray; моно, float32),
                          используется, если выбрана модальность "audio". Для чтения из файлов требуется torchaudio.
    :param text_tokenizer_fn: Кастомная функция токенизации текста (если не используется встроенный BatchTokenizer).
                              Сигнатура: fn(text_dict: Dict[str, str], special_tokens: Dict[str, str]) -> Union[Dict[str, Tensor], str].
                              Если возвращает dict с ключами вроде 'input_ids', считается, что функция сразу возвращает тензоры токенов;
                              иначе строку для последующей стандартной токенизации.
    :param special_tokens: Спец. токены/разделители для подготовки текста. По умолчанию {"sep": " [SEP] "}.
    :param tokenizer_returns_tensors: Флаг, сигнализирующий, что custom text_tokenizer_fn возвращает уже тензоры
                                      (dict c 'input_ids', 'attention_mask' и т.д.). Влияет на коллатор.
    :param backend: Режим сборки бэкенда. "auto" — подобрать оптимальные модели по модальностям;
                    "clip" — CLIP для текста и изображений; "clap" — CLAP для текста и аудио; любое иное — ручные конфиги.
    :param clip_checkpoint: Чекпойнт CLIP (используется при auto/clip), по умолчанию "openai/clip-vit-base-patch32".
    :param clap_checkpoint: Чекпойнт CLAP (используется при auto/clap), по умолчанию "laion/clap-htsat-unfused".
    :param text_model_config: Конфиг текстовой модели. Минимум: {"checkpoint": "...", "model_type": "..."}.
                              Дополнительно: "max_length" и т.д. Примеры model_type: "clip", "clap", "bert", "auto".
    :param image_model_config: Конфиг визуальной модели. Минимум: {"checkpoint": "...", "model_type": "..."}.
                               Дополнительно: "max_images", "image_agg" ("concat"|"mean") и т.д. Примеры model_type: "clip", "vit", "auto".
    :param audio_model_config: Конфиг аудио-модели. Минимум: {"checkpoint": "...", "model_type": "..."} (кроме "wav2clip").
                               Дополнительно: "max_audios", "audio_agg" ("concat"|"mean"), "sr". Примеры model_type: "clap", "wav2clip", "auto".
    :param fusion: Способ слияния модальностей в классификаторе: "concat" или "mean".
                   При "mean" размеры эмбеддингов всех модальностей должны совпадать.
    :param freeze_backbone: Если True — бэкенды заморожены (тренируется только классификационная "голова").
    :param clip_max_length: Максимальная длина текста для CLIP-токенизатора (по умолчанию 77).
    :param max_images_per_sample: Сколько изображений брать на сэмпл (усреднение или конкатенация задаются в image_model_config["image_agg"]).
    :param max_audios_per_sample: Сколько аудио брать на сэмпл (аналично image, параметр audio_model_config["audio_agg"]).
    :param use_batch_tokenizer: Использовать BatchTokenizer для текста (ускоряет токенизацию и кэширует результаты).
    :param pretokenize_data: Предварительно токенизировать текст датасета (в памяти) для ускорения обучения/инференса.
    :param pretokenize_batch_size: Батч-размер при предварительной токенизации.
    :param tokenizer_cache_size: Размер LRU-кэша в BatchTokenizer.
    :param max_pretokenize_samples: Максимум сэмплов для предварительной токенизации на чанк/датасет.
    :param local_cache_dir: Локальная директория кэша моделей/процессоров HF.
    :param text_padding: "max_length" — паддинг до фиксированной длины; "dynamic" — паддинг до максимальной длины в батче.
                         При "dynamic" предтокенизация текста автоматически отключается.

    :return: None

    :raises ValueError: Если бэкенд не поддерживает выбранные модальности (внутренняя проверка соответствия).
                        Также возможны ошибки конфигов, например fusion="mean" при несовпадающих размерах эмбеддингов.
    :raises OSError: Ошибки загрузки моделей/процессоров из Hugging Face Hub (сетевые/офлайн проблемы, отсутствующие чекпойнты).
    :raises RuntimeError: Проблемы с устройством/драйвером (CUDA/MPS), несовместимость версий зависимостей и т.п.
    """
    def __init__(
        self,
        modalities: List[str],
        num_labels: int,
        target_column_name: str,
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        tokenizer_returns_tensors: bool = False,
        backend: str = "auto",
        clip_checkpoint: str = "openai/clip-vit-base-patch32",
        clap_checkpoint: str = "laion/clap-htsat-unfused",
        text_model_config: Optional[Dict[str, Any]] = None,
        image_model_config: Optional[Dict[str, Any]] = None,
        audio_model_config: Optional[Dict[str, Any]] = None,
        fusion: str = "concat",
        freeze_backbone: bool = True,
        clip_max_length: int = 77,
        max_images_per_sample: int = 1,
        max_audios_per_sample: int = 1,
        use_batch_tokenizer: bool = True,
        pretokenize_data: bool = True,
        pretokenize_batch_size: int = 256,
        tokenizer_cache_size: int = 10000,
        max_pretokenize_samples: int = 100000,
        local_cache_dir: str = "./model_cache",
        text_padding: str = "max_length"  # "max_length" или "dynamic"
    ):
        self.modalities = sorted(list(set(modalities)))
        self.num_labels = num_labels
        self.target_column_name = target_column_name
        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors
        self.backend_name = backend.lower()
        self.clip_checkpoint = clip_checkpoint
        self.clap_checkpoint = clap_checkpoint
        self.text_model_config = text_model_config
        self.image_model_config = image_model_config
        self.audio_model_config = audio_model_config
        self.fusion = fusion
        self.freeze_backbone = freeze_backbone
        self.clip_max_length = clip_max_length
        self.max_images_per_sample = int(max_images_per_sample)
        self.max_audios_per_sample = int(max_audios_per_sample)

        self.use_batch_tokenizer = use_batch_tokenizer
        self.pretokenize_data = pretokenize_data
        self.pretokenize_batch_size = pretokenize_batch_size
        self.tokenizer_cache_size = tokenizer_cache_size
        self.max_pretokenize_samples = max_pretokenize_samples
        self.local_cache_dir = local_cache_dir
        self.text_padding = text_padding

        self.label2id: Optional[Dict[Any, int]] = None
        self.id2label: Optional[Dict[int, str]] = None

        self.backend: Optional[BaseBackend] = None
        self.model: Optional[SingleBackboneClassifier] = None
        self.trainer: Optional[Trainer] = None
        self.compute_metrics = None
        self.progress_callback: Optional[PbarConsoleLogger] = None

        self._build_backend()

    def _build_backend(self):
        mods = set(self.modalities)
        name = self.backend_name

        if name == "auto":
            if mods == {"text", "image"}:
                self.text_model_config = self.text_model_config or {
                    'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_length': self.clip_max_length
                }
                self.image_model_config = self.image_model_config or {
                    'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'
                }
            elif mods == {"text", "audio"}:
                self.text_model_config = self.text_model_config or {
                    'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_length': 64
                }
                self.audio_model_config = self.audio_model_config or {
                    'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000
                }
            else:
                if "text" in mods and self.text_model_config is None:
                    self.text_model_config = {'checkpoint': 'bert-base-multilingual-cased', 'model_type': 'bert', 'max_length': 512}
                if "image" in mods and self.image_model_config is None:
                    self.image_model_config = {'checkpoint': 'google/vit-base-patch16-224', 'model_type': 'vit', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'}
                if "audio" in mods and self.audio_model_config is None:
                    self.audio_model_config = {'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000}

        elif name == "clip":
            self.text_model_config = self.text_model_config or {
                'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_length': self.clip_max_length
            }
            self.image_model_config = self.image_model_config or {
                'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'
            }
        elif name == "clap":
            self.text_model_config = self.text_model_config or {
                'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_length': 64
            }
            self.audio_model_config = self.audio_model_config or {
                'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000
            }
        else:
            pass

        self.backend = UniversalMultiBackend(
            text_model_config=self.text_model_config if "text" in mods else None,
            image_model_config=self.image_model_config if "image" in mods else None,
            audio_model_config=self.audio_model_config if "audio" in mods else None,
            freeze=self.freeze_backbone,
            text_tokenizer_fn=self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            tokenizer_returns_tensors=self.tokenizer_returns_tensors,
            use_batch_tokenizer=self.use_batch_tokenizer,
            tokenizer_cache_size=self.tokenizer_cache_size,
            tokenizer_batch_size=self.pretokenize_batch_size,
            local_cache_dir=self.local_cache_dir,
            text_padding_strategy=self.text_padding
        )

        if not set(self.modalities).issubset(self.backend.supported):
            raise ValueError(f"Бэкенд {self.backend.name} не поддерживает модальности {self.modalities}")

    def _setup_metrics(self, metric_name: str):
        metric_name = metric_name.lower()
        if metric_name == "f1":
            metric = evaluate.load("f1")
            def compute(p):
                preds = p.predictions.argmax(-1)
                return metric.compute(predictions=preds, references=p.label_ids, average="weighted")
        elif metric_name == "accuracy":
            metric = evaluate.load("accuracy")
            def compute(p):
                preds = p.predictions.argmax(-1)
                return metric.compute(predictions=preds, references=p.label_ids)
        else:
            raise ValueError('metric_name должен быть "f1" или "accuracy"')
        self.compute_metrics = compute

    def _split(self, df: pd.DataFrame, test_size: float = 0.2, seed: int = 42):
        df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
        n_eval = int(math.ceil(len(df) * test_size))
        return df.iloc[n_eval:].reset_index(drop=True), df.iloc[:n_eval].reset_index(drop=True)

    def _validate_data(self, df: pd.DataFrame):
        if "text" in self.modalities:
            if not self.text_columns:
                raise ValueError("Вы выбрали модальность 'text', но text_columns пустой.")
            missing = [c for c in self.text_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют текстовые колонки: {missing}")
        if "image" in self.modalities:
            if not self.image_columns:
                raise ValueError("Вы выбрали модальность 'image', но image_columns пуст.")
            missing = [c for c in self.image_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют колонки изображений: {missing}")
        if "audio" in self.modalities:
            if not self.audio_columns:
                raise ValueError("Вы выбрали модальность 'audio', но audio_columns пуст.")
            missing = [c for c in self.audio_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют колонки аудио: {missing}")

    def fit(
        self,
        train_data: pd.DataFrame,
        epochs: int = 3,
        test_size: float = 0.2,
        test_data: Optional[pd.DataFrame] = None,
        per_device_train_batch_size: int = 16,
        gradient_accumulation_steps: int = 1,
        learning_rate: float = 2e-4,
        metric_name: str = "f1",
        fp16: bool = True,
        logging_steps: int = 50,
        eval_steps: int = 200,
        output_dir: str = "./result",
        seed: int = 42,
        hidden: int = 512,
        dropout: float = 0.1,
        gradient_checkpointing: bool = False,
        fit_chunk_size: Optional[int] = None,
        clear_cache_every_n_chunks: int = 10
    ):
        """
        Обучает классификатор на заданном DataFrame с поддержкой валидации, ранней остановки и чанковой тренировки.
        Внутренне использует WeightedCETrainer (кросс-энтропия с весами классов по обратной частоте в train_data),
        логгер прогресса и, при необходимости, предварительную токенизацию батчей текста.

        Если test_data не передан, train_data разделяется на train/eval по test_size (стратификации нет).
        При очень больших датасетах обучение проводится чанками (fit_chunk_size), чтобы ограничить потребление памяти.

        :param train_data: Обучающий DataFrame. Должен содержать столбец target_column_name, а также столбцы по выбранным модальностям:
                           - text_columns для "text" (строки, допускаются NaN),
                           - image_columns для "image" (строки путей к файлам, PIL.Image, np.ndarray или списки этих типов),
                           - audio_columns для "audio" (пути к аудиофайлам или np.ndarray моно-сигнала; для путей требуется torchaudio).
        :param epochs: Количество эпох обучения.
        :param test_size: Доля данных на валидацию, если test_data не задан.
        :param test_data: Отдельный DataFrame для валидации. Если указан, параметр test_size игнорируется.
        :param per_device_train_batch_size: Батч-размер на устройство для обучения.
        :param gradient_accumulation_steps: Шаги аккумулирования градиента (эффективный батч = batch_size * steps).
        :param learning_rate: Начальная скорость обучения (оптимизатор и шедулер создаются внутри Trainer).
        :param metric_name: Метрика ранней остановки/выбора лучшей модели: "f1" (weighted) или "accuracy".
        :param fp16: Использовать ли полуточность (только при наличии CUDA).
        :param logging_steps: Периодичность логирования в шагах.
        :param eval_steps: Периодичность валидации/сохранения модели в шагах.
        :param output_dir: Директория для чекпойнтов/логов.
        :param seed: Начальное зерно для воспроизводимости.
        :param hidden: Размер скрытого слоя классификационной головы.
        :param dropout: Дропаут в классификационной голове.
        :param gradient_checkpointing: Включить gradient checkpointing в бэкендах (если они обучаемые).
        :param fit_chunk_size: Размер чанка для поэтапной тренировки (None — весь датасет за эпоху без разбиения).
        :param clear_cache_every_n_chunks: Каждые N чанков очищать кэш токенизации (для экономии памяти).

        :return: self (для чейнинга вызовов).

        :raises ValueError:
            - Отсутствуют обязательные колонки по модальностям в train_data (внутренняя проверка _validate_data).
            - Невозможно собрать классификатор с fusion="mean" при разных размерах эмбеддингов модальностей.
        :raises RuntimeError: Ошибки, возникающие внутри transformers.Trainer (например, рассогласование батчей/логитов),
                              проблемы с устройством (CUDA OOM, MPS), ошибки чтения аудио/изображений.
        :raises OSError: Ошибки чтения исходных файлов данных (изображения/аудио) или кэша моделей.
        """
        self._validate_data(train_data)
        set_seed(seed)

        classes = sorted(train_data[self.target_column_name].unique().tolist())
        if self.num_labels != len(classes):
            print(f"Warning: num_labels={self.num_labels} != len(classes)={len(classes)}")
        self.label2id = {c: i for i, c in enumerate(classes)}
        self.id2label = {i: str(c) for c, i in self.label2id.items()}

        if test_data is None:
            df_train, df_eval = self._split(train_data, test_size=test_size, seed=seed)
        else:
            df_train, df_eval = train_data, test_data

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds_eval = MultiComboDataset(
            df=df_eval,
            target_col=self.target_column_name,
            label2id=self.label2id,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=(self.pretokenize_data and has_bt and len(df_eval) < 50000 and self.text_padding != "dynamic"),
            pretokenize_batch_size=self.pretokenize_batch_size,
            max_cache_size=min(len(df_eval), self.max_pretokenize_samples),
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        y_train_all = np.array([self.label2id[y] for y in df_train[self.target_column_name].tolist()], dtype=int)
        counts = np.bincount(y_train_all, minlength=self.num_labels)
        n_all = counts.sum()
        class_weights = np.zeros(self.num_labels, dtype=np.float32)
        nonzero = counts > 0
        class_weights[nonzero] = n_all / (self.num_labels * counts[nonzero].astype(np.float32))

        self.model = SingleBackboneClassifier(
            backend=self.backend,
            modalities=self.modalities,
            num_labels=self.num_labels,
            fusion=self.fusion,
            hidden=hidden,
            dropout=dropout
        )
        self._setup_metrics(metric_name)

        args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=max(4, per_device_train_batch_size // 2),
            gradient_accumulation_steps=gradient_accumulation_steps,
            eval_accumulation_steps=max(1, gradient_accumulation_steps * 2),
            learning_rate=learning_rate,
            warmup_ratio=0.1,
            lr_scheduler_type="cosine",
            weight_decay=0.01,
            eval_strategy="steps",
            eval_steps=eval_steps,
            save_strategy="steps",
            save_steps=eval_steps,
            load_best_model_at_end=True,
            metric_for_best_model=f"eval_{metric_name}",
            save_total_limit=1,
            logging_strategy="steps",
            logging_steps=logging_steps,
            report_to="none",
            fp16=fp16 and torch.cuda.is_available(),
            dataloader_num_workers=min(4, os.cpu_count() or 4),
            seed=seed,
            remove_unused_columns=False,
            gradient_checkpointing=gradient_checkpointing,
            dataloader_pin_memory=True,
            ddp_find_unused_parameters=False,
            disable_tqdm=True
        )

        def data_collator(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            return self.backend.collate(batch_list)

        def steps_for_size(sz: int, bsz: int, accum: int) -> int:
            return max(0, math.ceil(math.ceil(sz / max(1, bsz)) / max(1, accum)))

        def chunk_slices(index_array: np.ndarray, chunk_size: int):
            for i in range(0, len(index_array), chunk_size):
                yield index_array[i:i + chunk_size]

        n_train = len(df_train)
        rng = np.random.default_rng(seed)
        train_idx = np.arange(n_train)

        chunk_size = fit_chunk_size if (fit_chunk_size and fit_chunk_size > 0) else len(train_idx)

        total_steps = 0
        for _ in range(epochs):
            rng.shuffle(train_idx)
            for slc in chunk_slices(train_idx, chunk_size):
                total_steps += steps_for_size(len(slc), per_device_train_batch_size, gradient_accumulation_steps)

        dummy_idx = np.arange(min(len(df_train), 1))
        ds_train_init = (
            MultiComboDataset(
                df=df_train.iloc[dummy_idx],
                target_col=self.target_column_name,
                label2id=self.label2id,
                text_columns=self.text_columns,
                image_columns=self.image_columns,
                audio_columns=self.audio_columns,
                text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
                text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
                special_tokens=self.special_tokens,
                pretokenize=False,
                tokenizer_returns_tensors=self.tokenizer_returns_tensors
            ) if len(dummy_idx) > 0 else ds_eval
        )

        self.trainer = WeightedCETrainer(
            model=self.model,
            args=args,
            train_dataset=ds_train_init,
            eval_dataset=ds_eval,
            data_collator=data_collator,
            compute_metrics=self.compute_metrics,
            num_labels=self.num_labels,
            train_labels=y_train_all,
            class_weights=class_weights
        )
        self.trainer.remove_callback(PrinterCallback)

        if total_steps > 0:
            self.trainer.create_optimizer_and_scheduler(num_training_steps=total_steps)

        pbar = tqdm(total=total_steps, desc="Training Progress", unit="step")
        cb = PbarConsoleLogger(pbar)
        self.trainer.add_callback(cb)
        self.progress_callback = cb

        steps_done = 0
        chunk_counter = 0

        for ep in range(epochs):
            rng = np.random.default_rng(seed + ep)
            shuffled = np.arange(n_train)
            rng.shuffle(shuffled)

            for slc in chunk_slices(shuffled, chunk_size):
                chunk_df = df_train.iloc[slc]
                should_pretokenize = (
                    self.pretokenize_data and has_bt and self.text_padding != "dynamic"
                    and len(slc) < self.max_pretokenize_samples and len(slc) > 100
                )

                ds_chunk = MultiComboDataset(
                    df=chunk_df,
                    target_col=self.target_column_name,
                    label2id=self.label2id,
                    text_columns=self.text_columns,
                    image_columns=self.image_columns,
                    audio_columns=self.audio_columns,
                    text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
                    text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
                    special_tokens=self.special_tokens,
                    pretokenize=should_pretokenize,
                    pretokenize_batch_size=self.pretokenize_batch_size,
                    max_cache_size=min(len(slc), self.max_pretokenize_samples),
                    tokenizer_returns_tensors=self.tokenizer_returns_tensors
                )

                self.trainer.train_dataset = ds_chunk

                chunk_steps = steps_for_size(len(slc), per_device_train_batch_size, gradient_accumulation_steps)
                if chunk_steps == 0:
                    del ds_chunk, chunk_df
                    continue

                self.trainer.args.max_steps = steps_done + chunk_steps
                self.trainer.train()
                steps_done += chunk_steps

                chunk_counter += 1
                if chunk_counter % clear_cache_every_n_chunks == 0:
                    if hasattr(ds_chunk, 'clear_cache'):
                        ds_chunk.clear_cache()
                        print(f"✓ Очищен кэш токенизации после {chunk_counter} чанков")

                del ds_chunk, chunk_df
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        pbar.close()

        if getattr(self.backend, "batch_tokenizer", None):
            self.backend.batch_tokenizer.clear_cache()

        return self

    def predict(
        self,
        df: pd.DataFrame,
        return_label_str: bool = False,
        return_proba: bool = False,
        batch_size: Optional[int] = None
    ) -> np.ndarray:
        """
        Выполняет инференс на новом DataFrame и возвращает предсказания.
        Если в df отсутствует столбец target_column_name, он будет добавлен фиктивными значениями.

        :param df: DataFrame с теми же колонками по выбранным модальностям, что и при обучении.
                   - text_columns: строки,
                   - image_columns: пути к изображениям, PIL.Image, np.ndarray или списки таких элементов,
                   - audio_columns: пути к аудиофайлам или np.ndarray (моно, float32).
        :param return_label_str: Если True — вернуть массив строковых меток (id2label), иначе — индексы классов.
        :param return_proba: Если True — вернуть распределения вероятностей (softmax) формы [N, num_labels].
                             При включении этого флага игнорируется return_label_str.
        :param batch_size: Переопределяет per_device_eval_batch_size на время инференса (опционально).

        :return:
            - Если return_proba=True: np.ndarray формы [N, num_labels] — вероятности классов.
            - Иначе:
                - Если return_label_str=True: np.ndarray формы [N] со строковыми метками,
                - Иначе: np.ndarray формы [N] с индексами предсказанных классов.

        :raises RuntimeError: Если модель не обучена (trainer отсутствует).
        :raises ValueError: Ошибки приведения данных (например, несоответствие ожидаемым колонкам/типам),
                            внутренние ошибки коллатора/бэкенда (рассогласование размеров батчей).
        :raises OSError: Ошибки чтения исходных файлов (изображения/аудио).
        """
        if self.trainer is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        df_c = df.copy()
        if self.target_column_name not in df_c.columns:
            df_c[self.target_column_name] = list(self.label2id.keys())[0]

        print(f"Preparing dataset for prediction ({len(df_c)} samples)...")

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds = MultiComboDataset(
            df=df_c,
            target_col=self.target_column_name,
            label2id=self.label2id,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=(self.pretokenize_data and has_bt and len(df_c) < 10000 and self.text_padding != "dynamic"),
            pretokenize_batch_size=self.pretokenize_batch_size,
            max_cache_size=min(len(df_c), 10000),
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        if batch_size:
            original_bs = self.trainer.args.per_device_eval_batch_size
            self.trainer.args.per_device_eval_batch_size = batch_size

        effective_batch_size = batch_size or self.trainer.args.per_device_eval_batch_size
        num_batches = (len(df_c) + effective_batch_size - 1) // effective_batch_size

        print(f"Running predictions (batch_size={effective_batch_size}, num_batches={num_batches})...")

        original_disable_tqdm = self.trainer.args.disable_tqdm
        self.trainer.args.disable_tqdm = False

        preds = self.trainer.predict(test_dataset=ds)

        self.trainer.args.disable_tqdm = original_disable_tqdm

        if batch_size:
            self.trainer.args.per_device_eval_batch_size = original_bs

        if hasattr(ds, 'clear_cache'):
            ds.clear_cache()

        if return_proba:
            logits = preds.predictions
            exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
            probabilities = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
            return probabilities

        y_pred = np.argmax(preds.predictions, axis=-1)
        if return_label_str:
            return np.array([self.id2label[int(i)] for i in y_pred])
        return y_pred

    def get_embeddings(
        self,
        df: pd.DataFrame,
        batch_size: int = 32,
        return_per_modality: bool = False
    ):
        """
        Извлекает эмбеддинги для входного DataFrame. Возвращает склеенный (fused) эмбеддинг,
        а при необходимости — также эмбеддинги по каждой модальности.
        Схема слияния (concat/mean) и размеры зависят от настроек бэкенда и параметра fusion.

        Если в df отсутствует столбец target_column_name, он будет добавлен фиктивными значениями.

        :param df: DataFrame с данными по модальностям (аналогично predict()).
        :param batch_size: Батч-размер при извлечении эмбеддингов.
        :param return_per_modality: Если True — вернуть дополнительно словарь с эмбеддингами по модальностям.

        :return:
            - Если return_per_modality=False:
                np.ndarray формы [N, D_fused], где D_fused — размерность эмбеддинга после слияния.
            - Если return_per_modality=True:
                (fused, per_mod) — кортеж:
                    - fused: np.ndarray [N, D_fused],
                    - per_mod: Dict[str, np.ndarray], где ключ — модальность ("text"/"image"/"audio"),
                               значение — эмбеддинги этой модальности формы [N, D_mod].

        :raises RuntimeError: Если модель не обучена или не готова (trainer/model отсутствуют).
        :raises ValueError: Если бэкенд не вернул эмбеддинги (например, неверные настройки модальностей/данных).
        :raises OSError: Ошибки чтения исходных файлов (изображения/аудио).
        """
        if self.trainer is None or self.model is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        try:
            device = next(self.trainer.model.parameters()).device
        except StopIteration:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model.to(device).eval()

        df_c = df.copy()
        if self.target_column_name not in df_c.columns:
            df_c[self.target_column_name] = list(self.label2id.keys())[0]

        print(f"Preparing dataset for embeddings extraction ({len(df_c)} samples)...")

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds = MultiComboDataset(
            df=df_c,
            target_col=self.target_column_name,
            label2id=self.label2id,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=False,
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        def collate(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            return self.backend.collate(batch_list)

        def move_to_device(obj, device: torch.device):
            if torch.is_tensor(obj):
                return obj.to(device)
            if isinstance(obj, dict):
                return {k: move_to_device(v, device) for k, v in obj.items()}
            if isinstance(obj, (list, tuple)):
                t = [move_to_device(v, device) for v in obj]
                return type(obj)(t) if not isinstance(obj, list) else t
            return obj

        loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
        fused_list = []
        per_mod_lists = {m: [] for m in self.modalities} if return_per_modality else None

        num_batches = (len(df_c) + batch_size - 1) // batch_size

        print(f"Extracting embeddings (batch_size={batch_size}, num_batches={num_batches})...")

        with torch.no_grad():
            for batch in tqdm(loader, total=num_batches, desc="Extracting embeddings", unit="batch", leave=True):
                bi = move_to_device(batch["backend_inputs"], device)
                fused, per = self.model.get_embeddings(backend_inputs=bi, return_per_modality=True)
                fused_list.append(fused.cpu().numpy())
                if return_per_modality:
                    for m in per_mod_lists.keys():
                        if m in per:
                            per_mod_lists[m].append(per[m].cpu().numpy())

        print("Concatenating embeddings...")
        fused_arr = np.vstack(fused_list)

        if not return_per_modality:
            print(f"✓ Embeddings shape: {fused_arr.shape}")
            return fused_arr

        per_mod = {m: np.vstack(chunks) for m, chunks in per_mod_lists.items()}
        print(f"✓ Fused embeddings shape: {fused_arr.shape}")
        for m, arr in per_mod.items():
            print(f"✓ {m.capitalize()} embeddings shape: {arr.shape}")

        return fused_arr, per_mod

Функции для создания фиктивных данных.

In [None]:
import numpy as np

def make_rand_image(h=512, w=512):
    return (np.random.rand(h, w, 3) * 255).astype("uint8")

def make_sine_audio(sr=48000, seconds=1.0, freq=440.0):
    t = np.linspace(0, seconds, int(sr * seconds), endpoint=False)
    return (0.1 * np.sin(2 * np.pi * freq * t)).astype(np.float32)

Пример использования 1.

In [None]:
import pandas as pd

# Данные (строковые метки → красиво отобразятся в predict(return_label_str=True))
df_clip_clap = pd.DataFrame([
    {"text": "A man riding a bike", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 440.0), "label": "sports"},
    {"text": "A cat lying on sofa", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 330.0), "label": "lifestyle"},
    {"text": "Stock market is volatile", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 550.0), "label": "business"},
    {"text": "Runner on the track", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 660.0), "label": "sports"},
    {"text": "New cafe opens downtown", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 220.0), "label": "lifestyle"},
    {"text": "Company reports revenue", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 770.0), "label": "business"},
])

pipe1 = SingleModelMultiComboClassification(
    modalities=["text", "image", "audio"],
    num_labels=3,                          # == числу уникальных меток
    target_column_name="label",
    text_columns=["text"],
    image_columns=["image"],
    audio_columns=["audio"],
    # Явные конфиги бэкендов
    text_model_config={
        "checkpoint": "openai/clip-vit-base-patch32",
        "model_type": "clip",
        "max_length": 77
    },
    image_model_config={
        "checkpoint": "openai/clip-vit-base-patch32",
        "model_type": "clip",
        "max_images": 1,
        "image_agg": "mean"
    },
    audio_model_config={
        "checkpoint": "laion/clap-htsat-unfused",
        "model_type": "clap",
        "sr": 48000,
        "max_audios": 1,
        "audio_agg": "mean"
    },
    fusion="mean",                         # у всех 512 → mean
    freeze_backbone=False,                  # linear probing
    use_batch_tokenizer=True,              # быстрый токенизатор
    pretokenize_data=True,                 # предварительная токенизация
    pretokenize_batch_size=128,
    tokenizer_cache_size=5000,
    max_pretokenize_samples=100000,
    local_cache_dir="./model_cache"
)

pipe1.fit(
    train_data=df_clip_clap,
    epochs=2,
    test_size=0.33,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    metric_name="f1",
    fp16=True,                 # если есть CUDA
    logging_steps=1,
    eval_steps=2,
    output_dir="./mc_max_clip_clap",
    seed=42,
    hidden=512,
    dropout=0.2,
    gradient_checkpointing=True,
    fit_chunk_size=2,          # чанки по 2 сэмпла
    clear_cache_every_n_chunks=5,
)

# Предсказания — вероятности
probas1 = pipe1.predict(df_clip_clap.iloc[:3], return_proba=True)
print("Probas shape:", probas1.shape)

# Предсказания — строковые метки
labels1 = pipe1.predict(df_clip_clap.iloc[:3], return_label_str=True)
print("Labels:", labels1)

# Эмбеддинги (fused + по модальностям)
fused1, per1 = pipe1.get_embeddings(df_clip_clap.iloc[:3], batch_size=2, return_per_modality=True)
print("Fused shape:", fused1.shape)
for m, arr in per1.items():
    print(f"{m} emb shape:", arr.shape)

Пример использования 2.

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

def build_binary_multimodal_df(n_per_class: int = 50, sr: int = 16000) -> pd.DataFrame:
    # Генерация текстов «pets»
    pet_animals = ["кошка", "собака", "щенок", "кот", "котёнок", "пёс", "питомец", "котик", "пёсик", "котяра"]
    pet_actions = ["сидит", "лежит", "играет", "смотрит", "прячется", "спит", "тянется", "мурлычет", "исследует", "охотится"]
    pet_places = ["на подоконнике", "на диване", "на ковре", "на кухне", "в коробке", "на кресле", "у окна", "в саду", "на полу", "на стуле"]

    # Генерация текстов «news»
    news_subjects = ["власти города", "жители района", "аналитики", "эксперты", "журналисты", "компания", "департамент", "учёные", "инженеры", "ведомство"]
    news_verbs = ["обсудили", "обновили", "сообщили", "рассказали", "анонсировали", "заявили", "подтвердили", "планируют", "запустили", "увеличили"]
    news_topics = ["экономику", "политику", "транспорт", "погоду", "технологии", "культуру", "спорт", "здравоохранение", "образование", "экологию"]

    rows = []

    # PETS класс
    pet_freqs = [260.0, 280.0, 300.0, 320.0, 340.0, 360.0]
    for _ in range(n_per_class):
        text = f"{random.choice(pet_animals).capitalize()} {random.choice(pet_actions)} {random.choice(pet_places)}"
        img = make_rand_image()
        aud = make_sine_audio(sr, 1.0, random.choice(pet_freqs))
        rows.append({"text": text, "image": img, "audio": aud, "label": "pets"})

    # NEWS класс
    news_freqs = [560.0, 580.0, 600.0, 620.0, 640.0, 660.0]
    for _ in range(n_per_class):
        text = f"{random.choice(news_subjects).capitalize()} {random.choice(news_verbs)} новости про {random.choice(news_topics)}"
        img = make_rand_image()
        aud = make_sine_audio(sr, 1.0, random.choice(news_freqs))
        rows.append({"text": text, "image": img, "audio": aud, "label": "news"})

    random.shuffle(rows)
    df = pd.DataFrame(rows)
    return df

train_data = build_binary_multimodal_df(n_per_class=12)
df_train, df_eval = train_test_split(
    train_data, test_size=0.3, random_state=42, shuffle=True,
    stratify=train_data['label'])

pipe3 = SingleModelMultiComboClassification(
    modalities=["text", "image", "audio"],
    num_labels=2,
    target_column_name="label",
    text_columns=["text"],
    image_columns=["image"],
    audio_columns=["audio"],
    text_model_config={
        "checkpoint": "DeepPavlov/rubert-base-cased",
        "model_type": "bert",
        "max_length": 256
    },
    image_model_config={
        "checkpoint": "google/vit-base-patch16-224",
        "model_type": "vit",
        "max_images": 1,
        "image_agg": "mean"
    },
    audio_model_config={
        "checkpoint": "facebook/wav2vec2-base-960h",
        "model_type": "auto",   # AutoModel + AutoProcessor
        "sr": 16000,
        "max_audios": 1,
        "audio_agg": "mean"
    },
    fusion="concat",
    freeze_backbone=True,
    use_batch_tokenizer=True,
    pretokenize_data=True,
    pretokenize_batch_size=128,
    tokenizer_cache_size=5000,
    local_cache_dir="./model_cache"
)

pipe3.fit(
    train_data=df_train,
    test_data=df_eval,
    epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    metric_name="f1",
    fp16=True,
    logging_steps=1,
    eval_steps=1,
    output_dir="./mc_max_rubert_vit_w2v2",
    seed=2025,
    hidden=512,
    dropout=0.2,
    gradient_checkpointing=True,
    fit_chunk_size=8,
    clear_cache_every_n_chunks=3
)

print("Pred (labels):", pipe3.predict(df_eval, return_label_str=True))
fused3, per3 = pipe3.get_embeddings(df_eval, batch_size=2, return_per_modality=True)
print("Fused:", fused3.shape, "| text:", per3["text"].shape, "| image:", per3["image"].shape, "| audio:", per3["audio"].shape)

Пример использования 3.

In [None]:
import pandas as pd

df_min = pd.DataFrame([
    {"text": "Привет, как дела?", "label": "greet"},
    {"text": "Сегодня отличная погода", "label": "weather"},
    {"text": "До встречи!", "label": "greet"},
])

pipe_min = SingleModelMultiComboClassification(
    modalities=["text"],
    num_labels=2,                           # ровно столько, сколько уникальных меток в df_min
    target_column_name="label",
    text_columns=["text"],
    # Достаточно указать только текстовую модель
    text_model_config={
        "checkpoint": "DeepPavlov/rubert-base-cased",
        "model_type": "bert",
        "max_length": 128
    },
    fusion="concat",                        # неважно для single-modality
    freeze_backbone=True
)

# Минимальный fit: всё по умолчанию (без ранней остановки, без чанков)
pipe_min.fit(train_data=df_min)

print("Pred (ids):", pipe_min.predict(df_min))
print("Pred (labels):", pipe_min.predict(df_min, return_label_str=True))
emb_min = pipe_min.get_embeddings(df_min)
print("Embeddings shape:", emb_min.shape)

Пример использования 4.

In [None]:
import pandas as pd
from huggingface_hub import login
login('hf_ьдшин4хфэюц2\хкпсзэыэкпзмцфх3мххи ихщ т игщ йхт ъйм\укщи хъ4хрьх84им о\фуео\щ рэрь')  # ваш токен

df_min = pd.DataFrame([
    {"text": "Привет, как дела?", "label": "greet"},
    {"text": "Сегодня отличная погода", "label": "weather"},
    {"text": "До встречи!", "label": "greet"},
])

pipe_min = SingleModelMultiComboClassification(
    modalities=["text"],
    num_labels=2,
    target_column_name="label",
    text_columns=["text"],
    text_model_config={
        "checkpoint": "google/embeddinggemma-300m",
        "max_length": 2048
    },
    text_padding='max_length',
    freeze_backbone=False
)

pipe_min.fit(train_data=df_min)

print("Pred (ids):", pipe_min.predict(df_min))
print("Pred (labels):", pipe_min.predict(df_min, return_label_str=True))
emb_min = pipe_min.get_embeddings(df_min)
print("Embeddings shape:", emb_min.shape)

# Дообучение регрессора, который работает с данными разной модальностью.

Реализация регрессора.

In [None]:
# !pip install --upgrade --no-cache-dir \
#   --extra-index-url https://download.pytorch.org/whl/cu124 \
#   pillow==11.1.0 \
#   numpy==1.26.4 \
#   pandas==2.2.3 \
#   tqdm==4.67.1 \
#   transformers==4.51.3 \
#   evaluate==0.4.5 \
#   wav2clip==0.1.0 \
#   torch==2.6.0+cu124 \
#   torchaudio==2.6.0+cu124
!pip install evaluate wav2clip

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import math
import random
import gc
from functools import lru_cache
from typing import Any, Callable, Dict, Generator, List, Optional, Union

from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from transformers import TrainingArguments, Trainer
from transformers.trainer_callback import TrainerCallback, PrinterCallback, EarlyStoppingCallback
from transformers.modeling_outputs import SequenceClassifierOutput
import evaluate


# =========================
# Утилиты
# =========================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def to_pil(x: Union[str, np.ndarray, 'Image.Image']) -> 'Image.Image':
    if isinstance(x, Image.Image): return x.convert("RGB")
    if isinstance(x, str): return Image.open(x).convert("RGB")
    if isinstance(x, np.ndarray): return Image.fromarray(x).convert("RGB")
    raise ValueError("Ожидается путь/np.ndarray/PIL.Image")


def load_audio(path: str, target_sr: int) -> np.ndarray:
    try:
        import torchaudio
    except Exception as e:
        raise RuntimeError("Требуется torchaudio: pip install torchaudio") from e
    waveform, sr = torchaudio.load(path)
    if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True)
    if sr != target_sr: waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=target_sr)
    return waveform.squeeze(0).numpy().astype(np.float32)


def safe_load(component_cls, checkpoint: str, local_cache_dir: str = "./model_cache",
              local_files_only: Optional[bool] = None, **kwargs):
    if local_files_only is None:
        local_files_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
    name = getattr(component_cls, "__name__", "")
    if "Tokenizer" in name:
        kwargs.setdefault("use_fast", True)
    return component_cls.from_pretrained(
        checkpoint, cache_dir=local_cache_dir, local_files_only=local_files_only, **kwargs
    )


# =========================
# Токенизатор батчевый (с поддержкой стратегии паддинга)
# =========================

class BatchTokenizer:
    def __init__(
        self,
        tokenizer,
        max_length: int = 512,
        cache_size: int = 10000,
        batch_size: int = 256,
        use_fast: bool = True,
        device: str = "cpu",
        padding_strategy: str = "max_length"  # "max_length" или "dynamic"
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_fast = use_fast
        self.device = device
        self.padding_strategy = padding_strategy
        if self.padding_strategy not in {"max_length", "dynamic"}:
            raise ValueError("padding_strategy должен быть 'max_length' или 'dynamic'")
        self._cache = lru_cache(maxsize=cache_size)(self._tokenize_single)
        self.is_fast = hasattr(tokenizer, "is_fast") and tokenizer.is_fast
        if self.is_fast:
            print("✓ Используется Fast Tokenizer")

    def _tokenize_single(self, text: str) -> tuple:
        # Кэшируем только фиксированный паддинг — иначе батчи не склеить
        result = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return tuple((k, v.squeeze(0).cpu().numpy()) for k, v in result.items())

    def tokenize_batch(self, texts: List[str], use_cache: bool = True) -> Dict[str, torch.Tensor]:
        # Динамический паддинг — токенизируем сразу список (без поэлементного кэша, иначе формы различаются)
        if self.padding_strategy == "dynamic":
            result = self.tokenizer(
                texts,
                padding=True,  # 'longest'
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            for key in result:
                if key in ["input_ids", "attention_mask", "token_type_ids"]:
                    result[key] = result[key].long()
            return result

        # Фиксированный паддинг — используем кэш для коротких батчей
        if use_cache and len(texts) < 100:
            results = [dict(self._cache(text)) for text in texts]
            keys = results[0].keys()
            batch_dict = {}
            for key in keys:
                dtype = torch.long if key in ["input_ids", "attention_mask", "token_type_ids"] else torch.float32
                batch_dict[key] = torch.tensor(np.stack([r[key] for r in results]), dtype=dtype)
            return batch_dict
        else:
            result = self.tokenizer(
                texts,
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            for key in result:
                if key in ["input_ids", "attention_mask", "token_type_ids"]:
                    result[key] = result[key].long()
            return result

    def tokenize_dataset_lazy(
        self,
        texts: List[str],
        batch_size: Optional[int] = None
    ) -> Generator[Dict[str, torch.Tensor], None, None]:
        batch_size = batch_size or self.batch_size
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            yield self.tokenize_batch(batch, use_cache=False)

    def clear_cache(self):
        self._cache.cache_clear()


# =========================
# Универсальный датасет (ускорённый)
# =========================

class MultiComboDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        target_col: str,
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_tokenizer: Optional[BatchTokenizer] = None,          # BatchTokenizer
        text_tokenizer_fn: Optional[Callable] = None,              # кастомная fn -> dict тензоров ИЛИ строка
        text_tokenizer_fn_batched: Optional[Callable] = None,      # батчевая кастомная fn: List[dict] -> dict тензоров
        special_tokens: Optional[Dict[str, Any]] = None,
        pretokenize: bool = False,
        pretokenize_batch_size: int = 256,
        max_cache_size: int = 100000,                              # совместимость; не используется банками напрямую
        tokenizer_returns_tensors: bool = False,
        cache_dir: Optional[str] = None,
        deduplicate_texts: bool = True
    ):
        """
        Быстрый датасет:
        - Предтокенизирует текст батчами в банки тензоров (если возможно).
        - В __getitem__ просто делает слайс по банкам (O(1)), либо возвращает сырую строку.
        - Предсобирает списки изображений/аудио и метки.

        Если выбран dynamic padding в BatchTokenizer — предтокенизация выключается.
        """
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.target_col = target_col

        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []

        self.batch_tokenizer = text_tokenizer
        self.text_tokenizer_fn = text_tokenizer_fn
        self.text_tokenizer_fn_batched = text_tokenizer_fn_batched

        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors

        self._N = len(self.df)

        # labels в тензор [N, K] или [N]
        self._labels = self._prepare_labels(self.df, self.target_col)

        # Предсобранные списки изображений/аудио (без pandas в __getitem__)
        self._image_lists = self._collect_multi_values(self.df, self.image_columns) if self.image_columns else None
        self._audio_lists = self._collect_multi_values(self.df, self.audio_columns) if self.audio_columns else None

        # Банки предтокенизированных текстов: dict(key -> torch.Tensor [N, L])
        self._tok_bank: Optional[Dict[str, torch.Tensor]] = None
        self._has_text = bool(self.text_columns)

        # Dynamic-паддинг несовместим с предтокенизацией (формы различаются)
        if pretokenize and self.batch_tokenizer is not None and getattr(self.batch_tokenizer, "padding_strategy", "max_length") == "dynamic":
            print("⚠ Предтокенизация отключена: выбран dynamic-паддинг для текста.")
            pretokenize = False

        # Предтокенизация (батчами) — либо BatchTokenizer, либо кастомная функция (batched/single)
        if self._has_text and pretokenize:
            if self.batch_tokenizer is not None and self.text_tokenizer_fn is None and self.text_tokenizer_fn_batched is None:
                self._pretokenize_with_batch_tokenizer(pretokenize_batch_size)
            else:
                self._pretokenize_with_custom_fn(pretokenize_batch_size, deduplicate_texts=deduplicate_texts)

    # --------------------------
    # Метки
    # --------------------------
    def _prepare_labels(self, df: pd.DataFrame, target_col: str) -> torch.Tensor:
        if target_col not in df.columns:
            return torch.zeros((len(df), 1), dtype=torch.float32)
        labels_list = []
        for i in range(len(df)):
            v = df.iloc[i][target_col]
            if isinstance(v, (list, tuple, np.ndarray)):
                arr = np.asarray(v, dtype=np.float32)
            else:
                try:
                    arr = np.asarray([float(v)], dtype=np.float32)
                except Exception:
                    arr = np.asarray([0.0], dtype=np.float32)
            labels_list.append(arr)
        # выравниваем K по максимуму
        K = max(a.shape[0] for a in labels_list) if labels_list else 1
        out = np.zeros((len(labels_list), K), dtype=np.float32)
        for i, a in enumerate(labels_list):
            out[i, :a.shape[0]] = a
        return torch.tensor(out, dtype=torch.float32)

    # --------------------------
    # Хелперы по тексту/мультимедиа
    # --------------------------
    def _join_text(self, row: pd.Series) -> str:
        sep = self.special_tokens.get("sep", " [SEP] ")
        return sep.join([str(row[c]) if pd.notnull(row[c]) else "" for c in self.text_columns])

    @staticmethod
    def _as_list(v):
        if v is None or (isinstance(v, float) and np.isnan(v)):
            return []
        if isinstance(v, (list, tuple)):
            return list(v)
        return [v]

    def _collect_multi_values(self, df: pd.DataFrame, columns: List[str]) -> List[List[Any]]:
        out = []
        as_list = self._as_list
        for _, row in df.iterrows():
            lst: List[Any] = []
            for c in columns:
                if c in row:
                    lst.extend([x for x in as_list(row[c]) if x is not None])
            out.append(lst)
        return out

    # --------------------------
    # Предтокенизация: BatchTokenizer
    # --------------------------
    def _pretokenize_with_batch_tokenizer(self, batch_size: int):
        print("Предтокенизация с BatchTokenizer...")
        texts = [self._join_text(self.df.iloc[i]) for i in range(self._N)]
        banks: Dict[str, List[torch.Tensor]] = {}

        for start in range(0, self._N, batch_size):
            batch = texts[start:start + batch_size]
            tok = self.batch_tokenizer.tokenize_batch(batch, use_cache=False)  # dict[str, torch.Tensor [B, L]]
            # типы
            for k in tok:
                if k in ("input_ids", "attention_mask", "token_type_ids"):
                    tok[k] = tok[k].long()
                else:
                    tok[k] = tok[k].to(torch.float32)
            for k, v in tok.items():
                banks.setdefault(k, []).append(v)

        self._tok_bank = {k: torch.cat(v_parts, dim=0).contiguous() for k, v_parts in banks.items()}
        shapes = {k: tuple(v.shape) for k, v in self._tok_bank.items()}
        print(f"✓ Предтокенизация завершена: {self._N} образцов | keys={list(self._tok_bank.keys())}, shapes={shapes}")

    # --------------------------
    # Предтокенизация: кастомные функции (batched / single + дедуп)
    # --------------------------
    def _pretokenize_with_custom_fn(self, batch_size: int, deduplicate_texts: bool = True):
        if not self._has_text:
            return

        cols = list(self.text_columns)

        # Если есть batched-функция — используем её
        if self.text_tokenizer_fn_batched is not None:
            print("Предтокенизация кастомной batched-функцией...")
            # первая порция для выяснения формы
            first_end = min(self._N, max(8, batch_size))
            batch_data = []
            for i in range(first_end):
                row = self.df.iloc[i]
                d = {c: ("" if pd.isna(row[c]) else str(row[c])) for c in cols}
                batch_data.append(d)
            first_tok = self.text_tokenizer_fn_batched(batch_data, self.special_tokens)
            if not isinstance(first_tok, dict):
                raise ValueError("text_tokenizer_fn_batched должна возвращать dict тензоров [B, L]")
            # выделяем банки по форме первого батча
            bank: Dict[str, torch.Tensor] = {}
            for k, t in first_tok.items():
                if not torch.is_tensor(t): t = torch.tensor(t)
                bank[k] = torch.empty((self._N, t.size(1)), dtype=t.dtype)
                bank[k][:first_end] = t[:first_end]

            # оставшаяся часть
            for start in range(first_end, self._N, batch_size):
                end = min(self._N, start + batch_size)
                batch_data = []
                for i in range(start, end):
                    row = self.df.iloc[i]
                    d = {c: ("" if pd.isna(row[c]) else str(row[c])) for c in cols}
                    batch_data.append(d)
                tok = self.text_tokenizer_fn_batched(batch_data, self.special_tokens)
                for k, t in tok.items():
                    if not torch.is_tensor(t): t = torch.tensor(t)
                    bank[k][start:end] = t

            self._tok_bank = {k: v.contiguous() for k, v in bank.items()}
            shapes = {k: tuple(v.shape) for k, v in self._tok_bank.items()}
            print(f"✓ Предтокенизация кастомной batched-функцией завершена: shapes={shapes}")
            return

        # Иначе — single-функция + дедуп
        print("Предтокенизация кастомной single-функцией...")
        # подготовим строки (без pandas в горячем цикле)
        col_arrays = [self.df[c].astype(str).where(~self.df[c].isna(), other="").tolist() for c in cols]

        # детектируем форму
        first_td = {c: col_arrays[i][0] for i, c in enumerate(cols)}
        first_tok = self.text_tokenizer_fn(first_td, self.special_tokens)
        if not isinstance(first_tok, dict):
            raise ValueError("custom text_tokenizer_fn должна возвращать dict тензоров")
        for k, t in first_tok.items():
            if not torch.is_tensor(t): first_tok[k] = torch.tensor(t)
        # банки
        bank: Dict[str, torch.Tensor] = {k: torch.empty((self._N, *t.shape), dtype=t.dtype) for k, t in first_tok.items()}
        for k, t in first_tok.items():
            bank[k][0].copy_(t)

        cache: Dict[tuple, Dict[str, torch.Tensor]] = {}
        if deduplicate_texts:
            key0 = tuple(first_td.get(c, "") for c in cols)
            cache[key0] = {k: v.clone() for k, v in first_tok.items()}

        for i in range(1, self._N):
            td = {c: col_arrays[j][i] for j, c in enumerate(cols)}
            if deduplicate_texts:
                key = tuple(td.get(c, "") for c in cols)
                tok = cache.get(key)
                if tok is None:
                    tok = self.text_tokenizer_fn(td, self.special_tokens)
                    for k, t in tok.items():
                        if not torch.is_tensor(t): tok[k] = torch.tensor(t)
                    cache[key] = {k: v.clone() for k, v in tok.items()}
            else:
                tok = self.text_tokenizer_fn(td, self.special_tokens)
                for k, t in tok.items():
                    if not torch.is_tensor(t): tok[k] = torch.tensor(t)
            for k, t in tok.items():
                bank[k][i].copy_(t)

        self._tok_bank = {k: v.contiguous() for k, v in bank.items()}
        shapes = {k: tuple(v.shape) for k, v in self._tok_bank.items()}
        print(f"✓ Предтокенизация кастомной single-функцией завершена: shapes={shapes}")

    # --------------------------
    # Интерфейс Dataset
    # --------------------------
    def __len__(self) -> int:
        return self._N

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item: Dict[str, Any] = {}

        # Метки (вектор [K] или [1])
        item["labels"] = self._labels[idx]

        # Текст: либо слайс из банка, либо сырая строка (коллатор потом батчево токенизирует)
        if self.text_columns:
            if self._tok_bank is not None:
                item["text_tokens"] = {k: v[idx] for k, v in self._tok_bank.items()}
            else:
                item["text"] = self._join_text(self.df.iloc[idx])

        # Изображения/аудио — предсобранные списки
        if self._image_lists is not None:
            item["images"] = self._image_lists[idx]
        if self._audio_lists is not None:
            item["audios"] = self._audio_lists[idx]

        return item

    # --------------------------
    # Сервис
    # --------------------------
    def get_cache_stats(self) -> Dict[str, Any]:
        has = self._tok_bank is not None
        sizes = {k: tuple(v.shape) for k, v in (self._tok_bank or {}).items()}
        return {"has_pretokenized": has, "shapes": sizes, "N": self._N}

    def clear_cache(self):
        self._tok_bank = None
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


# =========================
# Универсальный бэкенд
# =========================

class BaseBackend(nn.Module):
    name: str = "base"
    supported: set = set()
    embed_dim: int = 0
    out_dim_per_modality: Dict[str, int] = {}
    text_tokenizer_fn: Optional[Callable] = None
    batch_tokenizer: Optional[BatchTokenizer] = None
    special_tokens: Dict[str, str] = {}
    tokenizer_returns_tensors: bool = False
    local_cache_dir: str = "./model_cache"
    text_padding_strategy: str = "max_length"  # стратегия паддинга текста

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        raise NotImplementedError

    def encode(self, backend_inputs: Dict[str, Any], device: torch.device) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def freeze_all(self):
        for p in self.parameters():
            p.requires_grad = False

    def get_out_dim(self, modality: str) -> int:
        return self.out_dim_per_modality.get(modality, self.embed_dim)

    def set_text_tokenizer(self, tokenizer_fn: Optional[Callable], special_tokens: Optional[Dict[str, str]] = None,
                           returns_tensors: bool = False):
        self.text_tokenizer_fn = tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = returns_tensors

    def set_batch_tokenizer(self, tokenizer, max_length: int = 512,
                            cache_size: int = 10000, batch_size: int = 256,
                            padding_strategy: str = "max_length"):
        self.text_padding_strategy = padding_strategy
        self.batch_tokenizer = BatchTokenizer(
            tokenizer=tokenizer,
            max_length=max_length,
            cache_size=cache_size,
            batch_size=batch_size,
            use_fast=True,
            padding_strategy=padding_strategy
        )


class UniversalMultiBackend(BaseBackend):
    name = "universal"

    class _ParamDeviceProxy(nn.Module):
        def __init__(self, base, device: torch.device):
            super().__init__()
            self.base = base if isinstance(base, nn.Module) else None
            self._callable = base if not isinstance(base, nn.Module) else None
            self._dummy = nn.Parameter(torch.empty(0), requires_grad=False)
            with torch.no_grad():
                self._dummy.data = self._dummy.data.to(device)
            try:
                target = self.base if self.base is not None else self._callable
                if hasattr(target, "to"):
                    target.to(device)
            except Exception:
                pass

        def forward(self, *args, **kwargs):
            target = self.base if self.base is not None else self._callable
            return target(*args, **kwargs)

        def to(self, device, *args, **kwargs):
            self._dummy.data = self._dummy.data.to(device)
            try:
                target = self.base if self.base is not None else self._callable
                if hasattr(target, "to"):
                    target.to(device)
            except Exception:
                pass
            return super().to(device, *args, **kwargs)

    def _model_device(self, model, default: torch.device) -> torch.device:
        try:
            return next(model.parameters()).device
        except StopIteration:
            pass
        try:
            buf = next(model.buffers())
            return buf.device
        except StopIteration:
            pass
        return default

    def _preferred_device(self) -> torch.device:
        if torch.cuda.is_available():
            return torch.device("cuda")
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")

    def _wrap_if_parameterless(self, model, device: torch.device):
        try:
            it = model.parameters() if hasattr(model, "parameters") else iter(())
            next(it)
            return model
        except StopIteration:
            return self._ParamDeviceProxy(model, device)
        except Exception:
            return self._ParamDeviceProxy(model, device)

    def __init__(
        self,
        text_model_config: Optional[Dict[str, Any]] = None,
        image_model_config: Optional[Dict[str, Any]] = None,
        audio_model_config: Optional[Dict[str, Any]] = None,
        freeze: bool = True,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        tokenizer_returns_tensors: bool = False,
        use_batch_tokenizer: bool = True,
        tokenizer_cache_size: int = 10000,
        tokenizer_batch_size: int = 256,
        local_cache_dir: str = "./model_cache",
        text_padding_strategy: str = "max_length"
    ):
        super().__init__()
        self.supported = set()
        self.out_dim_per_modality = {}
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors
        self.use_batch_tokenizer = use_batch_tokenizer
        self.tokenizer_cache_size = tokenizer_cache_size
        self.tokenizer_batch_size = tokenizer_batch_size
        self.local_cache_dir = local_cache_dir
        self.text_padding_strategy = text_padding_strategy

        self.text_model = None
        self.text_processor = None
        self.text_config = text_model_config or {}
        if text_model_config:
            self._init_text_model(text_model_config)
            self.supported.add("text")

        self.image_model = None
        self.image_processor = None
        self.image_config = image_model_config or {}
        if image_model_config:
            self._init_image_model(image_model_config)
            self.supported.add("image")

        self.audio_model = None
        self.audio_processor = None
        self.audio_config = audio_model_config or {}
        if audio_model_config:
            self._init_audio_model(audio_model_config)
            self.supported.add("audio")

        if freeze:
            self.freeze_all()

    def _ensure_2d(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if x is None:
            return None
        if x.dim() == 1:
            return x.unsqueeze(0)
        if x.dim() > 2:
            return x.view(x.size(0), -1)
        return x

    def _normalize_2d(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        x = self._ensure_2d(x)
        return F.normalize(x, dim=-1, eps=1e-12) if x is not None and x.numel() > 0 else x

    def _init_text_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPTokenizer, ClapModel, ClapProcessor

        checkpoint = config['checkpoint']
        model_type = config.get('model_type', 'auto').lower()

        print(f"Загрузка текстовой модели {checkpoint}...")

        if model_type == 'clip':
            self.text_model = safe_load(CLIPModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = safe_load(CLIPTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = self.text_model.config.projection_dim
        elif model_type == 'clap':
            self.text_model = safe_load(ClapModel, checkpoint, local_cache_dir=self.local_cache_dir)
            proc = safe_load(ClapProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = getattr(proc, 'tokenizer', None) or safe_load(AutoTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = getattr(self.text_model.config, "projection_dim", 512)
        else:
            self.text_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = safe_load(AutoTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = self.text_model.config.hidden_size

        dev = self._preferred_device()
        self.text_model  = self._wrap_if_parameterless(self.text_model, dev)

        if self.use_batch_tokenizer and self.text_processor is not None:
            self.set_batch_tokenizer(
                self.text_processor,
                max_length=config.get('max_length', 512),
                cache_size=self.tokenizer_cache_size,
                batch_size=self.tokenizer_batch_size,
                padding_strategy=self.text_padding_strategy
            )

        self.text_config['max_length'] = config.get('max_length', 512)
        self.text_config['dim'] = dim
        self.text_config['model_type'] = model_type
        self.out_dim_per_modality['text'] = dim

    def _init_image_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoImageProcessor, CLIPModel, CLIPImageProcessor

        checkpoint = config['checkpoint']
        model_type = config.get('model_type', 'auto').lower()

        print(f"Загрузка визуальной модели {checkpoint}...")

        if model_type == 'clip':
            self.image_model = safe_load(CLIPModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.image_processor = safe_load(CLIPImageProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.image_model.config.projection_dim
        else:
            self.image_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.image_processor = safe_load(AutoImageProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.image_model.config.hidden_size

        dev = self._preferred_device()
        self.image_model = self._wrap_if_parameterless(self.image_model, dev)

        self.image_config['max_images'] = config.get('max_images', 1)
        self.image_config['image_agg'] = config.get('image_agg', 'concat')
        self.image_config['dim'] = dim
        self.image_config['model_type'] = model_type

        self.out_dim_per_modality['image'] = (dim * self.image_config['max_images']) if self.image_config['image_agg'] == 'concat' else dim

    def _init_audio_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoProcessor, ClapModel, ClapProcessor

        model_type = config.get('model_type', 'auto').lower()
        checkpoint = config.get('checkpoint', None)

        print(f"Загрузка аудио модели (type={model_type})...")

        if model_type == 'wav2clip':
            import wav2clip as w2c
            self._w2c = w2c

            w2c_model = None
            if hasattr(w2c, "get_model"):
                w2c_model = w2c.get_model()
            elif hasattr(w2c, "model"):
                m = w2c.model
                w2c_model = m() if callable(m) else m
            else:
                raise RuntimeError("wav2clip не содержит get_model()/model. Обновите пакет wav2clip.")

            self.audio_model = w2c_model

            try:
                if isinstance(self.audio_model, torch.nn.Module) and torch.cuda.is_available():
                    self.audio_model = self.audio_model.to("cuda")
            except Exception:
                pass

            self.audio_processor = None
            dim = 512
            sr = config.get('sr', 16000)

        elif model_type == 'clap':
            if checkpoint is None:
                raise ValueError("audio_model_config['checkpoint'] обязателен для CLAP")
            self.audio_model = safe_load(ClapModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.audio_processor = safe_load(ClapProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = getattr(self.audio_model.config, "projection_dim", 512)
            sr = getattr(self.audio_processor, "sampling_rate", None)
            if sr is None:
                fe = getattr(self.audio_processor, "feature_extractor", None)
                sr = getattr(fe, "sampling_rate", 48000) if fe is not None else 48000

        else:
            if checkpoint is None:
                raise ValueError("audio_model_config['checkpoint'] обязателен для аудио-моделей, кроме wav2clip")
            self.audio_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.audio_processor = safe_load(AutoProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.audio_model.config.hidden_size
            fe = getattr(self.audio_processor, "feature_extractor", None)
            sr = getattr(fe, "sampling_rate", 16000) if fe is not None else 16000

        dev = self._preferred_device()
        self.audio_model = self._wrap_if_parameterless(self.audio_model, dev)

        self.audio_config['sr'] = config.get('sr', sr)
        self.audio_config['max_audios'] = config.get('max_audios', 1)
        self.audio_config['audio_agg'] = config.get('audio_agg', 'concat')
        self.audio_config['dim'] = dim
        self.audio_config['model_type'] = model_type

        self.out_dim_per_modality['audio'] = (
            dim * self.audio_config['max_audios']
            if self.audio_config['audio_agg'] == 'concat' else dim
        )

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        # Метки (регрессия: вектор или скаляр)
        labels = []
        for b in batch:
            labels.append(torch.as_tensor(b.get("labels", 0.0), dtype=torch.float32))
        labels = torch.stack(labels)

        backend_inputs: Dict[str, Any] = {}
        batch_size = len(batch)

        # Текст
        if self.text_model is not None:
            if "text_tokens" in batch[0]:
                text_inputs = {}
                for key in batch[0]["text_tokens"].keys():
                    if torch.is_tensor(batch[0]["text_tokens"][key]):
                        text_inputs[key] = torch.stack([b["text_tokens"][key] for b in batch])
                    else:
                        dtype = torch.long if key in ["input_ids", "attention_mask", "token_type_ids"] else torch.float32
                        text_inputs[key] = torch.tensor([b["text_tokens"][key] for b in batch], dtype=dtype)
                backend_inputs["text_inputs"] = text_inputs
            else:
                texts = [b.get("text", "") or " " for b in batch]
                if self.batch_tokenizer:
                    text_inputs = self.batch_tokenizer.tokenize_batch(texts, use_cache=True)
                else:
                    pad = "max_length" if getattr(self, "text_padding_strategy", "max_length") == "max_length" else True
                    text_inputs = self.text_processor(
                        texts, padding=pad, truncation=True,
                        max_length=self.text_config.get('max_length', 512),
                        return_tensors="pt"
                    )
                backend_inputs["text_inputs"] = {k: v for k, v in text_inputs.items()}

        # Изображения
        if self.image_model is not None:
            images_lists = [b.get("images", []) for b in batch]
            flat_images, img_counts = [], []
            for lst in images_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [img for img in lst if img is not None]
                img_counts.append(len(lst))
                for img in lst:
                    flat_images.append(to_pil(img))

            if len(flat_images) > 0:
                img_proc = self.image_processor(images=flat_images, return_tensors="pt")
                backend_inputs["image_inputs"] = {"pixel_values": img_proc["pixel_values"]}
            else:
                backend_inputs["image_inputs"] = {"pixel_values": None}

            backend_inputs["image_counts"] = torch.tensor(img_counts, dtype=torch.long)

        # Аудио
        if self.audio_model is not None:
            audios_lists = [b.get("audios", []) for b in batch]
            flat_audios, aud_counts = [], []
            for lst in audios_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [a for a in lst if a is not None]
                aud_counts.append(len(lst))
                for a in lst:
                    if isinstance(a, str):
                        flat_audios.append(load_audio(a, self.audio_config['sr']))
                    elif isinstance(a, np.ndarray):
                        aa = np.asarray(a, dtype=np.float32)
                        if aa.ndim > 1: aa = np.squeeze(aa)
                        if aa.ndim > 1: aa = aa.reshape(-1)
                        flat_audios.append(aa)
            if self.audio_config.get('model_type') == 'wav2clip':
                backend_inputs["audio_inputs"] = {"raw_audios": flat_audios}
            elif len(flat_audios) > 0:
                if self.audio_config.get('model_type') == 'clap':
                    aud_proc = self.audio_processor(
                        audios=flat_audios, sampling_rate=self.audio_config['sr'], padding=True, return_tensors="pt"
                    )
                    backend_inputs["audio_inputs"] = {"input_features": aud_proc["input_features"]}
                else:
                    aud_proc = self.audio_processor(
                        flat_audios, sampling_rate=self.audio_config['sr'], padding=True, return_tensors="pt"
                    )
                    backend_inputs["audio_inputs"] = {"input_values": aud_proc["input_values"]}
            else:
                backend_inputs["audio_inputs"] = {"input_features": None, "input_values": None, "raw_audios": []}

            backend_inputs["audio_counts"] = torch.tensor(aud_counts, dtype=torch.long)

        backend_inputs["batch_size"] = batch_size
        return {"labels": labels, "backend_inputs": backend_inputs}

    def _aggregate_embeddings(
        self,
        embs: Optional[torch.Tensor],
        counts: List[int],
        max_k: int,
        dim_hint: int,
        agg_type: str,
        batch_size: int,
        device: torch.device
    ) -> torch.Tensor:
        if embs is None or (torch.is_tensor(embs) and embs.numel() == 0):
            feat_dim = int(dim_hint) if dim_hint is not None else 0
            out_dim = feat_dim * max_k if agg_type == 'concat' else feat_dim
            return torch.zeros((batch_size, out_dim), device=device, dtype=torch.float32)

        if not torch.is_tensor(embs):
            embs = torch.as_tensor(embs, device=device, dtype=torch.float32)
        if embs.dim() == 1:
            embs = embs.unsqueeze(0)
        elif embs.dim() > 2:
            embs = embs.view(embs.size(0), -1)

        N, D = embs.size()
        out_dim = (D * max_k) if agg_type == 'concat' else D
        out = torch.zeros((batch_size, out_dim), device=device, dtype=embs.dtype)

        offset = 0
        for i, c in enumerate(counts):
            if c <= 0 or offset >= N:
                continue
            take_n = min(c, N - offset)
            sample = embs[offset:offset + take_n]
            offset += take_n

            if agg_type == 'concat':
                take = sample[:max_k]
                if take.size(0) < max_k:
                    pad = torch.zeros((max_k - take.size(0), D), device=device, dtype=embs.dtype)
                    take = torch.cat([take, pad], dim=0)
                out[i] = take.reshape(-1)
            else:
                out[i] = sample.mean(dim=0)

        return F.normalize(out, dim=-1, eps=1e-12) if out.size(1) > 0 else out

    @torch.no_grad()
    def _wav2clip_embed(self, arr: np.ndarray, device: torch.device) -> torch.Tensor:
        arr = np.asarray(arr, dtype=np.float32)
        if arr.ndim > 1:
            arr = np.squeeze(arr)
        if arr.ndim > 1:
            arr = arr.reshape(-1)
        if arr.size < 512:
            arr = np.pad(arr, (0, 512 - arr.size), mode="constant")

        try:
            emb = self._w2c.embed_audio(arr, self.audio_model)
            emb = np.asarray(emb)
        except Exception:
            model_dev = self._model_device(self.audio_model, default=device)
            x = torch.from_numpy(arr).float().unsqueeze(0).to(model_dev)
            y = self.audio_model(x)
            if isinstance(y, (tuple, list)):
                y = y[0]
            if torch.is_tensor(y):
                if y.dim() == 2 and y.size(0) == 1:
                    y = y.squeeze(0)
                emb = y.detach().cpu().numpy()
            else:
                emb = np.asarray(y)

        if emb.ndim > 1:
            emb = emb.reshape(-1)
        return torch.as_tensor(emb, device=device, dtype=torch.float32)

    def encode(self, backend_inputs: Dict[str, Any], device: torch.device) -> Dict[str, torch.Tensor]:
        results: Dict[str, torch.Tensor] = {}
        batch_size = int(backend_inputs.get("batch_size", 1))

        # Текст
        if self.text_model is not None and "text_inputs" in backend_inputs:
            text_inputs = {k: v.to(device) for k, v in backend_inputs["text_inputs"].items()}
            if hasattr(self.text_model, "get_text_features"):
                text_z = self.text_model.get_text_features(**text_inputs)
            else:
                outputs = self.text_model(**text_inputs)
                text_z = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state.mean(dim=1)
            results["text"] = self._normalize_2d(text_z)

        # Изображения
        if self.image_model is not None and "image_inputs" in backend_inputs:
            pi = backend_inputs["image_inputs"]["pixel_values"]
            counts = backend_inputs["image_counts"].tolist()
            total_images_needed = sum(counts)

            img_flat = None
            actual_img_dim = self.image_config.get("dim", 768)

            if pi is not None and pi.numel() > 0 and total_images_needed > 0:
                pi = pi.to(device)
                if pi.size(0) > total_images_needed:
                    pi = pi[:total_images_needed]

                if hasattr(self.image_model, "get_image_features"):
                    img_flat = self.image_model.get_image_features(pixel_values=pi)
                else:
                    outputs = self.image_model(pixel_values=pi)
                    img_flat = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state[:, 0]

                img_flat = self._normalize_2d(img_flat)
                actual_img_dim = img_flat.size(1) if img_flat is not None else actual_img_dim

            img_z = self._aggregate_embeddings(
                img_flat, counts,
                self.image_config["max_images"],
                actual_img_dim,
                self.image_config["image_agg"],
                len(counts),
                device
            )

            if actual_img_dim != self.image_config.get("dim"):
                self.image_config["dim"] = actual_img_dim
                self.out_dim_per_modality["image"] = (
                    actual_img_dim * self.image_config["max_images"]
                    if self.image_config["image_agg"] == "concat" else actual_img_dim
                )

            results["image"] = img_z

        # Аудио
        if self.audio_model is not None and "audio_inputs" in backend_inputs:
            counts = backend_inputs["audio_counts"].tolist()
            total_audios_needed = sum(counts)

            aud_flat = None
            actual_aud_dim = self.audio_config.get("dim", 768)
            model_type = self.audio_config.get("model_type")

            if total_audios_needed > 0:
                if model_type == "clap":
                    af = backend_inputs["audio_inputs"]["input_features"]
                    if af is not None and af.numel() > 0:
                        af = af.to(device)
                        if af.size(0) > total_audios_needed:
                            af = af[:total_audios_needed]
                        with torch.cuda.amp.autocast(enabled=False):
                            aud_flat = self.audio_model.get_audio_features(input_features=af.float())
                        aud_flat = self._normalize_2d(aud_flat.float())
                        actual_aud_dim = aud_flat.size(1)

                elif model_type == "wav2clip":
                    raw_list = backend_inputs["audio_inputs"].get("raw_audios", [])
                    if len(raw_list) > total_audios_needed:
                        raw_list = raw_list[:total_audios_needed]
                    if len(raw_list) > 0:
                        embs = [self._wav2clip_embed(arr, device) for arr in raw_list]
                        aud_flat = torch.stack(embs, dim=0)
                        aud_flat = self._normalize_2d(aud_flat)
                        actual_aud_dim = aud_flat.size(1)

                else:
                    av = backend_inputs["audio_inputs"]["input_values"]
                    if av is not None and av.numel() > 0:
                        av = av.to(device)
                        if av.size(0) > total_audios_needed:
                            av = av[:total_audios_needed]
                        av = av.clamp_(-1.0, 1.0)
                        with torch.cuda.amp.autocast(enabled=False):
                            outputs = self.audio_model(input_values=av.float())
                            feats = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state.mean(dim=1)
                        aud_flat = self._normalize_2d(feats.float())
                        actual_aud_dim = aud_flat.size(1)

            aud_z = self._aggregate_embeddings(
                aud_flat, counts,
                self.audio_config["max_audios"],
                actual_aud_dim,
                self.audio_config["audio_agg"],
                len(counts),
                device
            )

            if aud_flat is not None and actual_aud_dim != self.audio_config.get("dim"):
                self.audio_config["dim"] = actual_aud_dim
                self.out_dim_per_modality["audio"] = (
                    actual_aud_dim * self.audio_config["max_audios"]
                    if self.audio_config["audio_agg"] == "concat" else actual_aud_dim
                )

            results["audio"] = aud_z

        # Строгая проверка согласованности размеров батча между модальностями
        if results:
            bs_list = [v.size(0) for v in results.values()]
            if len(set(bs_list)) != 1:
                raise RuntimeError(f"Inconsistent batch sizes across modalities: {bs_list}")

        return results


# =========================
# Классификатор (голова регрессии)
# =========================

class SingleBackboneClassifier(nn.Module):
    def __init__(
        self,
        backend: BaseBackend,
        modalities: List[str],
        num_labels: int,
        fusion: str = "concat",
        hidden: int = 512,
        dropout: float = 0.1
    ):
        super().__init__()
        self.backend = backend
        self.modalities = modalities
        self.fusion = fusion
        self.num_labels = num_labels

        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        if fusion == "concat":
            in_dim = sum(self.backend.get_out_dim(m) for m in order)
        elif fusion == "mean":
            dims = [self.backend.get_out_dim(m) for m in order]
            if len(set(dims)) != 1:
                raise ValueError('Для fusion="mean" размеры модальностей должны совпадать')
            in_dim = dims[0]
        else:
            raise ValueError('fusion должен быть "concat" или "mean"')

        self.head = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_labels)
        )

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        for m in [getattr(self.backend, "text_model", None),
                  getattr(self.backend, "image_model", None),
                  getattr(self.backend, "audio_model", None)]:
            if m is None:
                continue
            try:
                has_trainable = any(p.requires_grad for p in m.parameters()) if hasattr(m, "parameters") else False
            except Exception:
                has_trainable = False
            if not has_trainable:
                continue
            try:
                cfg = getattr(m, "config", None)
                if cfg is not None and hasattr(cfg, "use_cache"):
                    cfg.use_cache = False
            except Exception:
                pass
            try:
                if hasattr(m, "gradient_checkpointing_enable"):
                    try:
                        if gradient_checkpointing_kwargs is not None:
                            m.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
                        else:
                            m.gradient_checkpointing_enable()
                    except TypeError:
                        m.gradient_checkpointing_enable()
            except Exception:
                pass

    def gradient_checkpointing_disable(self):
        for m in [getattr(self.backend, "text_model", None),
                  getattr(self.backend, "image_model", None),
                  getattr(self.backend, "audio_model", None)]:
            if m is None:
                continue
            try:
                if hasattr(m, "gradient_checkpointing_disable"):
                    m.gradient_checkpointing_disable()
            except Exception:
                pass

    def _infer_device_from_inputs(self, obj) -> torch.device:
        if isinstance(obj, torch.Tensor):
            return obj.device
        if isinstance(obj, dict):
            for v in obj.values():
                d = self._infer_device_from_inputs(v)
                if d is not None:
                    return d
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _fuse(self, z: Dict[str, torch.Tensor]) -> torch.Tensor:
        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        feats = []
        batch_size = None
        for m in order:
            if m in z:
                t = z[m]
                if t.dim() == 3:
                    t = t.mean(dim=1)
                elif t.dim() > 3:
                    t = t.view(t.size(0), -1)
                feats.append(t)
                if batch_size is None:
                    batch_size = t.size(0)
        if batch_size is not None:
            feats = [f[:batch_size] for f in feats]
        if self.fusion == "concat":
            return torch.cat(feats, dim=-1)
        elif self.fusion == "mean":
            return torch.stack(feats, dim=0).mean(dim=0)

    def forward(self, backend_inputs: Dict[str, Any], labels: Optional[torch.Tensor] = None):
        device = self._infer_device_from_inputs(backend_inputs)
        z = self.backend.encode(backend_inputs, device=device)
        if not z:
            raise ValueError("Backend не вернул эмбеддинги")
        fused = self._fuse(z)
        logits = self.head(fused)
        return SequenceClassifierOutput(logits=logits)

    @torch.no_grad()
    def get_embeddings(self, backend_inputs: Dict[str, Any], return_per_modality: bool = False):
        device = self._infer_device_from_inputs(backend_inputs)
        z = self.backend.encode(backend_inputs, device=device)
        fused = self._fuse(z)
        if return_per_modality:
            return fused, z
        return fused


# =========================
# Trainer для регрессии (MSE)
# =========================

class MSETrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch: Optional[int] = None):
        labels = inputs.pop("labels").to(torch.float32)
        outputs = model(**inputs)
        logits = outputs.logits
        preds = logits if logits.dim() == 1 else (logits.squeeze(-1) if logits.size(-1) == 1 else logits)
        labels = labels.view_as(preds)
        loss = F.mse_loss(preds, labels)
        return (loss, outputs) if return_outputs else loss


# =========================
# Прогресс-логгер
# =========================

class PbarConsoleLogger(TrainerCallback):
    def __init__(self, pbar):
        self.pbar = pbar
        self.last_logs = {}
        self.last_train_loss = None
        self.printed_eval_steps = set()
        self.tqdm = tqdm

    def _step(self, state) -> int:
        return int(state.global_step or 0)

    def _fmt_postfix(self):
        parts = []
        if 'loss' in self.last_logs:
            parts.append(f"loss {self.last_logs['loss']:.4f}")
        if 'eval_loss' in self.last_logs:
            parts.append(f"val {self.last_logs['eval_loss']:.4f}")
        for k, v in self.last_logs.items():
            if k.startswith('eval_') and k not in (
                'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'
            ):
                parts.append(f"{k.replace('eval_', '')} {v:.4f}")
        return " | ".join(parts)

    def on_step_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        if self.last_logs:
            self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        for k, v in logs.items():
            if isinstance(v, (int, float)):
                self.last_logs[k] = float(v)
        if 'loss' in logs and isinstance(logs['loss'], (int, float)):
            self.last_train_loss = float(logs['loss'])

        self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

        if any(k.startswith('eval_') for k in logs.keys()):
            step = self._step(state)
            if step in self.printed_eval_steps:
                return
            self.printed_eval_steps.add(step)

            train_loss_str = f"{self.last_train_loss:.10f}" if self.last_train_loss is not None else "n/a"
            val_loss = logs.get('eval_loss', None)
            val_loss_str = f"{float(val_loss):.10g}" if isinstance(val_loss, (int, float)) else "n/a"

            exclude = {'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'}
            extra_parts = []
            for k, v in logs.items():
                if k.startswith('eval_') and k not in exclude:
                    metric_name = k.replace('eval_', '')
                    extra_parts.append(f"val {metric_name}: {float(v):.10f}")

            line = f"step: {step}, train loss: {train_loss_str}, val loss: {val_loss_str}"
            if extra_parts:
                line += ", " + ", ".join(extra_parts)
            self.tqdm.write(line)

    def on_train_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        self.pbar.refresh()


# =========================
# Основной пайплайн (регрессия с несколькими таргет-колонками)
# =========================

class SingleModelMultiComboRegression:
    """
    Высокоуровневый пайплайн мультимодальной регрессии (text / image / audio) поверх моделей Hugging Face
    и wav2clip. Поддерживает автоматическую сборку бэкенда под набор модальностей, батчевую токенизацию
    (включая dynamic padding), предварительную токенизацию датасета (отключается при dynamic), чанковую тренировку,
    раннюю остановку, извлечение эмбеддингов.

    Теперь поддерживаются несколько столбцов-таргетов. Вы передаёте список target_column_names,
    а число выходов (num_labels) определяется автоматически как len(target_column_names).
    """
    def __init__(
        self,
        modalities: List[str],
        target_column_names: List[str],
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        tokenizer_returns_tensors: bool = False,
        backend: str = "auto",
        clip_checkpoint: str = "openai/clip-vit-base-patch32",
        clap_checkpoint: str = "laion/clap-htsat-unfused",
        text_model_config: Optional[Dict[str, Any]] = None,
        image_model_config: Optional[Dict[str, Any]] = None,
        audio_model_config: Optional[Dict[str, Any]] = None,
        fusion: str = "concat",
        freeze_backbone: bool = True,
        clip_max_length: int = 77,
        max_images_per_sample: int = 1,
        max_audios_per_sample: int = 1,
        use_batch_tokenizer: bool = True,
        pretokenize_data: bool = True,
        pretokenize_batch_size: int = 256,
        tokenizer_cache_size: int = 10000,
        max_pretokenize_samples: int = 100000,
        local_cache_dir: str = "./model_cache",
        text_padding: str = "max_length"
    ):
        self.modalities = sorted(list(set(modalities)))
        self.target_column_names = list(target_column_names)
        self.num_labels = len(self.target_column_names)

        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors
        self.backend_name = backend.lower()
        self.clip_checkpoint = clip_checkpoint
        self.clap_checkpoint = clap_checkpoint
        self.text_model_config = text_model_config
        self.image_model_config = image_model_config
        self.audio_model_config = audio_model_config
        self.fusion = fusion
        self.freeze_backbone = freeze_backbone
        self.clip_max_length = clip_max_length
        self.max_images_per_sample = int(max_images_per_sample)
        self.max_audios_per_sample = int(max_audios_per_sample)

        self.use_batch_tokenizer = use_batch_tokenizer
        self.pretokenize_data = pretokenize_data
        self.pretokenize_batch_size = pretokenize_batch_size
        self.tokenizer_cache_size = tokenizer_cache_size
        self.max_pretokenize_samples = max_pretokenize_samples
        self.local_cache_dir = local_cache_dir
        self.text_padding = text_padding

        self._target_vec_col = "__target_vector__"

        self.backend: Optional[BaseBackend] = None
        self.model: Optional[SingleBackboneClassifier] = None
        self.trainer: Optional[Trainer] = None
        self.compute_metrics = None
        self.progress_callback: Optional[PbarConsoleLogger] = None

        self._build_backend()

    def _build_backend(self):
        mods = set(self.modalities)
        name = self.backend_name

        if name == "auto":
            if mods == {"text", "image"}:
                self.text_model_config = self.text_model_config or {
                    'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_length': self.clip_max_length
                }
                self.image_model_config = self.image_model_config or {
                    'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'
                }
            elif mods == {"text", "audio"}:
                self.text_model_config = self.text_model_config or {
                    'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_length': 64
                }
                self.audio_model_config = self.audio_model_config or {
                    'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000
                }
            else:
                if "text" in mods and self.text_model_config is None:
                    self.text_model_config = {'checkpoint': 'bert-base-multilingual-cased', 'model_type': 'bert', 'max_length': 512}
                if "image" in mods and self.image_model_config is None:
                    self.image_model_config = {'checkpoint': 'google/vit-base-patch16-224', 'model_type': 'vit', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'}
                if "audio" in mods and self.audio_model_config is None:
                    self.audio_model_config = {'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000}

        elif name == "clip":
            self.text_model_config = self.text_model_config or {
                'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_length': self.clip_max_length
            }
            self.image_model_config = self.image_model_config or {
                'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'
            }
        elif name == "clap":
            self.text_model_config = self.text_model_config or {
                'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_length': 64
            }
            self.audio_model_config = self.audio_model_config or {
                'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000
            }
        else:
            pass

        self.backend = UniversalMultiBackend(
            text_model_config=self.text_model_config if "text" in mods else None,
            image_model_config=self.image_model_config if "image" in mods else None,
            audio_model_config=self.audio_model_config if "audio" in mods else None,
            freeze=self.freeze_backbone,
            text_tokenizer_fn=self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            tokenizer_returns_tensors=self.tokenizer_returns_tensors,
            use_batch_tokenizer=self.use_batch_tokenizer,
            tokenizer_cache_size=self.tokenizer_cache_size,
            tokenizer_batch_size=self.pretokenize_batch_size,
            local_cache_dir=self.local_cache_dir,
            text_padding_strategy=self.text_padding
        )

        if not set(self.modalities).issubset(self.backend.supported):
            raise ValueError(f"Бэкенд {self.backend.name} не поддерживает модальности {self.modalities}")

    def _setup_metrics(self, metric_name: str):
        name = metric_name.lower()
        if name not in ("rmse", "mae", "r2"):
            raise ValueError('metric_name для регрессии должен быть "rmse", "mae" или "r2"')

        def compute(p):
            preds = p.predictions
            y = p.label_ids
            preds = preds.squeeze(-1) if preds.ndim == 2 and preds.shape[-1] == 1 else preds
            y = y.squeeze(-1) if y.ndim == 2 and y.shape[-1] == 1 else y
            axis = 0 if preds.ndim == 2 else None
            if name == "rmse":
                err = preds - y
                mse = np.mean(err**2, axis=axis)
                rmse = np.sqrt(mse)
                return {"rmse": float(np.mean(rmse))}
            elif name == "mae":
                mae = np.mean(np.abs(preds - y), axis=axis)
                return {"mae": float(np.mean(mae))}
            else:
                y_mean = np.mean(y, axis=axis, keepdims=True) if preds.ndim == 2 else np.mean(y)
                ss_res = np.sum((y - preds) ** 2, axis=axis)
                ss_tot = np.sum((y - y_mean) ** 2, axis=axis)
                r2 = 1.0 - (ss_res / (ss_tot + 1e-12))
                return {"r2": float(np.mean(r2))}

        self.compute_metrics = compute

    def _split(self, df: pd.DataFrame, test_size: float = 0.2, seed: int = 42):
        df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
        n_eval = int(math.ceil(len(df) * test_size))
        return df.iloc[n_eval:].reset_index(drop=True), df.iloc[:n_eval].reset_index(drop=True)

    def _validate_data_modalities(self, df: pd.DataFrame):
        if "text" in self.modalities:
            if not self.text_columns:
                raise ValueError("Вы выбрали модальность 'text', но text_columns пустой.")
            missing = [c for c in self.text_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют текстовые колонки: {missing}")
        if "image" in self.modalities:
            if not self.image_columns:
                raise ValueError("Вы выбрали модальность 'image', но image_columns пуст.")
            missing = [c for c in self.image_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют колонки изображений: {missing}")
        if "audio" in self.modalities:
            if not self.audio_columns:
                raise ValueError("Вы выбрали модальность 'audio', но audio_columns пуст.")
            missing = [c for c in self.audio_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют колонки аудио: {missing}")

    def _validate_targets_present(self, df: pd.DataFrame):
        missing = [c for c in self.target_column_names if c not in df.columns]
        if missing:
            raise ValueError(f"В DataFrame отсутствуют целевые колонки: {missing}")

    def _attach_target_vector(self, df: pd.DataFrame, fill_zeros: bool = False) -> pd.DataFrame:
        df_c = df.copy()
        K = self.num_labels
        if fill_zeros:
            df_c[self._target_vec_col] = [np.zeros(K, dtype=np.float32) for _ in range(len(df_c))]
        else:
            def _row_to_vec(row):
                vals = [row[c] for c in self.target_column_names]
                return np.asarray(vals, dtype=np.float32)
            df_c[self._target_vec_col] = df_c.apply(_row_to_vec, axis=1)
        return df_c

    def fit(
        self,
        train_data: pd.DataFrame,
        epochs: int = 3,
        test_size: float = 0.2,
        test_data: Optional[pd.DataFrame] = None,
        per_device_train_batch_size: int = 16,
        gradient_accumulation_steps: int = 1,
        learning_rate: float = 2e-4,
        metric_name: str = "rmse",
        fp16: bool = True,
        logging_steps: int = 50,
        eval_steps: int = 200,
        output_dir: str = "./result_reg",
        seed: int = 42,
        hidden: int = 256,
        dropout: float = 0.1,
        gradient_checkpointing: bool = False,
        fit_chunk_size: Optional[int] = None,
        clear_cache_every_n_chunks: int = 10
    ):
        self._validate_data_modalities(train_data)
        self._validate_targets_present(train_data)
        set_seed(seed)

        df_train, df_eval = (train_data, test_data) if test_data is not None else self._split(train_data, test_size, seed)
        if test_data is not None:
            self._validate_targets_present(test_data)

        df_train_ext = self._attach_target_vector(df_train, fill_zeros=False)
        df_eval_ext = self._attach_target_vector(df_eval, fill_zeros=False)

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds_eval = MultiComboDataset(
            df=df_eval_ext,
            target_col=self._target_vec_col,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=(self.pretokenize_data and has_bt and len(df_eval_ext) < 50000 and self.text_padding != "dynamic"),
            pretokenize_batch_size=self.pretokenize_batch_size,
            max_cache_size=min(len(df_eval_ext), self.max_pretokenize_samples),
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        self.model = SingleBackboneClassifier(
            backend=self.backend,
            modalities=self.modalities,
            num_labels=self.num_labels,
            fusion=self.fusion,
            hidden=hidden,
            dropout=dropout
        )
        if gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

        self._setup_metrics(metric_name)

        args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=max(4, per_device_train_batch_size // 2),
            gradient_accumulation_steps=gradient_accumulation_steps,
            eval_accumulation_steps=max(1, gradient_accumulation_steps * 2),
            learning_rate=learning_rate,
            warmup_ratio=0.1,
            lr_scheduler_type="cosine",
            weight_decay=0.0,
            eval_strategy="steps",
            eval_steps=eval_steps,
            save_strategy="steps",
            save_steps=eval_steps,
            load_best_model_at_end=True,
            metric_for_best_model=f"eval_{metric_name}",
            save_total_limit=1,
            logging_strategy="steps",
            logging_steps=logging_steps,
            report_to="none",
            fp16=fp16 and torch.cuda.is_available(),
            dataloader_num_workers=min(4, os.cpu_count() or 4),
            seed=seed,
            remove_unused_columns=False,
            gradient_checkpointing=gradient_checkpointing,
            dataloader_pin_memory=True,
            ddp_find_unused_parameters=False,
            disable_tqdm=True
        )

        def regression_collator(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            out = self.backend.collate(batch_list)
            if "labels" in out:
                out["labels"] = out["labels"].to(torch.float32)
            return out

        def steps_for_size(sz: int, bsz: int, accum: int) -> int:
            return max(0, math.ceil(math.ceil(sz / max(1, bsz)) / max(1, accum)))

        def chunk_slices(index_array: np.ndarray, chunk_size: int):
            for i in range(0, len(index_array), chunk_size):
                yield index_array[i:i + chunk_size]

        n_train = len(df_train_ext)
        rng = np.random.default_rng(seed)
        train_idx = np.arange(n_train)
        chunk_size = fit_chunk_size if (fit_chunk_size and fit_chunk_size > 0) else len(train_idx)

        total_steps = 0
        for _ in range(epochs):
            rng.shuffle(train_idx)
            for slc in chunk_slices(train_idx, chunk_size):
                total_steps += steps_for_size(len(slc), per_device_train_batch_size, gradient_accumulation_steps)

        dummy_idx = np.arange(min(len(df_train_ext), 1))
        ds_train_init = (
            MultiComboDataset(
                df=df_train_ext.iloc[dummy_idx],
                target_col=self._target_vec_col,
                text_columns=self.text_columns,
                image_columns=self.image_columns,
                audio_columns=self.audio_columns,
                text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
                text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
                special_tokens=self.special_tokens,
                pretokenize=False,
                tokenizer_returns_tensors=self.tokenizer_returns_tensors
            ) if len(dummy_idx) > 0 else ds_eval
        )

        self.trainer = MSETrainer(
            model=self.model,
            args=args,
            train_dataset=ds_train_init,
            eval_dataset=ds_eval,
            data_collator=regression_collator,
            compute_metrics=self.compute_metrics
        )
        self.trainer.remove_callback(PrinterCallback)

        if total_steps > 0:
            self.trainer.create_optimizer_and_scheduler(num_training_steps=total_steps)

        pbar = tqdm(total=total_steps, desc="Training Progress", unit="step")
        cb = PbarConsoleLogger(pbar)
        self.trainer.add_callback(cb)
        self.progress_callback = cb

        steps_done = 0
        chunk_counter = 0

        for ep in range(epochs):
            rng = np.random.default_rng(seed + ep)
            shuffled = np.arange(n_train)
            rng.shuffle(shuffled)
            for slc in chunk_slices(shuffled, chunk_size):
                chunk_df = df_train_ext.iloc[slc]
                ds_chunk = MultiComboDataset(
                    df=chunk_df,
                    target_col=self._target_vec_col,
                    text_columns=self.text_columns,
                    image_columns=self.image_columns,
                    audio_columns=self.audio_columns,
                    text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
                    text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
                    special_tokens=self.special_tokens,
                    pretokenize=(self.pretokenize_data and has_bt and self.text_padding != "dynamic"
                                 and len(slc) < self.max_pretokenize_samples and len(slc) > 100),
                    pretokenize_batch_size=self.pretokenize_batch_size,
                    max_cache_size=min(len(slc), self.max_pretokenize_samples),
                    tokenizer_returns_tensors=self.tokenizer_returns_tensors
                )
                self.trainer.train_dataset = ds_chunk

                chunk_steps = steps_for_size(len(slc), per_device_train_batch_size, gradient_accumulation_steps)
                if chunk_steps == 0:
                    del ds_chunk
                    continue

                self.trainer.args.max_steps = steps_done + chunk_steps
                self.trainer.train()
                steps_done += chunk_steps

                chunk_counter += 1
                if chunk_counter % clear_cache_every_n_chunks == 0:
                    if hasattr(ds_chunk, 'clear_cache'):
                        ds_chunk.clear_cache()
                        print(f"✓ Очищен кэш токенизации после {chunk_counter} чанков")

                del ds_chunk
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        pbar.close()

        if getattr(self.backend, "batch_tokenizer", None):
            self.backend.batch_tokenizer.clear_cache()

        return self

    def predict(
        self,
        df: pd.DataFrame,
        batch_size: Optional[int] = None
    ) -> np.ndarray:
        if self.trainer is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        df_c = self._attach_target_vector(df, fill_zeros=True)

        print(f"Preparing dataset for prediction ({len(df_c)} samples)...")

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds = MultiComboDataset(
            df=df_c,
            target_col=self._target_vec_col,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=(self.pretokenize_data and has_bt and len(df_c) < 10000 and self.text_padding != "dynamic"),
            pretokenize_batch_size=self.pretokenize_batch_size,
            max_cache_size=min(len(df_c), 10000),
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        if batch_size:
            original_bs = self.trainer.args.per_device_eval_batch_size
            self.trainer.args.per_device_eval_batch_size = batch_size

        effective_batch_size = batch_size or self.trainer.args.per_device_eval_batch_size
        num_batches = (len(df_c) + effective_batch_size - 1) // effective_batch_size

        print(f"Running predictions (batch_size={effective_batch_size}, num_batches={num_batches})...")

        original_disable_tqdm = self.trainer.args.disable_tqdm
        self.trainer.args.disable_tqdm = False

        preds = self.trainer.predict(test_dataset=ds)

        self.trainer.args.disable_tqdm = original_disable_tqdm

        if batch_size:
            self.trainer.args.per_device_eval_batch_size = original_bs

        if hasattr(ds, 'clear_cache'):
            ds.clear_cache()

        y = preds.predictions
        y = y.squeeze(-1) if y.ndim == 2 and y.shape[-1] == 1 else y
        return y

    def get_embeddings(self, df: pd.DataFrame, batch_size: int = 32, return_per_modality: bool = False):
        if self.trainer is None or self.model is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        try:
            device = next(self.trainer.model.parameters()).device
        except StopIteration:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model.to(device).eval()

        df_c = self._attach_target_vector(df, fill_zeros=True)

        print(f"Preparing dataset for embeddings extraction ({len(df_c)} samples)...")

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds = MultiComboDataset(
            df=df_c,
            target_col=self._target_vec_col,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=False,
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        def collate(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            return self.backend.collate(batch_list)

        def move_to_device(obj, device: torch.device):
            if torch.is_tensor(obj):
                return obj.to(device)
            if isinstance(obj, dict):
                return {k: move_to_device(v, device) for k, v in obj.items()}
            if isinstance(obj, (list, tuple)):
                t = [move_to_device(v, device) for v in obj]
                return type(obj)(t) if not isinstance(obj, list) else t
            return obj

        loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
        fused_list = []
        per_mod_lists = {m: [] for m in self.modalities} if return_per_modality else None

        num_batches = (len(df_c) + batch_size - 1) // batch_size

        print(f"Extracting embeddings (batch_size={batch_size}, num_batches={num_batches})...")

        with torch.no_grad():
            for batch in tqdm(loader, total=num_batches, desc="Extracting embeddings", unit="batch", leave=True):
                bi = move_to_device(batch["backend_inputs"], device)
                fused, per = self.model.get_embeddings(backend_inputs=bi, return_per_modality=True)
                fused_list.append(fused.cpu().numpy())
                if return_per_modality:
                    for m in per_mod_lists.keys():
                        if m in per:
                            per_mod_lists[m].append(per[m].cpu().numpy())

        print("Concatenating embeddings...")
        fused_arr = np.vstack(fused_list)

        if not return_per_modality:
            print(f"✓ Embeddings shape: {fused_arr.shape}")
            return fused_arr

        per_mod = {m: np.vstack(chunks) for m, chunks in per_mod_lists.items()}
        print(f"✓ Fused embeddings shape: {fused_arr.shape}")
        for m, arr in per_mod.items():
            print(f"✓ {m.capitalize()} embeddings shape: {arr.shape}")

        return fused_arr, per_mod

Пример использования 1.

In [None]:
import numpy as np
import pandas as pd

# Данные
n = 8
df = pd.DataFrame({
    "title": [f"some short text {i}" for i in range(n)],
    "audio": [(np.sin(np.linspace(0, 2*np.pi, 48000)).astype(np.float32) * 0.1) for _ in range(n)],
    "target": np.random.randn(n).astype(np.float32)
})

# Инициализация пайплайна (text=CLIP, audio=CLAP)
pipe = SingleModelMultiComboRegression(
    modalities=["text","audio"],
    target_column_names=["target"],
    text_columns=["title"],
    audio_columns=["audio"],
    text_model_config={
        "checkpoint": "openai/clip-vit-base-patch32",
        "model_type": "clip",
        "max_length": 77
    },
    audio_model_config={
        "checkpoint": "laion/clap-htsat-unfused",
        "model_type": "clap",
        "sr": 48000,
        "max_audios": 2,
        "audio_agg": "concat"
    },
    fusion="concat",
    freeze_backbone=True,
    use_batch_tokenizer=True,
    pretokenize_data=True,
    pretokenize_batch_size=64,
    tokenizer_cache_size=10000,
    max_pretokenize_samples=50000,
    local_cache_dir="./model_cache"
)

# Тренировка
pipe.fit(
    train_data=df,
    epochs=2,
    test_size=0.25,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    metric_name="rmse",
    fp16=True,                       # будет True, только если есть CUDA
    logging_steps=10,
    eval_steps=10,
    output_dir="./out_clip_clap",
    seed=42,
    hidden=512,
    dropout=0.1,
    gradient_checkpointing=True,     # включить GC там, где поддерживается
    fit_chunk_size=None,
    clear_cache_every_n_chunks=2
)

# Предсказания
y_pred = pipe.predict(df, batch_size=4)
print("Pred shape:", y_pred.shape)

# Эмбеддинги
emb_fused, emb_per = pipe.get_embeddings(df, batch_size=4, return_per_modality=True)
print("Fused:", emb_fused.shape, "| text:", emb_per["text"].shape, "| audio:", emb_per["audio"].shape)

Пример использования 2.

In [None]:
import numpy as np
import pandas as pd

# Данные
n = 9
df = pd.DataFrame({
    "text": [f"example text #{i}" for i in range(n)],
    "image": [(np.random.rand(224,224,3) * 255).astype(np.uint8) for _ in range(n)],
    "audio": [(np.random.randn(16000).astype(np.float32)*0.05) for _ in range(n)],
    "target": np.random.randn(n).astype(np.float32)
})

pipe = SingleModelMultiComboRegression(
    modalities=["text","image","audio"],
    target_column_names=["target"],
    text_columns=["text"],
    image_columns=["image"],
    audio_columns=["audio"],
    text_model_config={
        "checkpoint": "distilbert-base-uncased",
        "model_type": "auto",
        "max_length": 128
    },
    image_model_config={
        "checkpoint": "google/vit-base-patch16-224",
        "model_type": "vit",
        "max_images": 2,
        "image_agg": "mean"
    },
    audio_model_config={
        "checkpoint": "facebook/wav2vec2-base-960h",
        "model_type": "auto",
        "sr": 16000,
        "max_audios": 2,
        "audio_agg": "mean"
    },
    fusion="concat",
    freeze_backbone=True,
    use_batch_tokenizer=True,
    pretokenize_data=True,
    pretokenize_batch_size=32,
    tokenizer_cache_size=20000,
    max_pretokenize_samples=100000,
    local_cache_dir="./model_cache"
)

pipe.fit(
    train_data=df,
    epochs=2,
    test_size=0.3,
    per_device_train_batch_size=3,
    gradient_accumulation_steps=1,
    learning_rate=1e-4,
    metric_name="r2",
    fp16=True,
    logging_steps=5,
    eval_steps=5,
    output_dir="./out_auto_triplet",
    seed=123,
    hidden=384,
    dropout=0.1,
    gradient_checkpointing=True,
    fit_chunk_size=6,                # демонстрация чанковой тренировки
)

y = pipe.predict(df, batch_size=3)
print("Pred shape:", y.shape)

fused, per_mod = pipe.get_embeddings(df, batch_size=3, return_per_modality=True)
print("Fused:", fused.shape, "| text:", per_mod["text"].shape, "| image:", per_mod["image"].shape, "| audio:", per_mod["audio"].shape)

Пример использования 3.

In [None]:
import numpy as np
import pandas as pd

# Данные
n = 5
df = pd.DataFrame({
    "audio": [(np.sin(np.linspace(0, 6.28, 16000)).astype(np.float32)*0.05) for _ in range(n)],
    "target": np.random.randn(n)
})

pipe = SingleModelMultiComboRegression(
    modalities=["audio"],
    target_column_names=["target"],
    audio_columns=["audio"],
    audio_model_config={"model_type": "wav2clip", "sr": 16000, "max_audios": 1, "audio_agg": "mean"},
    fusion="concat",
    freeze_backbone=True,
    use_batch_tokenizer=False,      # текст не используется
    pretokenize_data=False
)

pipe.fit(df, epochs=1, test_size=0.33, per_device_train_batch_size=2, eval_steps=1, logging_steps=1, output_dir="./out_w2c")

pred = pipe.predict(df)
print("Pred shape:", pred.shape)

emb = pipe.get_embeddings(df)
print("Embeddings:", emb.shape)    # ожидаемо [N, 512]

Пример использования 4.

In [None]:
import numpy as np
import pandas as pd
from huggingface_hub import login
login('hf_флжптфджуртн ижщрнтур пфмудгьтпруцдждсп йжт жтжщ йр45цт зз н нй2хщй тэщш рэйффыуео зрьт тжвэвцу5фвое')  # ваш HF токен

# Данные
n = 10
df = pd.DataFrame({
    "text": [f"example text #{i}" for i in range(n)],
    "target_1": np.random.randn(n),
    "target_2": np.random.randn(n)
})

pipe = SingleModelMultiComboRegression(
    modalities=["text"],
    target_column_names=["target_1", "target_2"],
    text_columns=["text"],
    text_model_config={
        "checkpoint": "google/embeddinggemma-300m",
        "max_length": 2048
    },
    text_padding="dynamic",
    freeze_backbone=False,
    pretokenize_data=True
)

pipe.fit(df, epochs=1, test_size=0.33, per_device_train_batch_size=2, eval_steps=1, logging_steps=1, output_dir="./test123")

pred = pipe.predict(df)
print("Pred shape:", pred.shape)

emb = pipe.get_embeddings(df)
print("Embeddings:", emb.shape)    # ожидаемо [N, 512]

# Дообучение классификатора с RuCLIP, который работает ещё и со звуком.

Реализация классификатора.

In [None]:
!pip install --upgrade --no-cache-dir \
  --extra-index-url https://download.pytorch.org/whl/cu124 \
  torch==2.6.0+cu124 torchaudio==2.6.0+cu124 \

!pip install --extra-index-url https://download.pytorch.org/whl/cu124 \
  torch==2.6.0+cu124 torchaudio==2.6.0+cu124
!pip install "open-clip-torch==2.26.1" "transformers==4.51.3" "evaluate==0.4.5"
!pip install "ruclip @ git+https://github.com/ai-forever/ru-clip.git@main#egg=ruclip"

import os, math, random, gc, time
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

import evaluate
from transformers import TrainingArguments, Trainer
from transformers.trainer_callback import TrainerCallback, PrinterCallback, EarlyStoppingCallback
from transformers.modeling_outputs import SequenceClassifierOutput

# =========================
# Утилиты
# =========================

def set_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def to_pil(x: Union[str, np.ndarray, 'Image.Image']) -> 'Image.Image':
    if isinstance(x, Image.Image): return x.convert("RGB")
    if isinstance(x, str): return Image.open(x).convert("RGB")
    if isinstance(x, np.ndarray): return Image.fromarray(x).convert("RGB")
    raise ValueError("Ожидается путь/np.ndarray/PIL.Image")

def load_audio(path: str, target_sr: int) -> np.ndarray:
    try:
        import torchaudio
    except Exception as e:
        raise RuntimeError("Требуется torchaudio: pip install torchaudio") from e
    waveform, sr = torchaudio.load(path)
    if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True)
    if sr != target_sr:
        waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=target_sr)
    return waveform.squeeze(0).numpy().astype(np.float32)


# =========================
# RuCLIP токенизатор (батч + кэш)
# =========================

class RuCLIPBatchTokenizer:
    def __init__(self, tokenizer, context_length: int = 77, cache_size: int = 20000):
        self.tokenizer = tokenizer
        self.context_length = context_length
        self._cache = lru_cache(maxsize=cache_size)(self._tok_one)

    def _tok_one(self, text: str) -> np.ndarray:
        # open_clip токенизатор принимает список строк и возвращает LongTensor [B, L]
        ids = self.tokenizer([text], context_length=self.context_length)
        return ids.squeeze(0).cpu().numpy().astype(np.int64)

    def tokenize_batch(self, texts: List[str]) -> torch.Tensor:
        if len(texts) < 100:
            arrs = [self._cache(t) for t in texts]
            return torch.from_numpy(np.stack(arrs, axis=0)).long()
        return self.tokenizer(texts, context_length=self.context_length).long()

    def clear_cache(self):
        self._cache.cache_clear()


# =========================
# Датасет
# =========================

class MultiComboDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        target_col: str,
        label2id: Dict[Any, int],
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_batch_tokenizer: Optional[RuCLIPBatchTokenizer] = None,
        pretokenize_text: bool = True,
        pretokenize_batch_size: int = 2048,
        ruclip_context_len: int = 77,
        audio_sr: int = 48000
    ):
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.target_col = target_col
        self.label2id = label2id
        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []
        self.text_tok = text_batch_tokenizer
        self.context_len = ruclip_context_len
        self.audio_sr = audio_sr

        self._N = len(self.df)
        if self.target_col in self.df.columns:
            y = self.df[self.target_col].map(self.label2id).fillna(0).astype(int).values
        else:
            y = np.zeros(self._N, dtype=np.int64)
        self._labels = torch.tensor(y, dtype=torch.long)

        self._image_lists = self._collect_multi(self.df, self.image_columns) if self.image_columns else None
        self._audio_lists = self._collect_multi(self.df, self.audio_columns) if self.audio_columns else None

        self._text_bank: Optional[torch.Tensor] = None
        if self.text_columns and pretokenize_text and self.text_tok is not None:
            texts = [self._join_text(self.df.iloc[i]) for i in range(self._N)]
            chunks = []
            for i in range(0, len(texts), pretokenize_batch_size):
                chunks.append(self.text_tok.tokenize_batch(texts[i:i+pretokenize_batch_size]))
            self._text_bank = torch.cat(chunks, dim=0).contiguous()
            print(f"✓ Предтокенизация RuCLIP: shape={tuple(self._text_bank.shape)}")

    def __len__(self): return self._N

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item: Dict[str, Any] = {"labels": int(self._labels[idx])}
        if self._text_bank is not None:
            item["text_tokens"] = self._text_bank[idx]
        elif self.text_columns:
            item["text"] = self._join_text(self.df.iloc[idx])
        if self._image_lists is not None:
            item["images"] = self._image_lists[idx]
        if self._audio_lists is not None:
            item["audios"] = self._audio_lists[idx]
        return item

    @staticmethod
    def _as_list(v):
        if v is None or (isinstance(v, float) and np.isnan(v)): return []
        if isinstance(v, (list, tuple)): return list(v)
        return [v]

    def _collect_multi(self, df: pd.DataFrame, cols: List[str]) -> List[List[Any]]:
        out = []
        for _, row in df.iterrows():
            lst = []
            for c in cols:
                if c in row: lst.extend([x for x in self._as_list(row[c]) if x is not None])
            out.append(lst)
        return out

    def _join_text(self, row: pd.Series) -> str:
        parts = []
        for c in self.text_columns:
            v = row.get(c, "")
            if pd.isna(v): v = ""
            parts.append(str(v))
        return " [SEP] ".join(parts)

    def clear_cache(self):
        self._text_bank = None
        torch.cuda.empty_cache()


# =========================
# Бэкенд: RuCLIP (+ опционально аудио через CLAP)
# =========================

class RuCLIPBackend(nn.Module):
    name = "ruclip"
    def __init__(
        self,
        ruclip_model_name: str = "ViT-B-32",
        ruclip_pretrained: Optional[str] = "hf-hub:ai-forever/ru-clip-vit-base-patch32-224",
        ruclip_context_len: int = 77,
        max_images: int = 1,
        image_agg: str = "concat",          # concat|mean
        max_audios: int = 1,
        audio_agg: str = "concat",          # concat|mean
        audio_cfg: Optional[Dict[str, Any]] = None,   # None или {'type':'clap','checkpoint':..., 'sr':48000}
        freeze: bool = True,
        device: Optional[torch.device] = None
    ):
        super().__init__()
        import open_clip

        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))

        # Детектируем «родной» RuCLIP от ai-forever (через ruclip)
        use_ruclip_pkg = False
        ruclip_repo = (ruclip_pretrained or "")  # может быть None
        if ruclip_repo.startswith("hf-hub:"):
            repopath = ruclip_repo.replace("hf-hub:", "")
        else:
            repopath = ruclip_repo

        base_name = repopath.split("/")[-1] if repopath else ""
        # Признак: репо ai-forever/ru-clip-... или имя ruclip-...
        if ("ai-forever/ru-clip-" in repopath) or base_name.startswith("ru-clip-") or base_name.startswith("ruclip-"):
            use_ruclip_pkg = True

        print(f"Загрузка RuCLIP: {ruclip_model_name} / {ruclip_pretrained}")

        if use_ruclip_pkg:
            try:
                import ruclip
            except Exception as e:
                raise RuntimeError(
                    "Для загрузки RuCLIP из HF нужен пакет 'ruclip'. Установите: pip install ruclip==0.0.2"
                ) from e

            # ruclip.load ожидает имя вроде "ruclip-vit-base-patch32-224"
            # Если пришло "ru-clip-...", заменим на "ruclip-..."
            ruclip_id = base_name
            if ruclip_id.startswith("ru-clip-"):
                ruclip_id = ruclip_id.replace("ru-clip-", "ruclip-", 1)

            clip_model, processor = ruclip.load(ruclip_id, device=str(self.device))
            self.ruclip_model = clip_model.eval()

            # Processor в ruclip обычно хранит препроцесс изображений и токенизацию
            self.ruclip_preprocess = getattr(processor, "preprocess", processor)

            # Токенизатор: берём из processor, если есть; иначе — из open_clip
            proc_tokenizer = getattr(processor, "tokenizer", None)
            self.ruclip_tokenizer = proc_tokenizer or open_clip.get_tokenizer(ruclip_model_name)

        else:
            # Фоллбек: open-clip (поддерживает openai/laion2b/... или pretrained=None)
            try:
                self.ruclip_model, self.ruclip_preprocess, _ = open_clip.create_model_and_transforms(
                    ruclip_model_name, pretrained=ruclip_pretrained
                )
            except Exception:
                # Попытка через HF-путь с флагом pretrained_hf
                if isinstance(ruclip_pretrained, str):
                    hf_path = ruclip_pretrained.replace("hf-hub:", "")
                else:
                    hf_path = ruclip_pretrained
                try:
                    self.ruclip_model, _, self.ruclip_preprocess = open_clip.create_model_and_transforms(
                        ruclip_model_name, pretrained=hf_path, pretrained_hf=True
                    )
                except Exception:
                    # Самый глубокий фоллбек: создать модель и взять стандартные трансформы
                    self.ruclip_model = open_clip.create_model(ruclip_model_name, pretrained=hf_path)
                    try:
                        _, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
                            ruclip_model_name, pretrained=None
                        )
                        self.ruclip_preprocess = preprocess_val
                    except Exception:
                        # Минимальный фоллбек на torchvision.transforms
                        from torchvision import transforms
                        self.ruclip_preprocess = transforms.Compose([
                            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
                            transforms.CenterCrop(224),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                                                 std=(0.26862954, 0.26130258, 0.27577711)),
                        ])
            self.ruclip_tokenizer = open_clip.get_tokenizer(ruclip_model_name)

        self.context_len = ruclip_context_len
        self.text_tok = RuCLIPBatchTokenizer(self.ruclip_tokenizer, context_length=self.context_len)

        self.max_images = int(max_images); self.image_agg = image_agg
        self.max_audios = int(max_audios); self.audio_agg = audio_agg

        # Аудио: только CLAP
        self.audio_model = None
        self.clap_processor = None
        self.audio_sr = 48000
        self.audio_enabled = False

        if audio_cfg is not None:
            at = str(audio_cfg.get("type", "")).lower()
            if at != "clap":
                raise ValueError("audio_cfg['type'] должен быть 'clap' или None.")
            from transformers import ClapModel, ClapProcessor
            ckpt = audio_cfg.get("checkpoint", "laion/clap-htsat-unfused")
            print(f"Загрузка CLAP: {ckpt}")
            self.audio_model = ClapModel.from_pretrained(ckpt)
            self.clap_processor = ClapProcessor.from_pretrained(ckpt)
            sr = getattr(self.clap_processor, "sampling_rate", None)
            if sr is None:
                fe = getattr(self.clap_processor, "feature_extractor", None)
                sr = getattr(fe, "sampling_rate", 48000) if fe is not None else 48000
            self.audio_sr = int(audio_cfg.get("sr", sr))
            self.audio_enabled = True

        self.ruclip_model.to(self.device).eval()
        if isinstance(self.audio_model, nn.Module):
            self.audio_model.to(self.device).eval()

        if freeze:
            for p in self.ruclip_model.parameters(): p.requires_grad = False
            if isinstance(self.audio_model, nn.Module):
                for p in self.audio_model.parameters(): p.requires_grad = False

        # Размерности эмбеддингов
        self.ruclip_dim = self._infer_ruclip_dim()
        self.out_dim_per_modality = {
            "text": self.ruclip_dim,
            "image": self.ruclip_dim if self.image_agg == "mean" else self.ruclip_dim * self.max_images
        }
        if self.audio_enabled:
            ad = getattr(getattr(self.audio_model, "config", None), "projection_dim", 512)
            self.audio_dim = int(ad)
            self.out_dim_per_modality["audio"] = self.audio_dim if self.audio_agg == "mean" else self.audio_dim * self.max_audios
        else:
            self.audio_dim = 0

    def _infer_ruclip_dim(self) -> int:
        # Надёжное извлечение проекционной размерности текста
        if hasattr(self.ruclip_model, "text_projection"):
            proj = self.ruclip_model.text_projection
            if hasattr(proj, 'shape'):   # nn.Parameter
                return int(proj.shape[0])
            if hasattr(proj, 'weight'):  # nn.Linear
                return int(proj.weight.shape[1])
        # Фоллбек: прогон фиктивного токена
        with torch.no_grad():
            # токенизатор open-clip совместим с .to(device)
            ids = self.ruclip_tokenizer(["test"], context_length=self.context_len).to(self.device)
            z = F.normalize(self.ruclip_model.encode_text(ids), dim=-1)
        return int(z.shape[-1])

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        labels = torch.tensor([b.get("labels", 0) for b in batch], dtype=torch.long)

        # Текст
        if "text_tokens" in batch[0]:
            text_ids = torch.stack([b["text_tokens"] for b in batch], dim=0)
        elif "text" in batch[0]:
            texts = [b.get("text", "") or " " for b in batch]
            text_ids = self.text_tok.tokenize_batch(texts)
        else:
            raise ValueError("Ожидается модальность 'text' для RuCLIP")

        # Картинки
        images_lists = [b.get("images", []) for b in batch]
        flat_images, img_counts = [], []
        for lst in images_lists:
            lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
            lst = [img for img in lst if img is not None]
            img_counts.append(len(lst))
            for img in lst:
                flat_images.append(to_pil(img))
        if len(flat_images) > 0:
            px = torch.stack([self.ruclip_preprocess(img) for img in flat_images], dim=0)
        else:
            px = torch.empty(0)

        # Аудио (опционально) — CLAP
        aud_counts = None
        audio_pack: Dict[str, Any] = {}
        if self.audio_enabled:
            aud_lists = [b.get("audios", []) for b in batch]
            flat_audios, aud_counts = [], []
            for lst in aud_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [a for a in lst if a is not None]
                aud_counts.append(len(lst))
                for a in lst:
                    if isinstance(a, str):
                        flat_audios.append(load_audio(a, self.audio_sr))
                    elif isinstance(a, np.ndarray):
                        aa = np.asarray(a, dtype=np.float32)
                        if aa.ndim > 1: aa = np.squeeze(aa)
                        if aa.ndim > 1: aa = aa.reshape(-1)
                        flat_audios.append(aa)
            if len(flat_audios) > 0:
                proc = self.clap_processor(audios=flat_audios, sampling_rate=self.audio_sr, padding=True, return_tensors="pt")
                audio_pack["features"] = proc["input_features"]
            else:
                audio_pack["features"] = torch.empty(0)

        return {
            "labels": labels,
            "backend_inputs": {
                "text_ids": text_ids,
                "pixel_values": px,
                "image_counts": torch.tensor(img_counts, dtype=torch.long),
                "audio": audio_pack if self.audio_enabled else None,
                "audio_counts": torch.tensor(aud_counts, dtype=torch.long) if aud_counts is not None else None,
                "batch_size": len(batch),
            }
        }

    @torch.no_grad()
    def _aggregate(self, embs: Optional[torch.Tensor], counts: List[int], max_k: int, agg: str, dim_hint: int) -> torch.Tensor:
        device = self.device
        bs = len(counts)
        if embs is None or (torch.is_tensor(embs) and embs.numel() == 0):
            out_dim = dim_hint * max_k if agg == "concat" else dim_hint
            return torch.zeros((bs, out_dim), device=device, dtype=torch.float32)
        if embs.dim() == 1: embs = embs.unsqueeze(0)
        if embs.dim() > 2: embs = embs.view(embs.size(0), -1)
        N, D = embs.size()
        out_dim = D * max_k if agg == "concat" else D
        out = torch.zeros((bs, out_dim), device=device, dtype=embs.dtype)
        off = 0
        for i, c in enumerate(counts):
            if c <= 0 or off >= N: continue
            take_n = min(c, N - off)
            sample = embs[off:off+take_n]; off += take_n
            if agg == "concat":
                take = sample[:max_k]
                if take.size(0) < max_k:
                    pad = torch.zeros((max_k - take.size(0), D), device=device, dtype=embs.dtype)
                    take = torch.cat([take, pad], dim=0)
                out[i] = take.reshape(-1)
            else:
                out[i] = sample.mean(dim=0)
        return F.normalize(out, dim=-1, eps=1e-12) if out.size(1) > 0 else out

    @torch.no_grad()
    def encode(self, backend_inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        # Текст
        text_ids = backend_inputs["text_ids"].to(self.device)
        zt = F.normalize(self.ruclip_model.encode_text(text_ids), dim=-1)

        # Картинки
        px = backend_inputs["pixel_values"]
        img_counts = backend_inputs["image_counts"].tolist()
        total_imgs = sum(img_counts)
        zi_flat = None
        if isinstance(px, torch.Tensor) and px.numel() > 0 and total_imgs > 0:
            px = px.to(self.device)
            if px.size(0) > total_imgs:
                px = px[:total_imgs]
            zi_flat = F.normalize(self.ruclip_model.encode_image(px), dim=-1)
        zi = self._aggregate(zi_flat, img_counts, self.max_images, self.image_agg, self.ruclip_dim)

        out = {"text": zt, "image": zi}

        # Аудио (опционально) — CLAP
        if self.audio_enabled:
            ac = backend_inputs["audio_counts"].tolist() if backend_inputs["audio_counts"] is not None else [0]*zt.size(0)
            total_a = sum(ac)
            za_flat = None
            audio_pack = backend_inputs["audio"] or {}
            feats = audio_pack.get("features", None)
            if feats is not None and isinstance(feats, torch.Tensor) and feats.numel() > 0 and total_a > 0:
                feats = feats.to(self.device)
                if feats.size(0) > total_a:
                    feats = feats[:total_a]
                z = self.audio_model.get_audio_features(input_features=feats.float())
                za_flat = F.normalize(z.float(), dim=-1)
                self.audio_dim = int(za_flat.size(1))
            za = self._aggregate(za_flat, ac, self.max_audios, self.audio_agg, self.audio_dim)
            out["audio"] = za

        return out

    def get_out_dim(self, modality: str) -> int:
        return self.out_dim_per_modality.get(modality, 0)

    def get_text_tokenizer(self) -> RuCLIPBatchTokenizer:
        return self.text_tok


# =========================
# Классификатор
# =========================

class SingleBackboneClassifier(nn.Module):
    def __init__(self, backend: RuCLIPBackend, modalities: List[str], num_labels: int,
                 fusion: str = "concat", hidden: int = 512, dropout: float = 0.1):
        super().__init__()
        self.backend = backend
        self.modalities = modalities
        self.fusion = fusion
        self.num_labels = num_labels

        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        if fusion == "concat":
            in_dim = sum(self.backend.get_out_dim(m) for m in order)
        elif fusion == "mean":
            dims = [self.backend.get_out_dim(m) for m in order]
            if len(set(dims)) != 1:
                raise ValueError('Для fusion="mean" размеры модальностей должны совпадать')
            in_dim = dims[0]
        else:
            raise ValueError('fusion должен быть "concat" или "mean"')

        self.head = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_labels)
        )

    def _fuse(self, z: Dict[str, torch.Tensor]) -> torch.Tensor:
        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        feats = []
        for m in order:
            if m in z:
                t = z[m]
                if t.dim() > 2: t = t.view(t.size(0), -1)
                feats.append(t)
        if self.fusion == "concat":
            return torch.cat(feats, dim=-1)
        else:
            return torch.stack(feats, dim=0).mean(dim=0)

    def forward(self, backend_inputs: Dict[str, Any], labels: Optional[torch.Tensor] = None):
        z = self.backend.encode(backend_inputs)
        fused = self._fuse(z)
        logits = self.head(fused)
        return SequenceClassifierOutput(logits=logits)

    @torch.no_grad()
    def get_embeddings(self, backend_inputs: Dict[str, Any], return_per_modality: bool = False):
        z = self.backend.encode(backend_inputs)
        fused = self._fuse(z)
        if return_per_modality:
            return fused, z
        return fused


# =========================
# Trainer с весами классов и прогресс
# =========================

class WeightedCETrainer(Trainer):
    def __init__(self, *args, num_labels=None, train_labels=None, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            self.class_weights = torch.as_tensor(class_weights, dtype=torch.float32)
        elif train_labels is not None and num_labels is not None:
            y = np.asarray(train_labels).astype(int)
            counts = np.bincount(y, minlength=num_labels)
            n = counts.sum(); w = np.zeros(num_labels, dtype=np.float32)
            nz = counts > 0; w[nz] = n / (num_labels * counts[nz].astype(np.float32))
            self.class_weights = torch.tensor(w, dtype=torch.float32)
        else:
            self.class_weights = None

    def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch: Optional[int] = None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        weight = self.class_weights.to(logits.device) if self.class_weights is not None else None
        loss = F.cross_entropy(logits, labels.long(), weight=weight)
        return (loss, outputs) if return_outputs else loss


class PbarConsoleLogger(TrainerCallback):
    def __init__(self, pbar: tqdm):
        self.pbar = pbar
        self.last_logs = {}
        self.printed_eval_steps = set()

    def _step(self, state) -> int:
        return int(getattr(state, "global_step", 0) or 0)

    def _fmt_postfix(self):
        parts = []
        if 'loss' in self.last_logs:
            parts.append(f"loss {self.last_logs['loss']:.4f}")
        if 'eval_loss' in self.last_logs:
            parts.append(f"val {self.last_logs['eval_loss']:.4f}")
        for k, v in self.last_logs.items():
            if k.startswith('eval_') and k not in ('eval_loss','eval_runtime','eval_samples_per_second','eval_steps_per_second','epoch'):
                parts.append(f"{k.replace('eval_','')} {v:.4f}")
        return " | ".join(parts)

    def on_train_begin(self, args, state, control, **kwargs):
        max_steps = int(getattr(state, "max_steps", 0) or 0)
        if max_steps > 0:
            self.pbar.reset(total=max_steps)
        self.pbar.refresh()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs: return
        for k, v in logs.items():
            if isinstance(v, (int, float)):
                self.last_logs[k] = float(v)
        # синхронизируем прогресс
        self.on_step_end(args, state, control)
        # печать строк валидации
        if any(k.startswith('eval_') for k in logs.keys()):
            step = self._step(state)
            if step not in self.printed_eval_steps:
                self.printed_eval_steps.add(step)
                excl = {'eval_loss','eval_runtime','eval_samples_per_second','eval_steps_per_second','epoch'}
                extra = [f"{k.replace('eval_','')}: {float(v):.6f}" for k,v in logs.items() if k.startswith('eval_') and k not in excl]
                from tqdm.auto import tqdm as _tqdm
                _tqdm.write(f"step {step} | " + ", ".join(extra))

    def on_step_end(self, args, state, control, **kwargs):
        g = self._step(state)
        if self.pbar.total:
            n = min(g, self.pbar.total)
        else:
            n = g
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        if self.last_logs:
            self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

    def on_train_end(self, args, state, control, **kwargs):
        try:
            g = self._step(state)
            if self.pbar.total and g > self.pbar.n:
                self.pbar.update(g - self.pbar.n)
        finally:
            self.pbar.close()


# =========================
# Пайплайн: RuCLIP классификация
# =========================

class RuCLIPMultiModalClassification:
    """
    Классификация с модальностями:
      - ['text','image'] обязательно через RuCLIP
      - ['text','image','audio'] + звук только через CLAP
    """
    def __init__(
        self,
        modalities: List[str],                    # ['text','image'] или ['text','image','audio']
        num_labels: int,
        target_column_name: str,
        text_columns: List[str],
        image_columns: List[str],
        audio_columns: Optional[List[str]] = None,
        ruclip_model_name: str = "ViT-B-32",
        ruclip_pretrained: Optional[str] = "hf-hub:ai-forever/ru-clip-vit-base-patch32-224",
        ruclip_context_len: int = 77,
        max_images_per_sample: int = 1,
        image_agg: str = "concat",
        audio_cfg: Optional[Dict[str, Any]] = None,    # {'type':'clap','checkpoint':..., 'sr':48000} или None
        max_audios_per_sample: int = 1,
        audio_agg: str = "concat",
        freeze_backbone: bool = True
    ):
        self.modalities = sorted(list(set(modalities)))
        assert set(self.modalities).issuperset({"text","image"}) and set(self.modalities).issubset({"text","image","audio"}), \
            "Поддерживаются только ['text','image'] или ['text','image','audio']"

        self.num_labels = num_labels
        self.target_column_name = target_column_name
        self.text_columns = text_columns
        self.image_columns = image_columns
        self.audio_columns = audio_columns or []

        self.backend = RuCLIPBackend(
            ruclip_model_name=ruclip_model_name,
            ruclip_pretrained=ruclip_pretrained,
            ruclip_context_len=ruclip_context_len,
            max_images=max_images_per_sample,
            image_agg=image_agg,
            max_audios=max_audios_per_sample,
            audio_agg=audio_agg,
            audio_cfg=audio_cfg,  # только CLAP
            freeze=freeze_backbone
        )

        self.model: Optional[SingleBackboneClassifier] = None
        self.trainer: Optional[Trainer] = None
        self.compute_metrics = None

    def _validate_data(self, df: pd.DataFrame):
        for c in self.text_columns:
            if c not in df.columns:
                raise ValueError(f"Нет текстовой колонки '{c}' в DataFrame")
        for c in self.image_columns:
            if c not in df.columns:
                raise ValueError(f"Нет колонки изображений '{c}' в DataFrame")
        if "audio" in self.modalities:
            if not self.audio_columns:
                raise ValueError("Вы выбрали 'audio', но audio_columns пуст")
            for c in self.audio_columns:
                if c not in df.columns:
                    raise ValueError(f"Нет аудио колонки '{c}' в DataFrame")

    def _split(self, df: pd.DataFrame, test_size: float = 0.2, seed: int = 42):
        df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
        n_eval = int(math.ceil(len(df) * test_size))
        return df.iloc[n_eval:].reset_index(drop=True), df.iloc[:n_eval].reset_index(drop=True)

    def _setup_metrics(self, metric_name: str):
        metric_name = metric_name.lower()
        if metric_name == "f1":
            metric = evaluate.load("f1")
            def compute(p):
                preds = p.predictions.argmax(-1)
                return metric.compute(predictions=preds, references=p.label_ids, average="weighted")
        elif metric_name == "accuracy":
            metric = evaluate.load("accuracy")
            def compute(p):
                preds = p.predictions.argmax(-1)
                return metric.compute(predictions=preds, references=p.label_ids)
        else:
            raise ValueError('metric_name должен быть "f1" или "accuracy"')
        self.compute_metrics = compute

    def _make_dataset(self, df: pd.DataFrame, pretokenize_text: bool) -> MultiComboDataset:
        return MultiComboDataset(
            df=df,
            target_col=self.target_column_name,
            label2id=self.label2id,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=(self.audio_columns if "audio" in self.modalities else None),
            text_batch_tokenizer=self.backend.get_text_tokenizer(),
            pretokenize_text=pretokenize_text,
            ruclip_context_len=self.backend.context_len,
            audio_sr=self.backend.audio_sr
        )

    def fit(
        self,
        train_data: pd.DataFrame,
        epochs: int = 3,
        test_size: float = 0.2,
        test_data: Optional[pd.DataFrame] = None,
        per_device_train_batch_size: int = 16,
        gradient_accumulation_steps: int = 1,
        learning_rate: float = 2e-4,
        metric_name: str = "f1",
        fp16: bool = True,
        logging_steps: int = 50,
        eval_steps: int = 200,
        output_dir: str = "./result",
        seed: int = 42,
        hidden: int = 512,
        dropout: float = 0.1,
        fusion: str = "concat",
        early_stopping_patience: Optional[int] = 3,
        early_stopping_threshold: float = 0.0
    ):
        self._validate_data(train_data)
        set_seed(seed)

        classes = sorted(train_data[self.target_column_name].unique().tolist())
        self.label2id = {c: i for i, c in enumerate(classes)}
        self.id2label = {i: str(c) for c, i in self.label2id.items()}
        if self.num_labels != len(classes):
            print(f"Warning: num_labels={self.num_labels} != len(classes)={len(classes)}")

        if test_data is None:
            df_train, df_eval = self._split(train_data, test_size=test_size, seed=seed)
        else:
            df_train, df_eval = train_data, test_data

        y_train_all = np.array([self.label2id[y] for y in df_train[self.target_column_name].tolist()], dtype=int)
        counts = np.bincount(y_train_all, minlength=self.num_labels)
        n_all = counts.sum(); class_weights = np.zeros(self.num_labels, dtype=np.float32)
        nz = counts > 0; class_weights[nz] = n_all / (self.num_labels * counts[nz].astype(np.float32))

        ds_eval = self._make_dataset(df_eval, pretokenize_text=True)
        ds_train = self._make_dataset(df_train, pretokenize_text=True)

        self.model = SingleBackboneClassifier(
            backend=self.backend,
            modalities=self.modalities,
            num_labels=self.num_labels,
            fusion=fusion,
            hidden=hidden,
            dropout=dropout
        )
        self._setup_metrics(metric_name)

        args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=max(4, per_device_train_batch_size // 2),
            gradient_accumulation_steps=gradient_accumulation_steps,
            eval_accumulation_steps=max(1, gradient_accumulation_steps * 2),
            learning_rate=learning_rate,
            warmup_ratio=0.1,
            lr_scheduler_type="cosine",
            weight_decay=0.01,
            eval_strategy="steps",
            eval_steps=eval_steps,
            save_strategy="steps",
            save_steps=eval_steps,
            load_best_model_at_end=True,
            metric_for_best_model=f"eval_{metric_name}",
            save_total_limit=1,
            logging_strategy="steps",
            logging_steps=logging_steps,
            report_to="none",
            fp16=fp16 and torch.cuda.is_available(),
            dataloader_num_workers=min(4, os.cpu_count() or 4),
            seed=seed,
            remove_unused_columns=False,
            dataloader_pin_memory=True,
            ddp_find_unused_parameters=False,
            disable_tqdm=True
        )

        def data_collator(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            return self.backend.collate(batch_list)

        self.trainer = WeightedCETrainer(
            model=self.model,
            args=args,
            train_dataset=ds_train,
            eval_dataset=ds_eval,
            data_collator=data_collator,
            compute_metrics=self.compute_metrics,
            num_labels=self.num_labels,
            train_labels=y_train_all,
            class_weights=class_weights
        )
        self.trainer.remove_callback(PrinterCallback)

        if (ds_eval is not None) and (early_stopping_patience is not None) and (early_stopping_patience > 0):
            esc = EarlyStoppingCallback(
                early_stopping_patience=int(early_stopping_patience),
                early_stopping_threshold=float(early_stopping_threshold)
            )
            self.trainer.add_callback(esc)

        # Улучшенный прогресс-бар (точный total подтянется при старте обучения)
        pbar = tqdm(total=0, desc="Training Progress", unit="step", leave=False, dynamic_ncols=True)
        self.trainer.add_callback(PbarConsoleLogger(pbar))
        try:
            self.trainer.train()
        finally:
            try: pbar.close()
            except Exception: pass

        self.backend.get_text_tokenizer().clear_cache()
        return self

    def predict(self, df: pd.DataFrame, return_label_str: bool = False, return_proba: bool = False, batch_size: Optional[int] = None) -> np.ndarray:
        if self.trainer is None: raise RuntimeError("Модель не обучена. Вызовите .fit().")
        df_c = df.copy()
        if self.target_column_name not in df_c.columns:
            df_c[self.target_column_name] = list(self.label2id.keys())[0]
        ds = self._make_dataset(df_c, pretokenize_text=len(df_c) < 10000)

        if batch_size:
            original_bs = self.trainer.args.per_device_eval_batch_size
            self.trainer.args.per_device_eval_batch_size = batch_size
        self.trainer.args.disable_tqdm = False
        preds = self.trainer.predict(test_dataset=ds)
        self.trainer.args.disable_tqdm = True
        if batch_size:
            self.trainer.args.per_device_eval_batch_size = original_bs

        if return_proba:
            logits = preds.predictions
            exp = np.exp(logits - np.max(logits, axis=1, keepdims=True))
            return exp / np.sum(exp, axis=1, keepdims=True)

        y_pred = np.argmax(preds.predictions, axis=-1)
        if return_label_str:
            return np.array([self.id2label[int(i)] for i in y_pred])
        return y_pred

    def get_embeddings(self, df: pd.DataFrame, batch_size: int = 32, return_per_modality: bool = False):
        if self.trainer is None or self.model is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")
        df_c = df.copy()
        if self.target_column_name not in df_c.columns:
            df_c[self.target_column_name] = list(self.label2id.keys())[0]
        ds = self._make_dataset(df_c, pretokenize_text=False)

        def collate(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            return self.backend.collate(batch_list)

        loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
        fused_list, per_mod_lists = [], ({m: [] for m in self.modalities} if return_per_modality else None)

        print(f"Extracting embeddings (batch_size={batch_size})...")
        device = next(self.model.parameters()).device
        self.model.eval()
        with torch.no_grad():
            for batch in tqdm(loader, unit="batch"):
                bi = batch["backend_inputs"]
                def move_to_device(obj):
                    if torch.is_tensor(obj): return obj.to(device)
                    if isinstance(obj, dict): return {k: move_to_device(v) for k, v in obj.items()}
                    return obj
                bi = move_to_device(bi)
                fused, per = self.model.get_embeddings(backend_inputs=bi, return_per_modality=True)
                fused_list.append(fused.cpu().numpy())
                if return_per_modality:
                    for m in per_mod_lists.keys():
                        if m in per:
                            per_mod_lists[m].append(per[m].cpu().numpy())

        fused_arr = np.vstack(fused_list)
        if not return_per_modality:
            print(f"✓ Embeddings shape: {fused_arr.shape}")
            return fused_arr
        per_mod = {m: np.vstack(chunks) for m, chunks in per_mod_lists.items()}
        print(f"✓ Fused embeddings shape: {fused_arr.shape}")
        for m, arr in per_mod.items():
            print(f"✓ {m.capitalize()} embeddings shape: {arr.shape}")
        return fused_arr, per_mod

Пример использования 1.

In [None]:
import numpy as np
import pandas as pd

set_seed(123)

# 1) Синтетические данные: по 2 картинки и 2 аудио на объект
N = 12
df = pd.DataFrame({
    "label": ["спорт", "еда", "техника"] * (N // 3) + (["спорт"] * (N % 3)),
    "title": [f"заголовок {i}" for i in range(N)],
    "desc":  [f"описание {i}" for i in range(N)],
})

def mk_img():
    return np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)

def mk_wav(sr=48000, dur_s=0.2):
    t = np.linspace(0, dur_s, int(sr*dur_s), endpoint=False, dtype=np.float32)
    return (0.1*np.sin(2*np.pi*440*t)).astype(np.float32)

df["image_path"] = [[mk_img(), mk_img()] for _ in range(N)]     # по 2 изображения
df["audio_path"] = [[mk_wav(), mk_wav()] for _ in range(N)]     # по 2 аудио

# 2) Инициализация пайплайна: RuCLIP из HF Hub, аудио через CLAP
pipeline = RuCLIPMultiModalClassification(
    modalities=["text","image","audio"],
    num_labels=3,
    target_column_name="label",
    text_columns=["title","desc"],
    image_columns=["image_path"],
    audio_columns=["audio_path"],

    ruclip_model_name="ViT-B-32",
    ruclip_pretrained="hf-hub:ai-forever/ru-clip-vit-base-patch32-224",  # скачает веса RuCLIP
    ruclip_context_len=77,

    max_images_per_sample=2,            # по 2 изображения → image_agg применится к 2 признакам
    image_agg="concat",                 # concat или mean

    audio_cfg={
        "type": "clap",
        "checkpoint": "laion/clap-htsat-unfused",
        "sr":48000
    },  # скачает CLAP (~600MB)
    max_audios_per_sample=2,
    audio_agg="mean",                   # усредним 2 аудиофичи

    freeze_backbone=True                # фиксируем RuCLIP/CLAP, обучаем только голову
)

# 3) Обучение
pipeline.fit(
    train_data=df,
    epochs=2,
    test_size=0.25,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=3e-4,
    metric_name="f1",
    fp16=True,                          # если CUDA доступна — быстрее
    logging_steps=5,
    eval_steps=10,
    output_dir="./result_ex1",
    seed=123,
    hidden=768,                         # размер скрытого слоя головы
    dropout=0.2,
    fusion="concat",                    # «безопасный» режим при разных размерностях модальностей
    early_stopping_patience=2,
    early_stopping_threshold=0.0
)

# 4) Предсказания (вернём вероятности классов)
proba = pipeline.predict(df.iloc[:5], return_proba=True)
print("proba shape:", proba.shape)      # (5, 3)

# 5) Эмбеддинги: склеенные и по модальностям
fused, per = pipeline.get_embeddings(df.iloc[:5], batch_size=2, return_per_modality=True)
print("fused:", fused.shape)
for k,v in per.items():
    print(f"{k}: {v.shape}")

Пример использования 2.

In [None]:
import numpy as np
import pandas as pd

set_seed(7)

# 1) Игрушечные данные: по 2 изображения, без аудио
N = 10
df = pd.DataFrame({
    "label": ["A","B"] * (N//2) + (["A"] if N%2 else []),
    "title": [f"title {i}" for i in range(N)],
    "desc":  [f"desc {i}" for i in range(N)],
})
df["image_path"] = [[np.zeros((224,224,3), dtype=np.uint8), np.ones((224,224,3), dtype=np.uint8)*255] for _ in range(N)]

# 2) Инициализация: ruclip_pretrained=None (не скачиваем веса), макс. картинок=2
pipeline = RuCLIPMultiModalClassification(
    modalities=["text","image"],
    num_labels=2,
    target_column_name="label",
    text_columns=["title","desc"],
    image_columns=["image_path"],

    ruclip_model_name="ViT-B-32",
    ruclip_pretrained=None,             # офлайн-режим (случайные веса RuCLIP)
    ruclip_context_len=77,

    max_images_per_sample=2,
    image_agg="mean",                   # усреднение 2 картинок
    freeze_backbone=True
)

# 3) Обучение с «богатыми» параметрами
pipeline.fit(
    train_data=df,
    epochs=3,
    test_size=0.3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    metric_name="accuracy",
    fp16=False,                         # с случайными весами можно и на CPU
    logging_steps=2,
    eval_steps=4,
    output_dir="./result_ex2",
    seed=7,
    hidden=256,
    dropout=0.1,
    fusion="concat",
    early_stopping_patience=1
)

# 4) Предсказания (строковые метки), изменим batch_size на лету
y_str = pipeline.predict(df.iloc[:6], return_label_str=True, batch_size=3)
print("pred labels:", y_str.tolist())

# 5) Только склеенные эмбеддинги (без разбиения по модальностям)
fused = pipeline.get_embeddings(df.iloc[:6], batch_size=3, return_per_modality=False)
print("fused emb:", fused.shape)

Пример использования 3.

In [None]:
import numpy as np
import pandas as pd

set_seed(1)

# 1) Минимальные данные: одно изображение/объект, без аудио
N = 6
df = pd.DataFrame({
    "label": ["кошки","собаки"] * (N//2) + (["кошки"] if N%2 else []),
    "title": [f"пример {i}" for i in range(N)],
})
# Можно не давать desc — оставим только 'title'
df["image_path"] = [np.random.randint(0,255,(224,224,3),dtype=np.uint8) for _ in range(N)]

# 2) Инициализация с минимальной настройкой
pipeline = RuCLIPMultiModalClassification(
    modalities=["text","image"],
    num_labels=2,
    target_column_name="label",
    text_columns=["title"],
    image_columns=["image_path"],
    # Если не хотите скачивать RuCLIP — установите None
    ruclip_model_name="ViT-B-32",
    ruclip_pretrained=None,
    freeze_backbone=True
)

# 3) Короткое обучение и предсказание
pipeline.fit(df, epochs=1, test_size=0.33, per_device_train_batch_size=4, metric_name="accuracy", fp16=False)
pred = pipeline.predict(df.iloc[:3], return_label_str=True)
print("pred:", pred.tolist())

# Дообучение классификатора картинок с библиотекой timm для SOTA результатов.

Реализация классификатора.

In [None]:
!pip install --upgrade --no-cache-dir \
  --extra-index-url https://download.pytorch.org/whl/cu124 \
  torch==2.6.0+cu124 \
  torchvision==0.21.0+cu124 \
  torchaudio==2.6.0+cu124 \
  timm==0.9.16 \
  albumentations==1.4.14 \
  opencv-python-headless==4.10.0.84 \
  accelerate==0.33.0 \
  scikit-learn==1.5.2 \
  iterative-stratification==0.1.7 \
  pillow==11.1.0 \
  numpy==1.26.4 \
  pandas==2.2.3 \
  tqdm==4.67.1

# ==========================
# УСТАНОВКА И ИМПОРТЫ
# ==========================
# В ноутбуке вы уже ставите зависимости через pip.
# Здесь — полный код пайплайна.

import os, math, random, warnings
from typing import List, Optional, Dict, Any, Tuple, Union, Callable

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import timm
from timm.data import resolve_data_config
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy, LabelSmoothingCrossEntropy
from timm.utils import ModelEmaV2

from accelerate import Accelerator
from accelerate.state import PartialState

import cv2

from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
from sklearn.metrics import accuracy_score, f1_score, average_precision_score, roc_auc_score

try:
    from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
    HAS_MLSTRAT = True
except Exception:
    HAS_MLSTRAT = False


# ==========================
# Утилиты
# ==========================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def _is_path_like(x) -> bool:
    return isinstance(x, str)


def _to_chw32(img: np.ndarray) -> np.ndarray:
    # img: HWC RGB float32
    return img.transpose(2, 0, 1).astype("float32")


def _normalize(img: np.ndarray, mean, std) -> np.ndarray:
    img = img.astype("float32")
    if img.max() > 1.5:
        img = img / 255.0
    img = (img - mean) / std
    return img


def _global_pool_and_flatten(feat: torch.Tensor) -> torch.Tensor:
    # feat: [B, C] или [B, C, H, W] -> [B, C]
    if feat.ndim == 4:
        feat = torch.mean(feat, dim=[2, 3])
    return feat


# ==========================
# Numpy-аугментации
# ==========================

def np_resize(img: np.ndarray, out_size: int) -> np.ndarray:
    return cv2.resize(img, (out_size, out_size), interpolation=cv2.INTER_LINEAR)

def np_random_resized_crop(img: np.ndarray, out_size: int,
                           scale: Tuple[float, float] = (0.7, 1.0),
                           ratio: Tuple[float, float] = (0.8, 1.25)) -> np.ndarray:
    h, w = img.shape[:2]
    area = h * w
    for _ in range(10):
        target_area = random.uniform(*scale) * area
        aspect = random.uniform(*ratio)
        new_w = int(round(math.sqrt(target_area * aspect)))
        new_h = int(round(math.sqrt(target_area / aspect)))
        if 0 < new_w <= w and 0 < new_h <= h:
            x1 = random.randint(0, w - new_w)
            y1 = random.randint(0, h - new_h)
            crop = img[y1:y1+new_h, x1:x1+new_w, :]
            return cv2.resize(crop, (out_size, out_size), interpolation=cv2.INTER_LINEAR)
    # fallback: центр-кроп -> ресайз
    return np_resize(img, out_size)

def np_hflip(img: np.ndarray, p: float = 0.5) -> np.ndarray:
    if random.random() < p:
        return cv2.flip(img, 1)
    return img

def np_shift_scale_rotate(img: np.ndarray,
                          shift_limit: float = 0.05,
                          scale_limit: float = 0.1,
                          rotate_limit: int = 15,
                          p: float = 0.7) -> np.ndarray:
    if random.random() >= p:
        return img
    h, w = img.shape[:2]
    angle = random.uniform(-rotate_limit, rotate_limit)
    scale = 1.0 + random.uniform(-scale_limit, scale_limit)
    tx = random.uniform(-shift_limit, shift_limit) * w
    ty = random.uniform(-shift_limit, shift_limit) * h
    M = cv2.getRotationMatrix2D((w/2, h/2), angle, scale)
    M[0,2] += tx
    M[1,2] += ty
    return cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)

def _adjust_brightness(img: np.ndarray, factor: float) -> np.ndarray:
    out = img.astype(np.float32) * factor
    return np.clip(out, 0, 255).astype(np.uint8)

def _adjust_contrast(img: np.ndarray, factor: float) -> np.ndarray:
    mean = np.mean(img, axis=(0,1), keepdims=True)
    out = (img.astype(np.float32) - mean) * factor + mean
    return np.clip(out, 0, 255).astype(np.uint8)

def _adjust_saturation(img: np.ndarray, factor: float) -> np.ndarray:
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    hsv = hsv.astype(np.float32)
    hsv[...,1] = np.clip(hsv[...,1] * factor, 0, 255)
    hsv = hsv.astype(np.uint8)
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

def _adjust_hue(img: np.ndarray, delta: float) -> np.ndarray:
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    hsv = hsv.astype(np.int32)
    shift = int(delta * 180.0)  # OpenCV H в [0,180]
    hsv[...,0] = (hsv[...,0] + shift) % 180
    hsv = np.clip(hsv, 0, 255).astype(np.uint8)
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

def np_color_jitter(img: np.ndarray,
                    brightness: float = 0.2,
                    contrast: float = 0.2,
                    saturation: float = 0.2,
                    hue: float = 0.1,
                    p: float = 0.5) -> np.ndarray:
    if random.random() >= p:
        return img
    ops = []
    if brightness > 0: ops.append(lambda x: _adjust_brightness(x, random.uniform(1 - brightness, 1 + brightness)))
    if contrast > 0: ops.append(lambda x: _adjust_contrast(x, random.uniform(1 - contrast, 1 + contrast)))
    if saturation > 0: ops.append(lambda x: _adjust_saturation(x, random.uniform(1 - saturation, 1 + saturation)))
    if hue > 0: ops.append(lambda x: _adjust_hue(x, random.uniform(-hue, hue)))
    random.shuffle(ops)
    out = img
    for f in ops:
        out = f(out)
    return out

def np_coarse_dropout(img: np.ndarray,
                      max_holes: int = 1,
                      max_h_frac: float = 0.3,
                      max_w_frac: float = 0.3,
                      p: float = 0.5,
                      fill_value: Optional[Tuple[int,int,int]] = None) -> np.ndarray:
    if random.random() >= p:
        return img
    h, w = img.shape[:2]
    out = img.copy()
    holes = random.randint(1, max_holes)
    if fill_value is None:
        # средний цвет
        fill_value = tuple(int(c) for c in out.reshape(-1,3).mean(axis=0))
    for _ in range(holes):
        hh = random.randint(1, max(1, int(max_h_frac * h)))
        ww = random.randint(1, max(1, int(max_w_frac * w)))
        y1 = random.randint(0, max(0, h - hh))
        x1 = random.randint(0, max(0, w - ww))
        out[y1:y1+hh, x1:x1+ww, :] = fill_value
    return out

class NpTransformPipeline:
    """Простая последовательность numpy-аугментаций."""
    def __init__(self, ops: List[Callable[[np.ndarray], np.ndarray]]):
        self.ops = ops
    def __call__(self, img: np.ndarray) -> np.ndarray:
        out = img
        for op in self.ops:
            out = op(out)
        return out


# ==========================
# Датасет
# ==========================

class PandasImageDataset(Dataset):
    """
    Универсальный датасет под pandas.
    image_column: может хранить путь (str) ИЛИ уже массив numpy (HWC RGB).
    target_column: индекс класса (multiclass) или строка/список меток (multilabel).
    transforms: callable(img: HWC uint8 RGB) -> HWC uint8 RGB (без нормализации).
    """
    def __init__(
        self,
        df: pd.DataFrame,
        image_column: str,
        target_column: Optional[str],
        transforms: Optional[Callable[[np.ndarray], np.ndarray]],
        classes: Optional[List[str]] = None,
        multilabel: bool = False,
        label_sep: str = " ",
        mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
        std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
        return_targets: bool = True,
    ):
        self.df = df.reset_index(drop=True)
        self.image_column = image_column
        self.target_column = target_column
        self.transforms = transforms
        self.multilabel = multilabel
        self.label_sep = label_sep
        self.mean = np.array(mean, dtype="float32")
        self.std = np.array(std, dtype="float32")
        self.return_targets = return_targets

        self.classes = classes
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)} if self.classes else None

    def __len__(self):
        return len(self.df)

    def _read_image(self, src):
        if _is_path_like(src):
            img = cv2.imread(src)
            if img is None:
                raise FileNotFoundError(f"Cannot read image at path: {src}")
            return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        elif isinstance(src, np.ndarray):
            img = src
            if img.ndim == 2:
                img = np.stack([img]*3, axis=-1)
            if img.ndim == 3 and img.shape[2] == 4:
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
            if img.dtype != np.uint8 and img.max() <= 1.5:
                img = (img * 255.0).astype(np.uint8)
            return img
        else:
            raise ValueError("image_column должен содержать путь (str) или numpy.ndarray")

    def _encode_target(self, row) -> Union[int, np.ndarray]:
        if self.target_column is None:
            return None
        y = row[self.target_column]
        if not self.multilabel:
            return int(self.class_to_idx[y]) if self.class_to_idx else int(y)
        # multilabel
        if isinstance(y, (list, tuple, set)):
            labels = list(y)
        elif isinstance(y, str):
            labels = [lab for lab in y.split(self.label_sep) if lab]
        else:
            labels = [y]
        vec = np.zeros(len(self.classes), dtype="float32")
        for lab in labels:
            if lab in self.class_to_idx:
                vec[self.class_to_idx[lab]] = 1.0
        return vec

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        img = self._read_image(r[self.image_column])

        if self.transforms is not None:
            img = self.transforms(img)

        img = _normalize(img, self.mean, self.std)
        img_t = torch.from_numpy(_to_chw32(img))

        if not self.return_targets or self.target_column is None:
            return img_t

        tgt = self._encode_target(r)
        if self.multilabel:
            y_t = torch.from_numpy(tgt)
        else:
            y_t = torch.tensor(tgt, dtype=torch.long)
        return img_t, y_t


# ==========================
# Лоссы
# ==========================

class FocalLossMultiLabel(nn.Module):
    def __init__(self, gamma: float = 2.0, reduction: str = "mean"):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        ce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        p_t = p * targets + (1 - p) * (1 - targets)
        loss = ce * ((1 - p_t) ** self.gamma)
        if self.reduction == "mean": return loss.mean()
        if self.reduction == "sum": return loss.sum()
        return loss

class AsymmetricLossMultiLabel(nn.Module):
    def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8, reduction="mean"):
        super().__init__()
        self.gamma_pos, self.gamma_neg, self.clip, self.eps, self.reduction = gamma_pos, gamma_neg, clip, eps, reduction
    def forward(self, logits, targets):
        x_sigmoid = torch.sigmoid(logits)
        xs_pos, xs_neg = x_sigmoid, 1 - x_sigmoid
        if self.clip and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)
        loss = targets * torch.log(xs_pos.clamp(min=self.eps)) + (1 - targets) * torch.log(xs_neg.clamp(min=self.eps))
        if self.gamma_pos > 0 or self.gamma_neg > 0:
            pt = xs_pos * targets + xs_neg * (1 - targets)
            asym_w = (1 - pt) ** (self.gamma_pos * targets + self.gamma_neg * (1 - targets))
            loss *= asym_w
        loss = -loss
        if self.reduction == "mean": return loss.mean()
        if self.reduction == "sum": return loss.sum()
        return loss


# ==========================
# SAM (опционально)
# ==========================

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        if rho <= 0: raise ValueError("rho should be > 0")
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + 1e-12)
            for p in group['params']:
                if p.grad is None: continue
                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale
                p.add_(e_w)
                self.state[p]['e_w'] = e_w
        if zero_grad: self.zero_grad()
    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                p.sub_(self.state[p]['e_w'])
        self.base_optimizer.step()
        if zero_grad: self.zero_grad()
    def step(self, closure=None): raise NotImplementedError("SAM doesn't use step(). Use first_step() and second_step().")
    def zero_grad(self): self.base_optimizer.zero_grad()
    def _grad_norm(self):
        shared_device = self.param_groups[0]['params'][0].device
        norms = [p.grad.norm(p=2).to(shared_device) for group in self.param_groups for p in group['params'] if p.grad is not None]
        return torch.norm(torch.stack(norms), p=2) if norms else torch.tensor(0.0, device=shared_device)


# ==========================
# Пайплайн
# ==========================

class ImageClassificationPipeline:
    """
    Универсальный пайплайн для задач классификации изображений (multiclass/multilabel)
    на PyTorch + timm с аугментациями на numpy/OpenCV, Accelerate, EMA, SAM, LLRD и TTA.

    Параметры инициализации касаются только данных/модели/инференса.
    Параметры обучения и валидации задаются в методе fit(...).

    :param target_column_name: имя колонки с таргетом в DataFrame
    :param image_column_name: имя колонки с изображением (путь к файлу или numpy HWC RGB)
    :param model_name: имя модели для timm.create_model (например, "resnet18", "vit_tiny_patch16_224")
    :param pretrained: использовать предобученные веса timm
    :param drop_path_rate: коэффициент stochastic depth (DropPath)
    :param multilabel: True — задача мультилейбл (multi-hot), False — мультикласс
    :param label_sep: разделитель меток в строковых таргетах для мультилейбл
    :param model_img_size: фиксированный размер входа для некоторых моделей (например, ViT 224)
    :param tta_hflip: использовать горизонтальный флип в TTA
    :param tta_scales: список размеров для multi-scale TTA; если None — без масштабов
    :param tta_crop_size: размер кропа для 5-crop TTA; если None — без мультикропов
    :param seed: базовое случайное зерно
    :param amp: включить mixed precision ("fp16") для Accelerate; False — "no"
    """
    def __init__(
        self,
        target_column_name: str,
        image_column_name: str,
        model_name: str = "convnextv2_base",
        pretrained: bool = True,
        drop_path_rate: float = 0.2,
        multilabel: bool = False,
        label_sep: str = " ",
        model_img_size: Optional[int] = None,
        tta_hflip: bool = True,
        tta_scales: Optional[List[int]] = None,
        tta_crop_size: Optional[int] = None,
        seed: int = 42,
        amp: bool = True,
    ):
        # Параметры модели/данных/инференса
        self.target_col = target_column_name
        self.image_col = image_column_name
        self.model_name = model_name
        self.pretrained = pretrained
        self.drop_path_rate = drop_path_rate
        self.multilabel = multilabel
        self.label_sep = label_sep
        self.model_img_size = model_img_size

        # Параметры TTA (инференс)
        self.tta_hflip = tta_hflip
        self.tta_scales = tta_scales
        self.tta_crop_size = tta_crop_size

        # Служебные
        self.seed = seed
        self.amp = amp

        # Атрибуты, заполняемые в fit()
        self.classes_: Optional[List[str]] = None
        self.models_: List[Any] = []
        self.oof_pred_proba_: Optional[np.ndarray] = None
        self.oof_targets_: Optional[np.ndarray] = None
        self.oof_fold_: Optional[np.ndarray] = None
        self.thresholds_: Optional[np.ndarray] = None
        self._data_config: Optional[Dict[str, Any]] = None

        # Динамические параметры обучения (заполняются в fit и используются после)
        self.img_sizes: List[int] = [224]
        self.stage_epochs: Optional[List[int]] = None
        self.aug_strength: str = "medium"
        self.epochs: int = 0
        self.batch_size: int = 0
        self.num_workers: int = 0
        self.lr: float = 0.0
        self.weight_decay: float = 0.0
        self.warmup_epochs: float = 0.0
        self.grad_clip: float = 0.0
        self.ema_decay: Optional[float] = None
        self.use_sam: bool = False
        self.sam_rho: float = 0.0
        self.mixup_alpha: float = 0.0
        self.cutmix_alpha: float = 0.0
        self.label_smoothing: float = 0.0
        self.disable_mix_last_n_epochs: int = 0
        self.class_weights_in_loss: bool = False
        self.use_weighted_sampler: bool = False
        self.layer_decay: Optional[float] = None
        self.val_metric: str = "f1_macro"
        self.optimize_thresholds: bool = True
        self.n_folds: int = 0
        self.fold_column: Optional[str] = None
        self.group_column: Optional[str] = None
        self.stratify: bool = True
        self.class_weights_: Optional[List[float]] = None
        self.grad_accum_steps: int = 1
        self._warned_sam_accum: bool = False  # внутр. флаг предупреждения SAM+accum

        # Инициализация Accelerate с учётом уже инициализированного состояния
        ps = PartialState()
        already_init = getattr(ps, "initialized", False)
        if already_init:
            self.accelerator = Accelerator()
            try:
                cur_mp = self.accelerator.state.mixed_precision
                want_mp = "fp16" if self.amp else "no"
                if cur_mp != want_mp:
                    warnings.warn(
                        f"Accelerate is already initialized with mixed_precision='{cur_mp}'. "
                        f"Requested amp={'True' if self.amp else 'False'} (='{want_mp}') will be ignored for this instance."
                    )
            except Exception:
                pass
        else:
            self.accelerator = Accelerator(mixed_precision=("fp16" if self.amp else "no"))

    # ---------------------- Трансформы ----------------------

    def _build_train_tfms(self, img_size: int) -> Callable[[np.ndarray], np.ndarray]:
        """
        Создаёт последовательность numpy-аугментаций для тренировки.

        :param img_size: итоговый размер стороны изображения (квадрат)
        :return: callable(img: np.ndarray HWC uint8 RGB) -> np.ndarray HWC uint8 RGB
        """
        if self.aug_strength == "light":
            ops = [
                lambda im: np_random_resized_crop(im, img_size, scale=(0.85, 1.0)),
                lambda im: np_hflip(im, p=0.5),
            ]
        elif self.aug_strength == "heavy":
            ops = [
                lambda im: np_random_resized_crop(im, img_size, scale=(0.6, 1.0)),
                lambda im: np_hflip(im, p=0.5),
                lambda im: np_shift_scale_rotate(im, 0.05, 0.2, 20, p=0.7),
                lambda im: np_color_jitter(im, 0.3, 0.3, 0.3, 0.15, p=0.6),
                lambda im: np_coarse_dropout(im, max_holes=1, max_h_frac=0.3, max_w_frac=0.3, p=0.5),
            ]
        else:
            ops = [
                lambda im: np_random_resized_crop(im, img_size, scale=(0.7, 1.0)),
                lambda im: np_hflip(im, p=0.5),
                lambda im: np_shift_scale_rotate(im, 0.05, 0.1, 15, p=0.6),
                lambda im: np_color_jitter(im, 0.2, 0.2, 0.2, 0.1, p=0.5),
                lambda im: np_coarse_dropout(im, max_holes=1, max_h_frac=0.3, max_w_frac=0.3, p=0.5),
            ]
        return NpTransformPipeline(ops)

    def _build_valid_tfms(self, img_size: int) -> Callable[[np.ndarray], np.ndarray]:
        """
        Создаёт преобразование валидации/инференса (только Resize).

        :param img_size: итоговый размер стороны изображения (квадрат)
        :return: callable(img: np.ndarray HWC uint8 RGB) -> np.ndarray HWC uint8 RGB
        """
        return lambda im: np_resize(im, img_size)

    # ---------------------- Сплиты ----------------------

    def _make_folds(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Создаёт колонку fold c номером фолда (Stratified/Group/Multilabel).

        :param df: исходный DataFrame
        :return: копия df с колонкой "fold"
        """
        if self.fold_column and self.fold_column in df.columns:
            return df.copy()

        df = df.copy()
        df["fold"] = -1

        if self.group_column:
            gkf = GroupKFold(n_splits=self.n_folds)
            for k, (_, val_idx) in enumerate(gkf.split(df, groups=df[self.group_column])):
                df.loc[df.index[val_idx], "fold"] = k
            return df

        if self.multilabel and HAS_MLSTRAT:
            Y = df[self.target_col]
            if isinstance(Y.iloc[0], str):
                Y_bin = Y.str.get_dummies(sep=self.label_sep)
            elif isinstance(Y.iloc[0], (list, tuple, set)):
                uniq = sorted({lab for labs in Y for lab in (labs if isinstance(labs, (list, tuple, set)) else [labs])})
                Y_bin = pd.DataFrame(
                    [[1 if u in (y if isinstance(y, (list, tuple, set)) else [y]) else 0 for u in uniq] for y in Y],
                    columns=uniq
                )
            else:
                raise ValueError("Для multilabel ожидается строка меток или список/множество в target_column.")
            mskf = MultilabelStratifiedKFold(n_splits=self.n_folds, shuffle=True, random_state=self.seed)
            for k, (_, val_idx) in enumerate(mskf.split(df, Y_bin)):
                df.loc[df.index[val_idx], "fold"] = k
            return df

        if self.stratify and not self.multilabel:
            skf = StratifiedKFold(n_splits=self.n_folds, shuffle=True, random_state=self.seed)
            for k, (_, val_idx) in enumerate(skf.split(df, y=df[self.target_col])):
                df.loc[df.index[val_idx], "fold"] = k
        else:
            kf = KFold(n_splits=self.n_folds, shuffle=True, random_state=self.seed)
            for k, (_, val_idx) in enumerate(kf.split(df)):
                df.loc[df.index[val_idx], "fold"] = k
        return df

    # ---------------------- LLRD ----------------------

    def _param_groups_llrd(self, model, base_lr: float, weight_decay: float):
        """
        Формирует группы параметров с layer-wise lr decay для ViT/Swin.

        :param model: модель timm
        :param base_lr: базовое значение learning rate
        :param weight_decay: коэфф. L2-регуляризации для AdamW
        :return: список групп параметров (dict)
        """
        layer_decay = self.layer_decay
        if not layer_decay or layer_decay >= 1.0:
            return [{"params": [p for p in model.parameters() if p.requires_grad], "lr": base_lr, "weight_decay": weight_decay}]

        layers = []
        if hasattr(model, "blocks"):  # ViT/DeiT
            layers = list(model.blocks)
        elif hasattr(model, "layers"):  # Swin
            for l in model.layers:
                if hasattr(l, "blocks"):
                    layers += list(l.blocks)
                else:
                    layers.append(l)

        if not layers:
            warnings.warn("LLRD: не найден blocks/layers — один LR для всех.")
            return [{"params": [p for p in model.parameters() if p.requires_grad], "lr": base_lr, "weight_decay": weight_decay}]

        n = len(layers) + 1
        layer_map = {}
        for i, layer in enumerate(layers):
            for name, p in layer.named_parameters(recurse=True):
                if p.requires_grad:
                    layer_map[id(p)] = i

        head_params = []
        body_params = [[] for _ in range(n)]
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            if id(p) in layer_map:
                body_params[layer_map[id(p)]].append(p)
            else:
                head_params.append(p)

        groups = []
        for i, params in enumerate(body_params):
            if not params:
                continue
            lr_i = base_lr * (layer_decay ** (n - i - 1))
            groups.append({"params": params, "lr": lr_i, "weight_decay": weight_decay})
        if head_params:
            groups.append({"params": head_params, "lr": base_lr, "weight_decay": weight_decay})
        return groups

    # ---------------------- Метрики и пороги ----------------------

    def _compute_metric(self, y_true, y_proba) -> float:
        """
        Считает метрику валидации/OOF согласно self.val_metric.

        :param y_true: истинные метки (1D для мультикласса, 2D (N, C) для мультилейбл)
        :param y_proba: вероятности модели (2D (N, C))
        :return: значение метрики
        """
        if self.multilabel:
            if self.val_metric == "map_macro":
                return average_precision_score(y_true, y_proba, average="macro")
            elif self.val_metric == "auc_ovr":
                try:
                    return roc_auc_score(y_true, y_proba, average="macro")
                except Exception:
                    return average_precision_score(y_true, y_proba, average="macro")
            thr = getattr(self, "thresholds_", None)
            if thr is None:
                y_pred = (y_proba >= 0.5).astype(int)
            else:
                thr_arr = np.array(thr)[None, :] if not np.isscalar(thr) else thr
                y_pred = (y_proba >= thr_arr).astype(int)
            return f1_score(y_true, y_pred, average="macro", zero_division=0)
        else:
            y_pred = y_proba.argmax(1)
            if self.val_metric == "accuracy":
                return accuracy_score(y_true, y_pred)
            else:
                return f1_score(y_true, y_pred, average="macro")

    def _tune_thresholds(self, y_true: np.ndarray, y_proba: np.ndarray):
        """
        Подбирает per-class пороги для мультилейбл по OOF (грид [0.05..0.95]).

        :param y_true: истинные метки мультилейбл (N, C)
        :param y_proba: вероятности модели (N, C)
        """
        if not self.multilabel:
            self.thresholds_ = None
            return
        grid = np.linspace(0.05, 0.95, 19)
        best_per_class = []
        for c in range(y_proba.shape[1]):
            best_c, sc = 0.5, -1
            y_t = y_true[:, c]
            y_p = y_proba[:, c]
            for t in grid:
                s = f1_score(y_t, (y_p >= t).astype(int), average="binary", zero_division=0)
                if s > sc:
                    sc, best_c = s, t
            best_per_class.append(best_c)
        self.thresholds_ = np.array(best_per_class, dtype="float32")

    # ---------------------- LR schedule ----------------------

    def _build_scheduler(self, optimizer, steps_per_epoch: int, epochs: int):
        """
        Строит LambdaLR: линейный warmup -> косинусное убывание.

        :param optimizer: оптимизатор, к которому привязан шедулер
        :param steps_per_epoch: число шагов на эпоху
        :param epochs: число эпох для данной стадии
        :return: torch.optim.lr_scheduler.LambdaLR
        """
        warmup_steps = int(max(1, self.warmup_epochs * steps_per_epoch))
        total_steps = int(max(1, epochs * steps_per_epoch))

        def lr_lambda(step):
            if step < warmup_steps:
                return float(step) / float(max(1, warmup_steps))
            t = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * t))

        return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # ---------------------- Обучение / Валидация ----------------------

    def _create_model_safe(self, pretrained: bool, img_size: Optional[int], num_classes: int):
        """
        Безопасно создаёт модель timm, игнорируя img_size там, где он не поддерживается.

        :param pretrained: использовать предобученные веса timm
        :param img_size: фиксированный размер входа (если нужен)
        :param num_classes: число классов для классификатора
        :return: torch.nn.Module
        """
        kw = dict(pretrained=pretrained, num_classes=num_classes, drop_path_rate=self.drop_path_rate)
        if img_size is not None:
            kw["img_size"] = img_size
        try:
            return timm.create_model(self.model_name, **kw)
        except TypeError as e:
            if "img_size" in str(e):
                kw.pop("img_size", None)
                return timm.create_model(self.model_name, **kw)
            raise

    def _train_one_epoch(self, model, train_loader, loss_fn, optimizer, scheduler, ema=None, mixup_fn=None):
        """
        Одна эпоха обучения с поддержкой AMP, SAM, EMA, mixup/cutmix, grad clipping и аккумуляции градиентов.

        :param model: обучаемая модель (после accelerator.prepare)
        :param train_loader: DataLoader с обучающими батчами
        :param loss_fn: функция потерь для тренинга
        :param optimizer: оптимизатор (SAM или обычный)
        :param scheduler: шедулер LR
        :param ema: EMA-объект (ModelEmaV2) или None
        :param mixup_fn: timm Mixup (или None)
        :return: средний train loss за эпоху
        """
        model.train()
        total_loss = 0.0

        if isinstance(loss_fn, SoftTargetCrossEntropy) or (hasattr(loss_fn, '__class__') and 'SoftTarget' in str(loss_fn.__class__)):
            fallback_loss_fn = nn.CrossEntropyLoss().to(self.accelerator.device)
        else:
            fallback_loss_fn = loss_fn

        pbar = tqdm(train_loader, desc="Training", leave=False, disable=not self.accelerator.is_main_process)

        use_accum = max(1, getattr(self, "grad_accum_steps", 1))
        if isinstance(optimizer, SAM) and use_accum > 1 and not self._warned_sam_accum:
            warnings.warn("SAM с grad_accum_steps>1 не поддерживается корректно; аккумуляция будет проигнорирована (используется 1).")
            self._warned_sam_accum = True

        for imgs, targets in pbar:
            def step_once():
                current_loss_fn = loss_fn if mixup_fn is not None else fallback_loss_fn
                x, y = imgs, targets
                if mixup_fn is not None:
                    x, y = mixup_fn(x, y)

                with self.accelerator.autocast():
                    preds = model(x)
                    loss = current_loss_fn(preds, y)

                if isinstance(optimizer, SAM):
                    self.accelerator.backward(loss)
                    if self.grad_clip:
                        self.accelerator.clip_grad_norm_(self.accelerator.unwrap_model(model).parameters(), self.grad_clip)
                    optimizer.first_step(zero_grad=True)

                    with self.accelerator.autocast():
                        preds2 = model(x)
                        loss2 = current_loss_fn(preds2, y)
                    self.accelerator.backward(loss2)
                    if self.grad_clip:
                        self.accelerator.clip_grad_norm_(self.accelerator.unwrap_model(model).parameters(), self.grad_clip)
                    optimizer.second_step(zero_grad=True)
                else:
                    with self.accelerator.accumulate(model):
                        self.accelerator.backward(loss)
                        if self.grad_clip:
                            self.accelerator.clip_grad_norm_(model.parameters(), self.grad_clip)
                        if self.accelerator.sync_gradients:
                            optimizer.step()
                            optimizer.zero_grad()
                            if scheduler is not None:
                                scheduler.step()
                return loss

            loss = step_once()
            total_loss += loss.item()
            if ema is not None:
                ema.update(self.accelerator.unwrap_model(model))
            if isinstance(optimizer, SAM) and scheduler is not None:
                scheduler.step()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        pbar.close()
        return total_loss / max(1, len(train_loader))

    @torch.no_grad()
    def _validate(self, model, valid_loader, loss_fn):
        """
        Валидация: считает лосс и возвращает вероятности/таргеты.

        :param model: модель для инференса (ema.module или обычная, после prepare)
        :param valid_loader: DataLoader валидации
        :param loss_fn: функция потерь валидации
        :return: (val_loss, probs, targets) — лосс (float), вероятности (np.ndarray), таргеты (np.ndarray)
        """
        model.eval()
        probs, tgts, losses = [], [], []
        pbar = tqdm(valid_loader, desc="Validating", leave=False, disable=not self.accelerator.is_main_process)

        for imgs, targets in pbar:
            with self.accelerator.autocast():
                logits = model(imgs)
                loss = loss_fn(logits, targets)

            loss_gathered = self.accelerator.gather(loss.detach())
            if loss_gathered.ndim == 0:
                loss_gathered = loss_gathered.unsqueeze(0)
            losses.append(loss_gathered.cpu())

            p = torch.sigmoid(logits) if self.multilabel else torch.softmax(logits, dim=1)
            probs.append(self.accelerator.gather(p.detach()).cpu())
            tgts.append(self.accelerator.gather(targets.detach()).cpu())
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        pbar.close()
        probs = torch.cat(probs).numpy()
        tgts = torch.cat(tgts).numpy()
        vloss = torch.cat(losses).mean().item() if losses else 0.0
        return vloss, probs, tgts

    # ---------------------- Fit (обучение) ----------------------

    def fit(
        self,
        df: pd.DataFrame,
        test_size: Optional[float] = None,
        n_folds: int = 5,
        fold_column: Optional[str] = None,
        group_column: Optional[str] = None,
        stratify: bool = True,
        img_sizes: Union[List[int], Tuple[int, ...]] = (224, 384),
        stage_epochs: Optional[List[int]] = None,
        aug_strength: str = "medium",
        epochs: int = 20,
        batch_size: int = 32,
        num_workers: int = 4,
        lr: float = 2e-3,
        weight_decay: float = 1e-4,
        warmup_epochs: float = 3.0,
        grad_clip: float = 1.0,
        ema_decay: Optional[float] = 0.9999,
        use_sam: bool = False,
        sam_rho: float = 0.05,
        mixup_alpha: float = 0.8,
        cutmix_alpha: float = 1.0,
        label_smoothing: float = 0.1,
        disable_mix_last_n_epochs: int = 3,
        class_weights_in_loss: bool = False,
        use_weighted_sampler: bool = False,
        layer_decay: Optional[float] = None,
        val_metric: str = "f1_macro",
        optimize_thresholds: bool = True,
        seed: Optional[int] = None,
        holdout_seed: Optional[int] = None,
        grad_accum_steps: int = 1,
        bce_pos_weight: Optional[Union[float, List[float], np.ndarray]] = None,
        grad_checkpointing: bool = False,
    ):
        """
        Обучает модели и сохраняет OOF-предсказания (holdout или K-fold).

        :param df: DataFrame с колонками изображения и таргета (+ опц. group_column)
        :param test_size: доля данных на валидацию в holdout-режиме (0 < test_size < 1); если None — K-fold
        :param n_folds: число фолдов в K-fold режиме
        :param fold_column: имя готовой колонки с номером фолда (если уже есть)
        :param group_column: имя колонки групп для группового сплита (holdout/K-fold)
        :param stratify: использовать стратификацию в мультиклассе
        :param img_sizes: список размеров для прогрессивного ресайза
        :param stage_epochs: список эпох на стадию; если None — распределяются равномерно
        :param aug_strength: сила аугментаций: "light" | "medium" | "heavy"
        :param epochs: общее число эпох (сумма по стадиям)
        :param batch_size: размер батча
        :param num_workers: число воркеров DataLoader
        :param lr: learning rate для AdamW
        :param weight_decay: weight decay (L2) для AdamW
        :param warmup_epochs: длительность warmup (в “эпохах”) для линейного разогрева LR
        :param grad_clip: максимум L2-нормы градиента (0/None — выключить)
        :param ema_decay: коэффициент EMA (None/0 — выключить EMA)
        :param use_sam: включить SAM (Sharpness-Aware Minimization)
        :param sam_rho: радиус окна для SAM
        :param mixup_alpha: параметр mixup (0 — выключить)
        :param cutmix_alpha: параметр cutmix (0 — выключить)
        :param label_smoothing: сглаживание меток для CrossEntropy
        :param disable_mix_last_n_epochs: число финальных эпох, где mixup/cutmix отключён
        :param class_weights_in_loss: использовать веса классов в CrossEntropy (только мультикласс)
        :param use_weighted_sampler: балансировка батчей через WeightedRandomSampler (только мультикласс)
        :param layer_decay: LLRD коэффициент (для ViT/Swin), например 0.75–0.8; None — выключено
        :param val_metric: метрика валидации: "accuracy" | "f1_macro" | "auc_ovr" | "map_macro"
        :param optimize_thresholds: подбирать per-class пороги для мультилейбл по OOF
        :param seed: случайное зерно обучения; если None — используется значение из __init__
        :param holdout_seed: зерно для holdout-сплита; если None — используется seed
        :param grad_accum_steps: число микробатчей для аккумуляции градиентов (для обычного оптимизатора)
        :param bce_pos_weight: pos_weight для BCEWithLogitsLoss (float | список | np.array) для мультилейбл;
                               если None — считается автоматически по train split; игнорируется для focal/asl
        :param grad_checkpointing: включить gradient checkpointing, если модель timm поддерживает set_grad_checkpointing
        :return: self
        """
        if seed is not None:
            self.seed = seed
        set_seed(self.seed)

        # Синхронизация параметров обучения c self.*
        self.n_folds = n_folds
        self.fold_column = fold_column
        self.group_column = group_column
        self.stratify = stratify

        self.img_sizes = list(img_sizes) if isinstance(img_sizes, (list, tuple)) else [img_sizes]
        self.stage_epochs = stage_epochs
        self.aug_strength = aug_strength
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.weight_decay = weight_decay
        self.warmup_epochs = warmup_epochs
        self.grad_clip = grad_clip
        self.ema_decay = ema_decay
        self.use_sam = use_sam
        self.sam_rho = sam_rho
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha
        self.label_smoothing = label_smoothing
        self.disable_mix_last_n_epochs = disable_mix_last_n_epochs
        self.class_weights_in_loss = class_weights_in_loss
        self.use_weighted_sampler = use_weighted_sampler
        self.layer_decay = layer_decay
        self.val_metric = val_metric
        self.optimize_thresholds = optimize_thresholds
        self.grad_accum_steps = max(1, int(grad_accum_steps))

        # Классы
        if self.multilabel:
            uniq = set()
            for y in df[self.target_col]:
                if isinstance(y, str):
                    uniq.update([lab for lab in y.split(self.label_sep) if lab])
                elif isinstance(y, (list, tuple, set)):
                    uniq.update(list(y))
                else:
                    raise ValueError("Для multilabel ожидается строка меток или список/множество.")
            self.classes_ = sorted(list(uniq))
        else:
            self.classes_ = sorted(list(pd.unique(df[self.target_col])))
        num_classes = len(self.classes_)
        if num_classes <= 0:
            raise ValueError("Не удалось определить список классов.")

        # Режим: holdout или k-fold
        use_holdout = test_size is not None
        if use_holdout:
            if not (0.0 < test_size < 1.0):
                raise ValueError("test_size должен быть в диапазоне (0, 1).")
            ho_seed = self.seed if holdout_seed is None else holdout_seed

            df_use = df.copy().reset_index(drop=True)
            df_use["fold"] = -1

            n = len(df_use)
            val_size = test_size

            if self.group_column:
                from sklearn.model_selection import GroupShuffleSplit
                gss = GroupShuffleSplit(n_splits=1, test_size=val_size, random_state=ho_seed)
                train_idx, val_idx = next(gss.split(df_use, groups=df_use[self.group_column]))
            elif self.multilabel:
                try:
                    from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit as MLSSS  # type: ignore
                    Y = df_use[self.target_col]
                    if isinstance(Y.iloc[0], str):
                        Y_bin = Y.str.get_dummies(sep=self.label_sep).values
                    elif isinstance(Y.iloc[0], (list, tuple, set)):
                        uniq = self.classes_
                        Y_bin = np.stack([
                            np.array([1.0 if u in (y if isinstance(y, (list, tuple, set)) else [y]) else 0.0 for u in uniq], dtype=np.float32)
                            for y in Y
                        ], axis=0)
                    else:
                        y0 = Y.iloc[0]
                        if isinstance(y0, (np.ndarray, list, tuple)):
                            Y_bin = np.stack(Y.values, axis=0).astype(np.float32)
                        else:
                            raise ValueError("Невозможно стратифицировать multilabel: непонятный формат target_column.")
                    msss = MLSSS(n_splits=1, test_size=val_size, random_state=ho_seed)
                    train_idx, val_idx = next(msss.split(np.zeros((len(df_use), 1)), Y_bin))
                except Exception:
                    from sklearn.model_selection import ShuffleSplit
                    ss = ShuffleSplit(n_splits=1, test_size=val_size, random_state=ho_seed)
                    train_idx, val_idx = next(ss.split(df_use))
            else:
                from sklearn.model_selection import train_test_split
                y = df_use[self.target_col].values
                if self.stratify:
                    train_idx, val_idx = train_test_split(
                        np.arange(n), test_size=val_size, random_state=ho_seed, stratify=y
                    )
                else:
                    train_idx, val_idx = train_test_split(
                        np.arange(n), test_size=val_size, random_state=ho_seed, stratify=None
                    )
            df_use.loc[val_idx, "fold"] = 0
            folds_iter = [0]
            n_folds_eff = 1
        else:
            df_use = self._make_folds(df)
            folds_iter = list(range(self.n_folds))
            n_folds_eff = self.n_folds

        # OOF
        self.oof_pred_proba_ = np.zeros((len(df_use), num_classes), dtype="float32")
        self.oof_targets_ = np.zeros((len(df_use), num_classes if self.multilabel else 1), dtype="float32")
        self.oof_fold_ = np.full(len(df_use), -1, dtype="int32")

        # Эпохи по стадиям
        if self.stage_epochs is None:
            base = self.epochs // len(self.img_sizes)
            rem = self.epochs - base * len(self.img_sizes)
            self.stage_epochs = [base + (1 if i < rem else 0) for i in range(len(self.img_sizes))]

        self.models_.clear()

        # Цикл по фолдам
        fold_pbar = tqdm(folds_iter, desc="Folds", disable=not self.accelerator.is_main_process)
        for fold in fold_pbar:
            fold_pbar.set_description(f"Fold {fold+1}/{n_folds_eff}")

            if "fold" not in df_use.columns:
                raise RuntimeError("Нет fold-колонки. Проверьте _make_folds()/holdout-сплит.")

            if use_holdout:
                df_tr = df_use[df_use["fold"] != 0].reset_index(drop=True)
                df_va = df_use[df_use["fold"] == 0].reset_index()
            else:
                df_tr = df_use[df_use["fold"] != fold].reset_index(drop=True)
                df_va = df_use[df_use["fold"] == fold].reset_index()

            # timm default_cfg
            ghost = self._create_model_safe(self.pretrained, self.model_img_size or self.img_sizes[-1], num_classes)
            self._data_config = resolve_data_config({}, model=ghost)
            mean = tuple(self._data_config.get("mean", (0.485, 0.456, 0.406)))
            std = tuple(self._data_config.get("std", (0.229, 0.224, 0.225)))

            # Модель
            model = self._create_model_safe(self.pretrained, self.model_img_size or self.img_sizes[-1], num_classes)

            # Включаем gradient checkpointing, если запрошено и модель поддерживает
            if grad_checkpointing and hasattr(model, "set_grad_checkpointing"):
                model.set_grad_checkpointing(True)
                if self.accelerator.is_main_process:
                    print("[GC] Gradient checkpointing enabled.")

            ema = None

            # Балансировка (мультикласс): class weights
            self.class_weights_ = None
            if (not self.multilabel) and self.class_weights_in_loss:
                counts = df_tr[self.target_col].value_counts().reindex(self.classes_, fill_value=0).values.astype(np.float32)
                cw = (counts.sum() / (counts + 1e-6))
                self.class_weights_ = (cw / cw.mean()).tolist()

            # Sampler (мультикласс): WeightedRandomSampler
            sampler = None
            if self.use_weighted_sampler and not self.multilabel:
                counts_map = df_tr[self.target_col].value_counts().reindex(self.classes_, fill_value=0).to_dict()
                sample_weights = df_tr[self.target_col].map(lambda x: 1.0 / max(1, counts_map.get(x, 1))).values
                sampler = WeightedRandomSampler(
                    weights=torch.as_tensor(sample_weights, dtype=torch.double),
                    num_samples=len(sample_weights),
                    replacement=True
                )

            # pos_weight для мультилейбл BCE
            pos_weight_tensor = None
            if self.multilabel:
                Y = df_tr[self.target_col]
                if isinstance(Y.iloc[0], str):
                    Y_bin = Y.str.get_dummies(sep=self.label_sep).reindex(columns=self.classes_, fill_value=0).values.astype(np.float32)
                elif isinstance(Y.iloc[0], (list, tuple, set)):
                    Y_bin = np.zeros((len(Y), num_classes), dtype=np.float32)
                    for i, y in enumerate(Y):
                        labs = list(y) if isinstance(y, (list, tuple, set)) else [y]
                        for lab in labs:
                            if lab in self.classes_:
                                Y_bin[i, self.classes_.index(lab)] = 1.0
                else:
                    y0 = Y.iloc[0]
                    if isinstance(y0, (np.ndarray, list, tuple)) and len(y0) == num_classes:
                        Y_bin = np.stack(Y.values, axis=0).astype(np.float32)
                    else:
                        Y_bin = np.zeros((len(Y), num_classes), dtype=np.float32)

                if bce_pos_weight is not None:
                    if isinstance(bce_pos_weight, (list, tuple, np.ndarray)):
                        pw = np.asarray(bce_pos_weight, dtype=np.float32)
                        if pw.shape[0] != num_classes:
                            raise ValueError("Длина bce_pos_weight должна совпадать с числом классов.")
                        pos_weight_tensor = torch.as_tensor(pw, dtype=torch.float32, device=self.accelerator.device)
                    else:
                        pos_weight_tensor = torch.full((num_classes,), float(bce_pos_weight), dtype=torch.float32, device=self.accelerator.device)
                else:
                    pc = Y_bin.sum(axis=0) + 1e-6
                    Ntr = float(Y_bin.shape[0])
                    pw = (Ntr - pc) / pc
                    pw = np.clip(pw, 1.0, 100.0).astype(np.float32)
                    pos_weight_tensor = torch.as_tensor(pw, dtype=torch.float32, device=self.accelerator.device)

            # Лоссы
            mixup_active = (self.mixup_alpha > 0 or self.cutmix_alpha > 0) and not self.multilabel
            if self.multilabel:
                loss_name = getattr(self, "loss_name", "bce")
                if loss_name == "focal":
                    train_loss_fn = FocalLossMultiLabel(gamma=2.0)
                    val_loss_fn = nn.BCEWithLogitsLoss()
                elif loss_name == "asl":
                    train_loss_fn = AsymmetricLossMultiLabel(gamma_pos=0, gamma_neg=4, clip=0.05)
                    val_loss_fn = nn.BCEWithLogitsLoss()
                else:
                    if pos_weight_tensor is not None:
                        train_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
                        val_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
                    else:
                        train_loss_fn = nn.BCEWithLogitsLoss()
                        val_loss_fn = nn.BCEWithLogitsLoss()
            else:
                if mixup_active:
                    train_loss_fn = SoftTargetCrossEntropy()
                elif self.label_smoothing > 0:
                    train_loss_fn = LabelSmoothingCrossEntropy(smoothing=self.label_smoothing)
                elif self.class_weights_in_loss and self.class_weights_:
                    train_loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(self.class_weights_, dtype=torch.float32))
                else:
                    train_loss_fn = nn.CrossEntropyLoss()
                val_loss_fn = nn.CrossEntropyLoss()

            # Оптимизатор / SAM (с LLRD при необходимости)
            if self.layer_decay:
                param_groups = self._param_groups_llrd(model, self.lr, self.weight_decay)
                base_optimizer = torch.optim.AdamW(param_groups, lr=self.lr, weight_decay=self.weight_decay)
                optimizer = SAM(param_groups, torch.optim.AdamW, rho=self.sam_rho, lr=self.lr, weight_decay=self.weight_decay) if self.use_sam else base_optimizer
            else:
                base_optimizer = torch.optim.AdamW(model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
                optimizer = SAM(model.parameters(), torch.optim.AdamW, rho=self.sam_rho, lr=self.lr, weight_decay=self.weight_decay) if self.use_sam else base_optimizer

            best_score = -1.0
            best_state = None

            # Прогрессивный ресайз
            for stage, (img_size, epc) in enumerate(zip(self.img_sizes, self.stage_epochs)):
                train_tfms = self._build_train_tfms(img_size)
                valid_tfms = self._build_valid_tfms(img_size)

                train_ds = PandasImageDataset(
                    df_tr, self.image_col, self.target_col, train_tfms,
                    classes=self.classes_, multilabel=self.multilabel, label_sep=self.label_sep,
                    mean=mean, std=std, return_targets=True
                )
                valid_ds = PandasImageDataset(
                    df_va, self.image_col, self.target_col, valid_tfms,
                    classes=self.classes_, multilabel=self.multilabel, label_sep=self.label_sep,
                    mean=mean, std=std, return_targets=True
                )

                pin_mem = (self.accelerator.device.type == 'cuda')
                train_loader = DataLoader(
                    train_ds, batch_size=self.batch_size, shuffle=(sampler is None), sampler=sampler,
                    num_workers=self.num_workers, pin_memory=pin_mem, drop_last=True, persistent_workers=False
                )
                valid_loader = DataLoader(
                    valid_ds, batch_size=max(1, self.batch_size * 2), shuffle=False,
                    num_workers=self.num_workers, pin_memory=pin_mem, drop_last=False, persistent_workers=False
                )

                mixup_fn = None
                if mixup_active and (not self.multilabel):
                    mixup_fn = Mixup(
                        mixup_alpha=self.mixup_alpha, cutmix_alpha=self.cutmix_alpha,
                        label_smoothing=self.label_smoothing, num_classes=num_classes
                    )

                # Scheduler
                sched_target_opt = optimizer.base_optimizer if isinstance(optimizer, SAM) else optimizer
                scheduler = self._build_scheduler(sched_target_opt, max(1, len(train_loader)), epc)

                # Accelerate.prepare: если SAM, готовим только base_optimizer
                if isinstance(optimizer, SAM):
                    opt_to_prepare = optimizer.base_optimizer
                else:
                    opt_to_prepare = optimizer

                model, opt_prepared, train_loader, valid_loader, scheduler = self.accelerator.prepare(
                    model, opt_to_prepare, train_loader, valid_loader, scheduler
                )

                if isinstance(optimizer, SAM):
                    optimizer.base_optimizer = opt_prepared
                else:
                    optimizer = opt_prepared

                # EMA после prepare на правильном девайсе
                if self.ema_decay:
                    if ema is None:
                        ema = ModelEmaV2(
                            self.accelerator.unwrap_model(model),
                            decay=self.ema_decay,
                            device=self.accelerator.device
                        )
                    else:
                        ema.set(self.accelerator.unwrap_model(model))

                for epoch in range(epc):
                    global_epoch_idx = sum(self.stage_epochs[:stage]) + epoch
                    use_mix_now = mixup_fn is not None and global_epoch_idx < (self.epochs - self.disable_mix_last_n_epochs)

                    train_loss = self._train_one_epoch(
                        model, train_loader, train_loss_fn, optimizer, scheduler, ema,
                        mixup_fn if use_mix_now else None
                    )

                    eval_model = ema.module if (ema is not None) else model
                    vloss, vproba, vtgts = self._validate(eval_model, valid_loader, val_loss_fn)
                    vt = vtgts if self.multilabel else vtgts.reshape(-1)
                    score = self._compute_metric(vt, vproba)

                    if self.accelerator.is_main_process:
                        tqdm.write(
                            f"fold {fold+1}/{n_folds_eff} | stage {stage+1}/{len(self.img_sizes)} | "
                            f"epoch {epoch+1}/{epc} | train_loss {train_loss:.4f} | val_loss {vloss:.4f} | score {score:.5f}"
                        )

                    if score > best_score:
                        best_score = score
                        state = {k: v.cpu() for k, v in self.accelerator.unwrap_model(eval_model).state_dict().items()}
                        best_state = state

                # unwrap после стадии
                model = self.accelerator.unwrap_model(model)

            # Восстановить лучшую модель фолда
            if best_state is None:
                raise RuntimeError("best_state не был сохранён — проверьте цикл обучения.")
            best_model = self._create_model_safe(False, self.model_img_size or self.img_sizes[-1], num_classes)
            best_model.load_state_dict(best_state, strict=True)
            best_model = self.accelerator.prepare(best_model)
            best_model.eval()

            # OOF для вал. части
            va_tfms = self._build_valid_tfms(self.img_sizes[-1])
            va_ds = PandasImageDataset(
                df_va, self.image_col, self.target_col, va_tfms,
                classes=self.classes_, multilabel=self.multilabel, label_sep=self.label_sep,
                mean=mean, std=std, return_targets=True
            )
            va_loader = DataLoader(
                va_ds, batch_size=max(1, self.batch_size * 2), shuffle=False,
                num_workers=self.num_workers, pin_memory=(self.accelerator.device.type == 'cuda'),
                persistent_workers=False
            )
            va_loader = self.accelerator.prepare(va_loader)

            all_probs, all_tgts = [], []
            with torch.no_grad():
                for imgs, tgts in va_loader:
                    logits = best_model(imgs)
                    probs = torch.sigmoid(logits) if self.multilabel else torch.softmax(logits, dim=1)
                    all_probs.append(self.accelerator.gather(probs).cpu())
                    all_tgts.append(self.accelerator.gather(tgts).cpu())
            all_probs = torch.cat(all_probs).numpy()
            all_tgts = torch.cat(all_tgts).numpy()

            self.oof_pred_proba_[df_va.index.values] = all_probs
            if self.multilabel:
                self.oof_targets_[df_va.index.values] = all_tgts
            else:
                self.oof_targets_[df_va.index.values, 0] = all_tgts
            self.oof_fold_[df_va.index.values] = fold if not use_holdout else 0

            self.models_.append(self.accelerator.unwrap_model(best_model))
            self.accelerator.wait_for_everyone()

        # Пороги (multilabel)
        if self.optimize_thresholds and self.multilabel:
            mask = (self.oof_fold_ >= 0)
            if mask.any():
                self._tune_thresholds(self.oof_targets_[mask], self.oof_pred_proba_[mask])

        # Итоговый CV по доступным OOF
        mask = (self.oof_fold_ >= 0)
        if mask.any():
            if self.multilabel:
                y_true, y_proba = self.oof_targets_[mask], self.oof_pred_proba_[mask]
            else:
                y_true, y_proba = self.oof_targets_[mask].reshape(-1), self.oof_pred_proba_[mask]
            cv_score = self._compute_metric(y_true, y_proba)
        else:
            cv_score = float("nan")

        if self.accelerator.is_main_process:
            print(f"\nCV {self.val_metric}: {cv_score:.5f}" if not np.isnan(cv_score) else "\nCV: no validation predictions (mask empty)")

        return self

    # ---------------------- TTA ----------------------

    def _apply_tta(self, imgs: torch.Tensor, scales: List[int], crop_size: Optional[int]) -> List[torch.Tensor]:
        """
        Генерирует TTA-вариации батча: multi-scale, 5-crop, hflip.

        :param imgs: входной батч тензоров [B, C, H, W]
        :param scales: список размеров для ресайза (квадрат)
        :param crop_size: размер 5-crop (None — без multi-crop)
        :return: список TTA-версий батча (тензоры)
        """
        outs = []

        def resize(imgs_, new_size):
            return torch.nn.functional.interpolate(imgs_, size=(new_size, new_size), mode="bilinear", align_corners=False)

        def five_crops(imgs_, out_size):
            H, W = imgs_.shape[-2:]
            if out_size is None or out_size >= H or out_size >= W:
                return [imgs_]
            coords = [
                (0, 0), (0, W - out_size),
                (H - out_size, 0), (H - out_size, W - out_size),
                ((H - out_size) // 2, (W - out_size) // 2)
            ]
            return [imgs_[..., y:y + out_size, x:x + out_size] for y, x in coords]

        for s in scales:
            r = resize(imgs, s)
            crops = five_crops(r, crop_size) if crop_size is not None else [r]
            for ci in crops:
                outs.append(ci)
                if self.tta_hflip:
                    outs.append(torch.flip(ci, dims=[-1]))
        return outs

    # ---------------------- Predict ----------------------

    @torch.no_grad()
    def predict(self, df: pd.DataFrame, return_proba: bool = True, batch_size: Optional[int] = None,
                tta: bool = True) -> Union[np.ndarray, np.ndarray]:
        """
        Предсказывает для df, усредняя по фолдам (и по TTA при включении).

        :param df: DataFrame с колонкой изображения (путь или numpy HWC RGB)
        :param return_proba: True — вернуть вероятности (N×C); False — индексы классов (multiclass)
        :param batch_size: размер батча на инференсе; если None — 2×тренировочный
        :param tta: включить Test-Time Augmentation
        :return: np.ndarray вероятностей (multилabel/multiclass) или индексов классов (multiclass, return_proba=False)
        """
        if not self.models_:
            raise RuntimeError("Сначала вызовите fit().")
        batch_size = batch_size or max(1, self.batch_size * 2)

        mean = tuple(self._data_config.get("mean", (0.485, 0.456, 0.406))) if self._data_config else (0.485, 0.456, 0.406)
        std = tuple(self._data_config.get("std", (0.229, 0.224, 0.225))) if self._data_config else (0.229, 0.224, 0.225)
        base_size = self.img_sizes[-1]
        valid_tfms = self._build_valid_tfms(base_size)

        ds = PandasImageDataset(
            df, self.image_col, None, valid_tfms,
            classes=self.classes_, multilabel=self.multilabel, label_sep=self.label_sep,
            mean=mean, std=std, return_targets=False
        )
        pin_mem = (self.accelerator.device.type == 'cuda')
        dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=pin_mem, persistent_workers=False)
        dl = self.accelerator.prepare(dl)

        if self.model_img_size is not None:
            scales = [self.model_img_size] if tta else [self.model_img_size]
            crop_size = None
        else:
            scales = self.tta_scales if (tta and self.tta_scales) else [base_size]
            crop_size = (self.tta_crop_size if (tta and self.tta_crop_size and self.tta_crop_size < max(scales)) else None)

        all_fold_probs = []
        for i, model in enumerate(self.models_):
            model.to(self.accelerator.device)
            model.eval()
            fold_probs = []
            pbar_desc = f"Predicting fold {i+1}/{len(self.models_)}"
            for imgs in tqdm(dl, desc=pbar_desc, leave=False, disable=not self.accelerator.is_main_process):
                aug_imgs = self._apply_tta(imgs, scales, crop_size) if tta else [imgs]
                tta_preds = []
                for batch_aug in aug_imgs:
                    logits = model(batch_aug)
                    p = torch.sigmoid(logits) if self.multilabel else torch.softmax(logits, dim=1)
                    tta_preds.append(p)
                tta_mean = torch.stack(tta_preds, dim=0).mean(dim=0)
                fold_probs.append(self.accelerator.gather(tta_mean).cpu())
            all_fold_probs.append(torch.cat(fold_probs, dim=0).numpy())

        probs = np.mean(all_fold_probs, axis=0)

        if self.multilabel:
            return probs
        if return_proba:
            return probs
        else:
            return probs.argmax(1)

    # ---------------------- Embeddings ----------------------

    @torch.no_grad()
    def get_embeddings(self, df: pd.DataFrame, batch_size: Optional[int] = None) -> np.ndarray:
        """
        Извлекает эмбеддинги (признаки до классификатора), усредняя по фолдам.

        :param df: DataFrame с колонкой изображения (путь или numpy HWC RGB)
        :param batch_size: размер батча на инференсе эмбеддингов; если None — 2×тренировочный
        :return: матрица эмбеддингов формы [N, D]
        """
        if not self.models_:
            raise RuntimeError("Сначала вызовите fit().")
        batch_size = batch_size or max(1, self.batch_size * 2)

        mean = tuple(self._data_config.get("mean", (0.485, 0.456, 0.406))) if self._data_config else (0.485, 0.456, 0.406)
        std = tuple(self._data_config.get("std", (0.229, 0.224, 0.225))) if self._data_config else (0.229, 0.224, 0.225)
        img_size = self.img_sizes[-1]
        tfms = self._build_valid_tfms(img_size)

        ds = PandasImageDataset(
            df, self.image_col, None, tfms,
            classes=self.classes_, multilabel=self.multilabel, label_sep=self.label_sep,
            mean=mean, std=std, return_targets=False
        )
        pin_mem = (self.accelerator.device.type == 'cuda')
        dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=pin_mem, persistent_workers=False)
        dl = self.accelerator.prepare(dl)

        def extract_features(model, x):
            if hasattr(model, "forward_features"):
                feat = model.forward_features(x)
            else:
                feat = model(x)
            if isinstance(feat, (list, tuple)):
                feat = feat[-1]
            if feat.ndim == 3:   # ViT/DeiT: [B, N, D] -> CLS
                feat = feat[:, 0, :]
            if feat.ndim == 4:   # CNN: [B, C, H, W] -> GAP
                feat = torch.mean(feat, dim=[2, 3])
            return feat

        all_fold_embs = []
        for i, model in enumerate(self.models_):
            model.to(self.accelerator.device)
            model.eval()
            fold_embs = []
            pbar_desc = f"Embeddings fold {i+1}/{len(self.models_)}"
            for imgs in tqdm(dl, desc=pbar_desc, leave=False, disable=not self.accelerator.is_main_process):
                feat = extract_features(model, imgs)
                feat = feat.flatten(1)
                fold_embs.append(self.accelerator.gather(feat).cpu())
            all_fold_embs.append(torch.cat(fold_embs, dim=0).numpy())

        return np.mean(all_fold_embs, axis=0)

    def save(self, path: str):
        """
        Сохраняет ансамбль моделей и метаданные пайплайна в один .pt файл.
        Сохраняется только на главном процессе Accelerate.
        """
        if not self.accelerator.is_main_process:
            return
        if not self.models_:
            raise RuntimeError("Нет обученных моделей: сначала вызовите fit().")
    
        num_classes = len(self.classes_ or [])
        mean = tuple(self._data_config.get("mean", (0.485, 0.456, 0.406))) if self._data_config else (0.485, 0.456, 0.406)
        std  = tuple(self._data_config.get("std",  (0.229, 0.224, 0.225)))  if self._data_config else (0.229, 0.224, 0.225)
    
        state_dicts = []
        for m in self.models_:
            # На всякий случай снимаем с DDP/AMP-обёрток (должны быть уже unwrap’нуты в вашем fit)
            sd = {k: v.cpu() for k, v in m.state_dict().items()}
            state_dicts.append(sd)
    
        ckpt = {
            "format": "ICPv1",
            "saved_at": time.strftime("%Y-%m-%d %H:%M:%S"),
            "lib_versions": {
                "torch": torch.__version__,
                "timm": getattr(timm, "__version__", "unknown"),
                "accelerate": getattr(__import__("accelerate"), "__version__", "unknown"),
            },
            # конструкторные параметры
            "target_column_name": self.target_col,
            "image_column_name": self.image_col,
            "model_name": self.model_name,
            "pretrained": self.pretrained,
            "drop_path_rate": self.drop_path_rate,
            "multilabel": self.multilabel,
            "label_sep": self.label_sep,
            "model_img_size": self.model_img_size,
            "tta_hflip": self.tta_hflip,
            "tta_scales": self.tta_scales,
            "tta_crop_size": self.tta_crop_size,
            # мета обучения/инференса
            "classes": self.classes_,
            "num_classes": num_classes,
            "img_sizes": self.img_sizes,
            "val_metric": self.val_metric,
            "thresholds": (self.thresholds_.tolist() if self.thresholds_ is not None else None),
            "mean": mean,
            "std": std,
            # веса ансамбля (список по фолдам)
            "state_dicts": state_dicts,
        }
        torch.save(ckpt, path)
        print(f"[save] checkpoint saved to: {path}")

    @classmethod
    def load(cls, path: str, map_location: Optional[str] = "cpu"):
        try:
            from torch.torch_version import TorchVersion
            from torch.serialization import add_safe_globals
            add_safe_globals([TorchVersion])
            ckpt = torch.load(path, map_location=map_location) 
        except Exception as e:
            ckpt = torch.load(path, map_location=map_location, weights_only=False)

        pipe = cls(
            target_column_name=ckpt.get("target_column_name", "label"),
            image_column_name=ckpt.get("image_column_name", "image"),
            model_name=ckpt["model_name"],
            pretrained=False,
            multilabel=ckpt["multilabel"],
            drop_path_rate=ckpt.get("drop_path_rate", 0.0),
            label_sep=ckpt.get("label_sep", " "),
            model_img_size=ckpt.get("model_img_size", None),
            tta_hflip=ckpt.get("tta_hflip", True),
            tta_scales=ckpt.get("tta_scales", None),
            tta_crop_size=ckpt.get("tta_crop_size", None),
            amp=True,
            seed=42,
        )
    
        pipe.classes_ = ckpt["classes"]
        pipe.img_sizes = ckpt.get("img_sizes", [224])
        pipe.val_metric = ckpt.get("val_metric", "f1_macro")
        thr = ckpt.get("thresholds", None)
        pipe.thresholds_ = (torch.tensor(thr).numpy() if thr is not None else None)
        pipe._data_config = {
            "mean": tuple(ckpt.get("mean", (0.485, 0.456, 0.406))),
            "std":  tuple(ckpt.get("std",  (0.229, 0.224, 0.225))),
        }
    
        num_classes = ckpt["num_classes"]
        pipe.models_.clear()
        for sd in ckpt["state_dicts"]:
            m = pipe._create_model_safe(
                pretrained=False,
                img_size=(pipe.model_img_size or pipe.img_sizes[-1]),
                num_classes=num_classes
            )
            m.load_state_dict(sd, strict=True)
            m.eval()
            pipe.models_.append(m)
    
        print(f"[load] loaded {len(pipe.models_)} model(s) from {path} ({ckpt['model_name']}); classes={num_classes}")
        return pipe

Краш-тесты.

In [None]:
import warnings
warnings.filterwarnings("ignore", message="`torch.cuda.amp.GradScaler", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*does not have many workers.*")

import numpy as np
import pandas as pd
import cv2
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import train_test_split
import tempfile
import os

# ==================== ТЕСТ 1: Базовый multiclass с малым датасетом ====================
print("TEST 1: Tiny dataset (edge case)")
def test_1_tiny_dataset():
    rng = np.random.default_rng(1)
    N = 10
    imgs = [rng.integers(0, 255, (32, 32, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.integers(0, 2, N)
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
        amp=False,
    )

    clf.fit(
        df,
        img_sizes=[32],
        epochs=2,
        batch_size=3,
        n_folds=2,
        num_workers=0,
        ema_decay=None,
        mixup_alpha=0.0,
        cutmix_alpha=0.0,
    )
    preds = clf.predict(df, return_proba=False, tta=False)
    print(f"✓ Test 1 passed. Predictions shape: {preds.shape}, Accuracy: {accuracy_score(labels, preds):.2f}")

test_1_tiny_dataset()

# ==================== ТЕСТ 2: Multilabel с разными форматами меток ====================
print("\nTEST 2: Multilabel with various label formats")
def test_2_multilabel_formats():
    rng = np.random.default_rng(2)
    N = 30

    # Способ 1: строки с разделителями
    imgs1 = [rng.integers(0, 255, (48, 48, 3), dtype=np.uint8) for _ in range(N)]
    labels1 = []
    for _ in range(N):
        n_labels = rng.integers(1, 4)
        labs = rng.choice(['cat', 'dog', 'bird', 'fish', 'mouse'], n_labels, replace=False)
        labels1.append(" ".join(sorted(labs)))
    df1 = pd.DataFrame({"img": imgs1, "tags": labels1})

    # Способ 2: списки
    imgs2 = [rng.integers(0, 255, (48, 48, 3), dtype=np.uint8) for _ in range(N)]
    labels2 = []
    for _ in range(N):
        n_labels = rng.integers(1, 4)
        labs = rng.choice(['A', 'B', 'C', 'D'], n_labels, replace=False).tolist()
        labels2.append(labs)
    df2 = pd.DataFrame({"img": imgs2, "tags": labels2})

    # Тест со строками
    clf1 = ImageClassificationPipeline(
        target_column_name="tags",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
        multilabel=True,
        label_sep=" ",
    )
    clf1.loss_name = "focal"
    clf1.fit(
        df1,
        img_sizes=[48],
        epochs=2,
        batch_size=8,
        n_folds=2,
        num_workers=0,
        val_metric="map_macro",
        optimize_thresholds=True,
    )
    proba1 = clf1.predict(df1[:5], tta=False)

    # Тест со списками
    clf2 = ImageClassificationPipeline(
        target_column_name="tags",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
        multilabel=True,
    )
    clf2.loss_name = "asl"
    clf2.fit(
        df2,
        img_sizes=[48],
        epochs=2,
        batch_size=8,
        n_folds=2,
        num_workers=0,
        val_metric="map_macro",
    )
    proba2 = clf2.predict(df2[:5], tta=False)

    print(f"✓ Test 2 passed. String labels shape: {proba1.shape}, List labels shape: {proba2.shape}")
    print(f"  Thresholds found: {clf1.thresholds_}")

test_2_multilabel_formats()

# ==================== ТЕСТ 3: Работа с файлами изображений ====================
print("\nTEST 3: Image file paths instead of arrays")
def test_3_file_paths():
    rng = np.random.default_rng(3)
    N = 20

    with tempfile.TemporaryDirectory() as tmpdir:
        paths = []
        labels = []
        for i in range(N):
            img = rng.integers(0, 255, (64, 64, 3), dtype=np.uint8)
            path = os.path.join(tmpdir, f"img_{i}.jpg")
            cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            paths.append(path)
            labels.append(rng.choice(['class_a', 'class_b', 'class_c']))

        df = pd.DataFrame({"image_path": paths, "category": labels})

        clf = ImageClassificationPipeline(
            target_column_name="category",
            image_column_name="image_path",
            model_name="resnet18",
            pretrained=False,
        )

        clf.fit(
            df,
            img_sizes=[64],
            epochs=2,
            batch_size=5,
            n_folds=2,
            num_workers=0,
            stratify=True,
        )
        preds = clf.predict(df[:5], return_proba=True, tta=False)
        print(f"✓ Test 3 passed. File paths work. Predictions shape: {preds.shape}")

test_3_file_paths()

# ==================== ТЕСТ 4: Progressive resize с heavy augmentations ====================
print("\nTEST 4: Progressive resize + heavy augmentations")
def test_4_progressive_resize():
    rng = np.random.default_rng(4)
    N = 40

    imgs = []
    for _ in range(N):
        size = rng.choice([64, 96, 128])
        imgs.append(rng.integers(0, 255, (size, size, 3), dtype=np.uint8))
    labels = rng.integers(0, 4, N)

    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
    )

    clf.fit(
        df,
        img_sizes=[64, 96, 128],
        stage_epochs=[2, 2, 2],
        aug_strength="heavy",
        epochs=6,
        batch_size=8,
        n_folds=2,
        num_workers=0,
    )
    preds = clf.predict(df[:5], return_proba=False, tta=False)
    print(f"✓ Test 4 passed. Progressive resize works. Final predictions: {preds}")

test_4_progressive_resize()

# ==================== ТЕСТ 5: ViT с LLRD и SAM ====================
print("\nTEST 5: Vision Transformer with LLRD and SAM")
def test_5_vit_llrd_sam():
    rng = np.random.default_rng(5)
    N = 30
    imgs = [rng.integers(0, 255, (224, 224, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.integers(0, 3, N)
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="vit_tiny_patch16_224",
        pretrained=False,
        model_img_size=224,
    )

    clf.fit(
        df,
        img_sizes=[224],
        epochs=3,
        batch_size=8,
        n_folds=2,
        num_workers=0,
        layer_decay=0.75,   # LLRD
        use_sam=True,       # SAM optimizer
        sam_rho=0.05,
        lr=1e-3,
    )

    emb = clf.get_embeddings(df[:10])
    print(f"✓ Test 5 passed. ViT+LLRD+SAM works. Embeddings shape: {emb.shape}")

test_5_vit_llrd_sam()

# ==================== ТЕСТ 6: Mixup/Cutmix с отключением на последних эпохах ====================
print("\nTEST 6: Mixup/Cutmix with disable on last epochs")
def test_6_mixup_cutmix():
    rng = np.random.default_rng(6)
    N = 50

    imgs = [rng.integers(0, 255, (64, 64, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.integers(0, 5, N)
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
    )

    clf.fit(
        df,
        img_sizes=[64],
        epochs=5,
        batch_size=10,
        n_folds=2,
        num_workers=0,
        mixup_alpha=0.8,
        cutmix_alpha=1.0,
        label_smoothing=0.1,
        disable_mix_last_n_epochs=2,
    )
    preds = clf.predict(df[:5], return_proba=True, tta=False)
    print(f"✓ Test 6 passed. Mixup/Cutmix with disable works. Predictions shape: {preds.shape}")

test_6_mixup_cutmix()

# ==================== ТЕСТ 7: GroupKFold для предотвращения утечек ====================
print("\nTEST 7: GroupKFold to prevent leakage")
def test_7_group_kfold():
    rng = np.random.default_rng(7)
    N = 60
    groups = np.repeat(np.arange(15), 4)
    imgs = [rng.integers(0, 255, (48, 48, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.integers(0, 2, N)
    df = pd.DataFrame({"img": imgs, "label": labels, "patient_id": groups})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
    )

    clf.fit(
        df,
        img_sizes=[48],
        epochs=2,
        batch_size=8,
        n_folds=3,
        num_workers=0,
        group_column="patient_id",
    )

    print(f"✓ Test 7 passed. GroupKFold works. OOF shape: {clf.oof_pred_proba_.shape}")

test_7_group_kfold()

# ==================== ТЕСТ 8: Imbalanced dataset с WeightedSampler и class weights ====================
print("\nTEST 8: Imbalanced dataset handling (multiclass)")
def test_8_imbalanced():
    rng = np.random.default_rng(8)

    N_class0, N_class1, N_class2 = 100, 20, 10
    imgs, labels = [], []
    for _ in range(N_class0):
        imgs.append(rng.integers(0, 255, (32, 32, 3), dtype=np.uint8)); labels.append(0)
    for _ in range(N_class1):
        imgs.append(rng.integers(0, 255, (32, 32, 3), dtype=np.uint8)); labels.append(1)
    for _ in range(N_class2):
        imgs.append(rng.integers(0, 255, (32, 32, 3), dtype=np.uint8)); labels.append(2)
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
    )

    clf.fit(
        df,
        img_sizes=[32],
        epochs=3,
        batch_size=16,
        n_folds=3,
        num_workers=0,
        use_weighted_sampler=True,
        class_weights_in_loss=True,
        val_metric="f1_macro",
    )

    test_df = df[df['label'].isin([1, 2])].head(10)
    if len(test_df) > 0:
        preds = clf.predict(test_df, return_proba=False, tta=False)
        print(f"✓ Test 8 passed. Imbalanced handling works. Predictions for rare classes: {preds}")
    else:
        print(f"✓ Test 8 passed. Imbalanced handling works.")

test_8_imbalanced()

# ==================== ТЕСТ 9: EfficientNet с разными TTA стратегиями ====================
print("\nTEST 9: EfficientNet with various TTA strategies")
def test_9_efficientnet_tta():
    rng = np.random.default_rng(9)
    N = 25
    imgs = [rng.integers(0, 255, (128, 128, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.integers(0, 3, N)
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="efficientnet_b0",
        pretrained=False,
        tta_hflip=True,
        tta_scales=[112, 128, 144],
        tta_crop_size=112,
    )

    clf.fit(
        df,
        img_sizes=[128],
        epochs=2,
        batch_size=5,
        n_folds=2,
        num_workers=0,
        ema_decay=0.999,
    )

    preds_no_tta = clf.predict(df[:5], return_proba=True, tta=False)
    preds_with_tta = clf.predict(df[:5], return_proba=True, tta=True)

    print(f"✓ Test 9 passed. EfficientNet+TTA works.")
    print(f"  No TTA shape: {preds_no_tta.shape}, With TTA shape: {preds_with_tta.shape}")

test_9_efficientnet_tta()

# ==================== ТЕСТ 10: ConvNeXt с валидацией на отдельном датасете ====================
print("\nTEST 10: ConvNeXt with train/val split")
def test_10_convnext_split():
    rng = np.random.default_rng(10)
    N = 80
    imgs = [rng.integers(0, 255, (96, 96, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.choice(['alpha', 'beta', 'gamma', 'delta'], N)
    df = pd.DataFrame({"img": imgs, "label": labels})
    df_train, df_test = train_test_split(df, test_size=0.25, stratify=df['label'], random_state=10)

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="convnext_tiny",
        pretrained=False,
        drop_path_rate=0.2,
    )

    clf.fit(
        df_train,
        img_sizes=[96],
        epochs=3,
        batch_size=8,
        n_folds=3,
        num_workers=0,
        warmup_epochs=1.0,
        grad_clip=1.0,
        val_metric="accuracy",
    )

    test_proba = clf.predict(df_test, return_proba=True, tta=True)
    test_preds = clf.predict(df_test, return_proba=False, tta=False)
    acc = accuracy_score(df_test['label'].values, [clf.classes_[i] for i in test_preds])
    print(f"✓ Test 10 passed. ConvNeXt works. Test accuracy: {acc:.2f}")

test_10_convnext_split()

# ==================== ТЕСТ 11: Граничный случай - 1 пример на класс ====================
print("\nTEST 11: Edge case - 1 sample per class")
def test_11_one_sample_per_class():
    rng = np.random.default_rng(11)
    imgs = [rng.integers(0, 255, (32, 32, 3), dtype=np.uint8) for _ in range(6)]
    labels = [0, 0, 1, 1, 2, 2]
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
    )

    try:
        clf.fit(
            df,
            img_sizes=[32],
            epochs=1,
            batch_size=2,
            n_folds=2,
            num_workers=0,
            stratify=True,
            mixup_alpha=0.0,
            cutmix_alpha=0.0,
        )
        preds = clf.predict(df, return_proba=False, tta=False)
        print(f"✓ Test 11 passed. Ultra-small dataset works. Predictions: {preds}")
    except Exception as e:
        print(f"✗ Test 11 failed with error: {e}")

test_11_one_sample_per_class()

# ==================== ТЕСТ 12: Swin Transformer с разными входными размерами ====================
print("\nTEST 12: Swin Transformer with different input sizes")
def test_12_swin():
    rng = np.random.default_rng(12)
    N = 30
    imgs = [rng.integers(0, 255, (224, 224, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.integers(0, 4, N)
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="swin_tiny_patch4_window7_224",
        pretrained=False,
        model_img_size=224,
        drop_path_rate=0.1,
    )

    clf.fit(
        df,
        img_sizes=[224],
        epochs=2,
        batch_size=4,
        n_folds=2,
        num_workers=0,
        layer_decay=0.8,
    )

    emb = clf.get_embeddings(df[:10])
    print(f"✓ Test 12 passed. Swin Transformer works. Embeddings shape: {emb.shape}")

test_12_swin()

# ==================== ТЕСТ 13: Multilabel с одним активным классом (граничный случай) ====================
print("\nTEST 13: Multilabel edge case - single active class")
def test_13_multilabel_edge():
    rng = np.random.default_rng(13)
    N = 40
    imgs = [rng.integers(0, 255, (64, 64, 3), dtype=np.uint8) for _ in range(N)]
    labels = []
    for i in range(N):
        if i < 30:
            labels.append(rng.choice(['A', 'B', 'C', 'D', 'E']))
        else:
            n = rng.integers(2, 4)
            labs = rng.choice(['A', 'B', 'C', 'D', 'E'], n, replace=False)
            labels.append(" ".join(sorted(labs)))
    df = pd.DataFrame({"img": imgs, "tags": labels})

    clf = ImageClassificationPipeline(
        target_column_name="tags",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
        multilabel=True,
        label_sep=" ",
    )
    clf.loss_name = "bce"

    clf.fit(
        df,
        img_sizes=[64],
        epochs=3,
        batch_size=8,
        n_folds=3,
        num_workers=0,
        optimize_thresholds=True,
    )
    proba = clf.predict(df[:5], tta=False)
    print(f"✓ Test 13 passed. Multilabel with mostly single labels works.")
    print(f"  Probabilities shape: {proba.shape}, Thresholds: {clf.thresholds_}")

test_13_multilabel_edge()

# ==================== ТЕСТ 14: Holdout (test_size) мультикласс ====================
print("\nTEST 14: Holdout (test_size) multiclass split")
def test_14_holdout_multiclass():
    rng = np.random.default_rng(14)
    N = 120
    imgs = [rng.integers(0, 255, (48, 48, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.integers(0, 3, N)
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
    )

    clf.fit(
        df,
        test_size=0.3,
        stratify=True,
        img_sizes=[48],
        epochs=2,
        batch_size=16,
        num_workers=0,
    )

    mask = (clf.oof_fold_ >= 0)
    print(f"Holdout val size: {mask.sum()} of {len(mask)}")
    preds = clf.predict(df[:8], return_proba=False, tta=False)
    print(f"✓ Test 14 passed. Holdout multiclass works. Predictions: {preds}")

test_14_holdout_multiclass()

# ==================== ТЕСТ 15: Holdout (test_size) мультилейбл ====================
print("\nTEST 15: Holdout (test_size) multilabel split")
def test_15_holdout_multilabel():
    rng = np.random.default_rng(15)
    N, C = 60, 5
    def mk_img():
        return rng.integers(0, 255, (64, 64, 3), dtype=np.uint8)
    imgs = [mk_img() for _ in range(N)]
    tags = []
    for _ in range(N):
        k = rng.integers(1, 4)
        labs = sorted(rng.choice(C, size=k, replace=False).tolist())
        tags.append(" ".join([f"t{i}" for i in labs]))
    df = pd.DataFrame({"img": imgs, "tags": tags})

    clf = ImageClassificationPipeline(
        target_column_name="tags",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
        multilabel=True,
        label_sep=" ",
    )
    clf.loss_name = "asl"

    clf.fit(
        df,
        test_size=0.25,
        img_sizes=[64],
        epochs=2,
        batch_size=8,
        num_workers=0,
        val_metric="map_macro",
        optimize_thresholds=True,
    )

    mask = (clf.oof_fold_ >= 0)
    print(f"Holdout val size (multilabel): {mask.sum()} of {len(mask)}")
    proba = clf.predict(df[:6], tta=False)
    print(f"✓ Test 15 passed. Holdout multilabel works. Proba shape: {proba.shape}")

test_15_holdout_multilabel()

# ==================== ТЕСТ 16: Аккумуляция градиентов (grad_accum_steps) ====================
print("\nTEST 16: Gradient Accumulation")
def test_16_grad_accum():
    rng = np.random.default_rng(16)
    N = 64
    imgs = [rng.integers(0, 255, (64, 64, 3), dtype=np.uint8) for _ in range(N)]
    labels = rng.integers(0, 4, N)
    df = pd.DataFrame({"img": imgs, "label": labels})

    clf = ImageClassificationPipeline(
        target_column_name="label",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
    )

    # Малый батч + накопление = эффективный большой батч
    clf.fit(
        df,
        img_sizes=[64],
        epochs=2,
        batch_size=4,        # микробатч
        grad_accum_steps=4,  # 4 шага => экв. батч = 16
        n_folds=2,
        num_workers=0,
        mixup_alpha=0.0,
        cutmix_alpha=0.0,
    )

    preds = clf.predict(df[:8], return_proba=False, tta=False)
    print(f"✓ Test 16 passed. Grad accumulation works. Predictions: {preds}")

test_16_grad_accum()

# ==================== ТЕСТ 17: Multilabel BCE pos_weight (явно) ====================
print("\nTEST 17: Multilabel BCE with explicit pos_weight")
def test_17_bce_posweight():
    rng = np.random.default_rng(17)
    N, C = 80, 4
    imgs = [rng.integers(0, 255, (72, 72, 3), dtype=np.uint8) for _ in range(N)]
    # Делаем редкий класс t3
    tags = []
    for i in range(N):
        labs = []
        if rng.random() < 0.8:
            labs.append("t0")
        if rng.random() < 0.5:
            labs.append("t1")
        if rng.random() < 0.3:
            labs.append("t2")
        if rng.random() < 0.05:
            labs.append("t3")  # редкий класс
        if not labs:
            labs = ["t0"]
        tags.append(" ".join(sorted(set(labs))))
    df = pd.DataFrame({"img": imgs, "tags": tags})

    clf = ImageClassificationPipeline(
        target_column_name="tags",
        image_column_name="img",
        model_name="resnet18",
        pretrained=False,
        multilabel=True,
        label_sep=" ",
    )
    clf.loss_name = "bce"

    # Явно подаём pos_weight, например, усиливая редкий последний класс
    bce_pw = [1.0, 1.5, 2.0, 10.0]

    clf.fit(
        df,
        img_sizes=[72],
        epochs=2,
        batch_size=8,
        n_folds=2,
        num_workers=0,
        val_metric="map_macro",
        optimize_thresholds=True,
        bce_pos_weight=bce_pw,
    )

    proba = clf.predict(df[:6], tta=False)
    print(f"✓ Test 17 passed. Explicit BCE pos_weight works. Proba shape: {proba.shape}")

test_17_bce_posweight()

# ==================== ФИНАЛЬНАЯ СТАТИСТИКА ====================
print("\n" + "="*60)
print("ALL CRASH TESTS (WITH HOLDOUT, ACCUM, POS_WEIGHT) COMPLETED SUCCESSFULLY!")
print("="*60)
print("The pipeline is robust and handles:")
print("✓ Tiny datasets (N=6-10)")
print("✓ Various image formats (numpy arrays, file paths)")
print("✓ Progressive resize and strong augmentations")
print("✓ Multiple architectures (ResNet, EfficientNet, ConvNeXt, ViT, Swin)")
print("✓ Multiclass and multilabel classification")
print("✓ Imbalanced datasets (sampler + class weights; pos_weight for multilabel)")
print("✓ GroupKFold for preventing data leakage")
print("✓ Advanced optimizations (SAM, LLRD, EMA)")
print("✓ Mixup/Cutmix with smart disabling")
print("✓ Multiple TTA strategies")
print("✓ Edge cases (1-2 samples per class)")
print("✓ Different loss functions (CE, BCE, Focal, ASL)")
print("✓ Holdout mode via test_size for fast iteration")
print("✓ Gradient accumulation (grad_accum_steps)")
print("✓ Explicit pos_weight for multilabel BCE")

Пример использования 1.

In [None]:
import warnings
warnings.filterwarnings("ignore", message="`torch.cuda.amp.GradScaler", category=FutureWarning)

import numpy as np, pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

rng = np.random.default_rng(42)

def make_blob(h, w, color):
    img = np.zeros((h, w, 3), dtype=np.uint8)
    cy, cx = rng.integers(h//4, 3*h//4), rng.integers(w//4, 3*w//4)
    r = rng.integers(min(h,w)//8, min(h,w)//4)
    import cv2
    cv2.circle(img, (cx, cy), int(r), tuple(map(int, color)), -1)
    img += rng.integers(0, 10, size=img.shape, dtype=np.uint8)
    return img

N, H, W = 90, 96, 96
labels = rng.integers(0, 3, size=N)
palette = [(200, 50, 50), (50, 200, 50), (50, 50, 200)]
imgs = [make_blob(H, W, palette[int(y)]) for y in labels]
df = pd.DataFrame({"image": imgs, "label": labels})
df_tr, df_va = train_test_split(df, test_size=0.25, stratify=df["label"], random_state=42)

clf = ImageClassificationPipeline(
    target_column_name="label",
    image_column_name="image",
    model_name="resnet18",
    pretrained=True,
    drop_path_rate=0.1,
    multilabel=False,
    amp=True,
)

clf.fit(
    df_tr,
    img_sizes=[96, 128],
    stage_epochs=[3, 3],
    aug_strength="medium",
    epochs=6,
    batch_size=16,
    num_workers=0,            # безопасно для ноутбуков/Colab
    ema_decay=0.9999,
    lr=2e-3,
    weight_decay=1e-4,
    warmup_epochs=0.5,
    grad_clip=1.0,
    mixup_alpha=0.5,
    cutmix_alpha=0.8,
    label_smoothing=0.05,
    disable_mix_last_n_epochs=1,
    n_folds=3,
    val_metric="f1_macro",
)

proba = clf.predict(df_va, return_proba=True, tta=True)
preds = clf.predict(df_va, return_proba=False, tta=False)
print("VAL F1-macro:", f1_score(df_va["label"].values, preds, average="macro"))

Пример использования 2.

In [None]:
import warnings
warnings.filterwarnings("ignore", message="`torch.cuda.amp.GradScaler", category=FutureWarning)

import numpy as np, pandas as pd

rng = np.random.default_rng(13)
N, H, W, C = 80, 96, 96, 5

def make_multi(h, w, active_ids):
    img = rng.integers(0, 50, size=(h, w, 3), dtype=np.uint8)
    import cv2
    colors = [(200,0,0),(0,200,0),(0,0,200),(200,200,0),(200,0,200)]
    for i in active_ids:
        if i % 2 == 0:
            y = rng.integers(10, h-10)
            cv2.line(img, (10, y), (w-10, y), colors[i], 3)
        else:
            x = rng.integers(10, w-10)
            cv2.line(img, (x, 10), (x, h-10), colors[i], 3)
    return img

tags_all = [f"t{i}" for i in range(C)]
imgs, tag_strings = [], []
for _ in range(N):
    k = rng.integers(1, 4)
    labs = sorted(rng.choice(C, size=k, replace=False).tolist())
    imgs.append(make_multi(H, W, labs))
    tag_strings.append(" ".join([f"t{i}" for i in labs]))

df = pd.DataFrame({"image": imgs, "tags": tag_strings})

clf_ml = ImageClassificationPipeline(
    target_column_name="tags",
    image_column_name="image",
    model_name="resnet18",
    pretrained=True,
    multilabel=True,
    label_sep=" ",
    amp=True,
)
clf_ml.loss_name = "asl"  # "bce" | "asl" | "focal"

clf_ml.fit(
    df,
    img_sizes=[96, 128],
    stage_epochs=[3, 3],
    aug_strength="medium",
    epochs=6,
    batch_size=16,
    num_workers=0,
    ema_decay=0.9999,
    lr=1e-3,
    weight_decay=1e-4,
    val_metric="map_macro",
    optimize_thresholds=True,
    n_folds=3,
)

proba = clf_ml.predict(df, tta=False)
print("Tuned thresholds:", clf_ml.thresholds_)

Пример использования 3.

In [None]:
import warnings
warnings.filterwarnings("ignore", message="`torch.cuda.amp.GradScaler", category=FutureWarning)

import numpy as np, pandas as pd
from sklearn.metrics import f1_score
from sklearn.linear_model import LogisticRegression

rng = np.random.default_rng(7)
N, H, W, KCLS = 90, 224, 224, 4

def make_shape(h, w, cls):
    img = np.zeros((h, w, 3), dtype=np.uint8)
    import cv2
    c = [(220,70,70),(70,220,70),(70,70,220),(220,220,70)][cls]
    if cls == 0:
        cv2.circle(img, (w//2, h//2), min(h,w)//4, c, -1)
    elif cls == 1:
        cv2.rectangle(img, (w//4,h//4), (3*w//4,3*h//4), c, -1)
    elif cls == 2:
        pts = np.array([[w//2,h//4],[w//4,3*h//4],[3*w//4,3*h//4]], np.int32)
        cv2.fillPoly(img, [pts], c)
    else:
        cv2.line(img, (w//4,h//2), (3*w//4,h//2), c, 8)
        cv2.line(img, (w//2,h//4), (w//2,3*h//4), c, 8)
    img += rng.integers(0, 15, size=img.shape, dtype=np.uint8)
    return img

labels = rng.integers(0, KCLS, size=N)
groups = rng.integers(0, 15, size=N)
imgs = [make_shape(H, W, int(y)) for y in labels]
df = pd.DataFrame({"image": imgs, "label": labels, "group_id": groups})

clf_vit = ImageClassificationPipeline(
    target_column_name="label",
    image_column_name="image",
    model_name="vit_tiny_patch16_224",
    pretrained=True,
    multilabel=False,
    model_img_size=224,       # фиксированный вход для ViT
    tta_hflip=True,           # параметры TTA — это часть инференса
    tta_scales=[224],
    tta_crop_size=None,
    amp=True,
    drop_path_rate=0.2,
)

clf_vit.fit(
    df,
    img_sizes=[224],
    stage_epochs=[6],
    aug_strength="medium",
    epochs=6,
    batch_size=16,
    num_workers=0,
    ema_decay=0.9999,
    lr=1e-3,
    weight_decay=0.02,
    layer_decay=0.8,        # LLRD
    n_folds=3,
    group_column="group_id",
    val_metric="f1_macro",
)

proba = clf_vit.predict(df, return_proba=True, tta=True)

# Эмбеддинги + логистическая регрессия
emb = clf_vit.get_embeddings(df)
lr = LogisticRegression(max_iter=300, n_jobs=-1)
lr.fit(emb, labels)
pred_lr = lr.predict(emb)
print("F1-macro (логрег на эмбеддингах):", f1_score(labels, pred_lr, average="macro"))

Пример использования 4.

In [None]:
import warnings
warnings.filterwarnings("ignore", message="`torch.cuda.amp.GradScaler", category=FutureWarning)

import numpy as np, pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score

# 1) Синтетика (4 класса, картинки 224x224)
rng = np.random.default_rng(21)
N, H, W, KCLS = 300, 224, 224, 4

def make_shape(h, w, cls):
    import cv2
    img = np.zeros((h, w, 3), dtype=np.uint8)
    c = [(220,70,70),(70,220,70),(70,70,220),(220,220,70)][cls]
    if cls == 0:
        cv2.circle(img, (w//2, h//2), min(h,w)//4, c, -1)
    elif cls == 1:
        cv2.rectangle(img, (w//4,h//4), (3*w//4,3*h//4), c, -1)
    elif cls == 2:
        pts = np.array([[w//2,h//4],[w//4,3*h//4],[3*w//4,3*h//4]], np.int32)
        cv2.fillPoly(img, [pts], c)
    else:
        cv2.line(img, (w//4,h//2), (3*w//4,h//2), c, 8)
        cv2.line(img, (w//2,h//4), (w//2,3*h//4), c, 8)
    # немного шума
    img += rng.integers(0, 10, size=img.shape, dtype=np.uint8)
    return img

labels = rng.integers(0, KCLS, size=N)
imgs = [make_shape(H, W, int(y)) for y in labels]
df = pd.DataFrame({"image": imgs, "label": labels})

# Отдельный тестовый сплит (валидация на настоящем тесте)
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df["label"], random_state=2025)

# 2) Пайплайн (инициализация только параметрами данных/модели/инференса)
clf = ImageClassificationPipeline(
    target_column_name="label",
    image_column_name="image",
    model_name="vit_tiny_patch16_224",
    pretrained=True,
    multilabel=False,
    model_img_size=224,      # фиксированный вход для ViT
    tta_hflip=True,
    tta_scales=[224],
    tta_crop_size=None,
    amp=True,
    drop_path_rate=0.2,
)

# 4) Обучаем без K-fold: holdout через test_size (например, 10% на валидацию внутри fit)
#    Плюс grad_accum_steps (аккумуляция градиентов) для экономии VRAM.
clf.fit(
    df_train,
    test_size=0.1,            # holdout внутри обучающей выборки
    img_sizes=[224],          # один размер (без прогрессивного ресайза)
    epochs=5,
    batch_size=8,             # микробатч
    grad_accum_steps=2,       # => эффективный батч ~ 16
    num_workers=0,
    ema_decay=0.9995,
    lr=8e-4,
    weight_decay=0.02,
    warmup_epochs=0.5,
    grad_clip=1.0,
    layer_decay=0.8,          # LLRD для ViT (опционально)
    val_metric="f1_macro",
)

# 5) Валидация на внешнем тесте (df_test)
proba_test = clf.predict(df_test, return_proba=True, tta=True)
preds_test = proba_test.argmax(1)
y_true = df_test["label"].values
print("Test Accuracy:", accuracy_score(y_true, preds_test))
print("Test F1-macro:", f1_score(y_true, preds_test, average="macro"))

# Дообучение YOLO для детекции или подсчёта объектов на картинке.

In [None]:
# yolo_detection_pipeline.py
# Обновлённый пайплайн:
# - авто-выбор устройства (GPU '0' / мульти-GPU '0,1,...' / CPU),
# - тюнинг подсчёта: plain (conf/iou/max_det) ИЛИ area-gated (cs/cb/area_thr + iou/max_det),
# - Ridge-калибровка по резидуалу (K-fold CV, стандартизация),
# - устойчивый инференс при мульти-GPU (тюнинг/инференс на первой карте),
# - финальная валидация RMSE/MAE для задачи подсчёта.
#
# Требования:
# pip install -U ultralytics opencv-python pandas numpy scikit-learn tqdm

!pip install ultralytics

import os
import cv2
import json
import shutil
import tempfile
import warnings
import random
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from ultralytics import YOLO
from ultralytics.utils.downloads import attempt_download_asset

from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import KFold


class YOLODetectionPipeline:
    """
    YOLO-пайплайн из pandas DataFrame с опциональным подбором гиперпараметров подсчёта
    (plain или area-gated) и калибровкой линейной моделью (Ridge по резидуалу).

    Входные данные (DataFrame):
      - image_path (str): путь к изображению (jpg/png)
      - boxes_col: GT-боксы в YOLO-нормировке [0..1] (списки/массивы/строка)

    Основной сценарий:
      - fit(): обучает YOLO; (опц.) тюнинг порогов (plain/area-gated + iou/max_det);
               (опц.) калибрует Ridge; (опц.) валидирует RMSE/MAE.
      - predict(): возвращает детекции (boxes_json + count). Калибровка НЕ применяется.
      - predict_counts(): возвращает числовой подсчёт; применяет лучшие пороги/калибровку, если включены.

    Управление временем тюнинга:
      - tune_val_subsample: подвыборка валидации (int — кол-во, float — доля 0..1).
      - tune_max_combinations: ограничивает число проверяемых комбо (случайно из сетки).
    """

    def __init__(self,
                 model_ckpt: str = "yolov8n.pt",
                 data_root: str | None = None,
                 image_col: str = "image_path",
                 boxes_col: str = "boxes",
                 class_names: list[str] | None = None,
                 use_symlinks: bool = True,
                 verbose: bool = True,
                 # переключатели
                 enable_tuning: bool = True,
                 enable_ridge: bool = True,
                 validate_count: bool = True,
                 # режим тюнинга plain vs area-gated
                 enable_area_gate: bool = True,        # True → тюним (conf_small, conf_big, area_thr) + (iou, max_det)
                 enable_tta_flip: bool = False,        # True → TTA flip при plain-подсчёте
                 # управление временем тюнинга
                 tune_val_subsample: int | float | None = None,  # int=кол-во; float=доля [0..1]
                 tune_max_combinations: int | None = 100,
                 random_state: int = 42,
                 # сетки для plain-тюнинга
                 tune_conf_grid = (0.20, 0.25, 0.30, 0.35),
                 tune_iou_grid  = (0.55, 0.60),
                 tune_max_det_grid = (300, 600),
                 # сетки для area-gated (под мелкие объекты из ваших стат)
                 tune_conf_small_grid = (0.10, 0.12, 0.14, 0.18),
                 tune_conf_big_grid   = (0.30, 0.40, 0.50),
                 tune_area_thr_grid   = (0.0008, 0.0010, 0.0012, 0.0015),
                 gate_conf_base: float = 0.07,  # базовый conf для извлечения кандидатов при area-gated
                 # сетка для Ridge
                 ridge_alpha_grid = (0.3, 1.0, 3.0),
                 # пороги нормированной площади для фич Ridge/plain
                 small_thr: float = 0.0010,
                 big_thr: float   = 0.003):
        self.model_ckpt = model_ckpt
        self.image_col = image_col
        self.boxes_col = boxes_col
        self.class_names = class_names or ["obj"]
        self.use_symlinks = use_symlinks
        self.verbose = verbose

        self.enable_tuning = enable_tuning
        self.enable_ridge = enable_ridge
        self.validate_count = validate_count
        self.enable_area_gate = enable_area_gate
        self.enable_tta_flip = enable_tta_flip

        self.tune_val_subsample = tune_val_subsample
        self.tune_max_combinations = tune_max_combinations
        self.random_state = random_state
        random.seed(random_state)
        np.random.seed(random_state)

        # рабочая папка
        self._tmpdir_owned = False
        if data_root is None:
            self.data_root = tempfile.mkdtemp(prefix="yolo_ds_")
            self._tmpdir_owned = True
        else:
            self.data_root = os.path.abspath(data_root)
            os.makedirs(self.data_root, exist_ok=True)

        self.dataset_yaml = os.path.join(self.data_root, "dataset.yaml")
        self.model_path = None
        self._model = None
        self._device = None  # строка устройства, использованная при fit()

        # сетки и пороги
        self.tune_conf_grid = tuple(float(x) for x in tune_conf_grid)
        self.tune_iou_grid  = tuple(float(x) for x in tune_iou_grid)
        self.tune_max_det_grid = tuple(int(x) for x in tune_max_det_grid)

        self.tune_conf_small_grid = tuple(float(x) for x in tune_conf_small_grid)
        self.tune_conf_big_grid   = tuple(float(x) for x in tune_conf_big_grid)
        self.tune_area_thr_grid   = tuple(float(x) for x in tune_area_thr_grid)
        self.gate_conf_base = float(gate_conf_base)

        self.ridge_alpha_grid = tuple(float(x) for x in ridge_alpha_grid)
        self.small_thr = float(small_thr)
        self.big_thr   = float(big_thr)

        # сохранённые результаты тюнинга/калибровки
        self.calib_ = dict(
            # plain режим:
            best_conf=None, best_iou=None, best_max_det=None,
            # area-gated:
            gate_conf_small=None, gate_conf_big=None, gate_area_thr=None, gate_conf_base=None,
            # Ridge:
            ridge_alpha=None, ridge_model=None, ridge_mu=None, ridge_sd=None,
            # общий:
            imgsz=None
        )

    # -------------------- device helpers --------------------
    @staticmethod
    def _resolve_device(device: str | int | None) -> str:
        """
        Выбор устройства для обучения:
          - None / "auto": "0,1,...,N-1" при наличии CUDA, иначе "cpu"
          - иначе вернуть строку как есть (например, "0" или "cpu")
        """
        if device is None or str(device).lower() == "auto":
            if torch.cuda.is_available() and torch.cuda.device_count() > 0:
                n = torch.cuda.device_count()
                return ",".join(str(i) for i in range(n))
            return "cpu"
        return str(device)

    def _infer_device(self) -> str:
        """
        Устройство для инференса/тюнинга:
          - если тренировались на '0,1,...' → берём первую карту '0'
          - если тренировались на 'k' → её же
          - иначе авто: '0' при наличии CUDA, 'cpu' без GPU
        """
        if getattr(self, "_device", None):
            if isinstance(self._device, str) and "," in self._device:
                return self._device.split(",")[0]
            return self._device
        return "0" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu"

    # -------------------- helpers: разметка → YOLO-тxt --------------------
    @staticmethod
    def _is_nan_like(x):
        if x is None: return True
        if isinstance(x, float) and np.isnan(x): return True
        if isinstance(x, str) and x.strip()=="": return True
        return False

    def _parse_boxes(self, row):
        boxes_raw = row[self.boxes_col] if self.boxes_col in row else None
        if self._is_nan_like(boxes_raw): return []
        out = []
        if isinstance(boxes_raw, (list, tuple, np.ndarray)):
            for it in boxes_raw:
                vals = list(map(float, it))
                if len(vals) >= 4:
                    x,y,w,h = vals[:4]
                    if 0 <= x <= 1 and 0 <= y <= 1 and 0 < w <= 1 and 0 < h <= 1:
                        out.append((0, x,y,w,h))
        elif isinstance(boxes_raw, str):
            lines = [ln.strip() for ln in boxes_raw.strip().splitlines() if ln.strip()]
            for ln in lines:
                parts = ln.split()
                vals = list(map(float, parts))
                if len(vals) == 4:
                    x,y,w,h = vals
                    out.append((0, x,y,w,h))
                elif len(vals) >= 5:
                    cls,x,y,w,h = int(vals[0]), *vals[1:5]
                    out.append((cls, float(x),float(y),float(w),float(h)))
        return out

    def _link_or_copy(self, src, dst):
        os.makedirs(os.path.dirname(dst), exist_ok=True)
        if self.use_symlinks:
            try:
                if os.path.lexists(dst): os.remove(dst)
                os.symlink(os.path.abspath(src), dst)
                return
            except Exception:
                pass
        shutil.copy2(src, dst)

    def _write_label_file(self, label_path, boxes):
        os.makedirs(os.path.dirname(label_path), exist_ok=True)
        with open(label_path, "w", encoding="utf-8") as f:
            for cls, x,y,w,h in boxes:
                f.write(f"{int(cls)} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")

    def _materialize(self, train_df, val_df, train_split="train", val_split="val"):
        for split_name, df in [(train_split, train_df), (val_split, val_df)]:
            img_dir = os.path.join(self.data_root, "images", split_name)
            lbl_dir = os.path.join(self.data_root, "labels", split_name)
            os.makedirs(img_dir, exist_ok=True); os.makedirs(lbl_dir, exist_ok=True)
            it = df.iterrows()
            if self.verbose: it = tqdm(it, total=len(df), desc=f"[build] {split_name}")
            for _, row in it:
                src = row[self.image_col]
                if not os.path.exists(src):
                    raise FileNotFoundError(f"Image not found: {src}")
                fname = os.path.basename(src)
                stem, _ = os.path.splitext(fname)
                self._link_or_copy(src, os.path.join(img_dir, fname))
                self._write_label_file(os.path.join(lbl_dir, stem + ".txt"), self._parse_boxes(row))

        with open(self.dataset_yaml, "w", encoding="utf-8") as f:
            f.write(f"path: {self.data_root}\ntrain: images/{train_split}\nval: images/{val_split}\nnames:\n")
            for i, name in enumerate(self.class_names):
                f.write(f"  {i}: {name}\n")

    # -------------------- fit: train + (tune/ridge/validate) --------------------
    def fit(self,
            train_df: pd.DataFrame,
            val_df: pd.DataFrame | None = None,
            test_size: float = 0.2,
            epochs: int = 50,
            imgsz: int = 640,
            batch: int = 16,
            device: str | int | None = "auto",
            workers: int = 4,
            patience: int = 50,
            optimizer: str = "auto",
            augment: bool = True,
            seed: int = 42,
            close_mosaic: int | None = 10,
            cos_lr: bool = True,
            rect: bool = False,
            iou: float = 0.7,
            **extra_train_kwargs):

        # выбрать устройство для тренировки и сохранить
        self._device = self._resolve_device(device)
        if self.verbose:
            print(f"[device] training device='{self._device}'")

        np.random.seed(seed); random.seed(seed)

        # если val_df не задан — делаем простую стратификацию по бинам count
        if val_df is None:
            tmp = train_df.copy()
            counts = [len(self._parse_boxes(r)) for _, r in tmp.iterrows()]
            tmp["_bins"] = np.clip((np.array(counts)//5).astype(int), 0, 50)
            val_mask = tmp.groupby("_bins", group_keys=False).apply(
                lambda g: g.sample(frac=test_size, random_state=seed)).index
            val_df = train_df.loc[val_mask]
            train_df = train_df.drop(index=val_mask)
            train_df = train_df.reset_index(drop=True); val_df = val_df.reset_index(drop=True)

        self._materialize(train_df, val_df)

        # загрузить/скачать чекпоинт
        if not os.path.exists(self.model_ckpt) and self.model_ckpt.endswith(".pt"):
            try:
                if self.verbose: print(f"Checkpoint '{self.model_ckpt}' not found. Attempting to download...")
                attempt_download_asset(self.model_ckpt)
            except Exception as e:
                raise FileNotFoundError(f"Failed to download '{self.model_ckpt}'. Error: {e}")

        # train args
        train_args = {
            'data': self.dataset_yaml, 'epochs': epochs, 'imgsz': imgsz, 'batch': batch,
            'device': self._device, 'workers': workers, 'patience': patience, 'optimizer': optimizer,
            'augment': augment, 'seed': seed, 'close_mosaic': close_mosaic, 'cos_lr': cos_lr,
            'rect': rect, 'iou': iou, 'verbose': self.verbose
        }
        train_args.update(extra_train_kwargs)

        # обучение
        model = YOLO(self.model_ckpt)
        model.train(**train_args)

        # ГАРАНТИРОВАННО берём лучший чекпоинт
        best_path = None
        if hasattr(model, "trainer") and getattr(model.trainer, "best", None):
            best_path = str(model.trainer.best)     # .../runs/detect/exp/weights/best.pt
        elif getattr(model, "ckpt_path", None):
            best_path = str(model.ckpt_path)
        else:
            best_path = self.model_ckpt
        self.model_path = best_path
        if self.verbose:
            print(f"[fit] best model: {self.model_path}")

        # тюнинг/калибровка/валидация
        try:
            self._tune_and_or_calibrate(val_df, imgsz=imgsz)
            if self.validate_count:
                self._validate_counting(val_df)
        except Exception as e:
            if self.verbose:
                print(f"[post-fit] skipped tuning/calibration/validation due to: {e}")

        return self.model_path

    # -------------------- инференс-хелперы --------------------
    def _ensure_model(self):
        if self._model is None:
            path = self.model_path or self.model_ckpt
            self._model = YOLO(path)

    @torch.no_grad()
    def _raw_counts(self, paths, imgsz, conf, iou, max_det):
        """Plain len(detections). Если enable_tta_flip=True — усреднение с augment=True."""
        self._ensure_model()
        out = []
        dev = self._infer_device()
        if not self.enable_tta_flip:
            for i in tqdm(range(0, len(paths), 64), disable=not self.verbose, desc="[counts]"):
                batch = paths[i:i+64]
                res = self._model.predict(batch, imgsz=imgsz, conf=conf, iou=iou,
                                          max_det=max_det, device=dev, verbose=False)
                for r in res:
                    out.append(int(len(r.boxes) if (r.boxes is not None) else 0))
        else:
            for i in tqdm(range(0, len(paths), 64), disable=not self.verbose, desc="[counts-tta]"):
                batch = paths[i:i+64]
                r1 = self._model.predict(batch, imgsz=imgsz, conf=conf, iou=iou,
                                         max_det=max_det, device=dev, verbose=False)
                r2 = self._model.predict(batch, imgsz=imgsz, conf=conf, iou=iou,
                                         max_det=max_det, device=dev, verbose=False, augment=True)
                for a, b in zip(r1, r2):
                    n1 = int(len(a.boxes) if (a.boxes is not None) else 0)
                    n2 = int(len(b.boxes) if (b.boxes is not None) else 0)
                    out.append(0.5 * (n1 + n2))
        return np.array(out, dtype=float)

    @torch.no_grad()
    def _yolo_feats(self, paths, imgsz, conf, iou, max_det):
        """Фичи для Ridge: [n, conf_sum, conf_mean, conf_max, area_mean, frac_small, frac_mid, frac_big]."""
        self._ensure_model()
        rows = []
        dev = self._infer_device()
        for i in tqdm(range(0, len(paths), 64), disable=not self.verbose, desc="[feats]"):
            batch = paths[i:i+64]
            res = self._model.predict(batch, imgsz=imgsz, conf=conf, iou=iou,
                                      max_det=max_det, device=dev, verbose=False)
            for r in res:
                if r.boxes is None or len(r.boxes) == 0:
                    rows.append(dict(n=0, conf_sum=0, conf_mean=0, conf_max=0,
                                     area_mean=0, frac_small=0, frac_mid=0, frac_big=0))
                    continue
                confs = r.boxes.conf.cpu().numpy()
                xywhn = r.boxes.xywhn.cpu().numpy()
                areas = (xywhn[:, 2] * xywhn[:, 3]).clip(0, 1)
                rows.append(dict(
                    n=len(confs),
                    conf_sum=float(confs.sum()),
                    conf_mean=float(confs.mean()),
                    conf_max=float(confs.max()),
                    area_mean=float(areas.mean()),
                    frac_small=float((areas < self.small_thr).mean()),
                    frac_mid=float(((areas >= self.small_thr) & (areas <= self.big_thr)).mean()),
                    frac_big=float((areas > self.big_thr).mean())
                ))
        return pd.DataFrame(rows).to_numpy()

    @torch.no_grad()
    def _detect_conf_area(self, paths, imgsz, conf_base, iou, max_det):
        """Возвращает список массивов Nx2 [conf, area] при базовом пороге conf_base (для area-gated)."""
        self._ensure_model()
        dev = self._infer_device()
        out = []
        for i in tqdm(range(0, len(paths), 64), disable=not self.verbose, desc="[boxes]"):
            batch = paths[i:i+64]
            res = self._model.predict(batch, imgsz=imgsz, conf=conf_base, iou=iou,
                                      max_det=max_det, device=dev, verbose=False)
            for r in res:
                if r.boxes is None or len(r.boxes) == 0:
                    out.append(np.empty((0, 2), dtype=np.float32))
                    continue
                confs = r.boxes.conf.cpu().numpy()
                xywhn = r.boxes.xywhn.cpu().numpy()
                areas = (xywhn[:, 2] * xywhn[:, 3]).clip(0, 1)
                out.append(np.stack([confs, areas], axis=1))
        return out

    @staticmethod
    def _count_with_area_gate(conf_area_list, conf_small, conf_big, area_thr):
        """Подсчёт с двупороговой фильтрацией по площади."""
        counts = []
        for ca in conf_area_list:
            if ca.size == 0:
                counts.append(0); continue
            conf = ca[:, 0]; area = ca[:, 1]
            small_mask = (area < area_thr)  & (conf >= conf_small)
            big_mask   = (area >= area_thr) & (conf >= conf_big)
            counts.append(int(small_mask.sum() + big_mask.sum()))
        return np.array(counts, dtype=float)

    # -------------------- подвыборка валидации --------------------
    def _subset_val(self, val_df: pd.DataFrame) -> pd.DataFrame:
        if self.tune_val_subsample is None:
            return val_df
        n = len(val_df)
        if isinstance(self.tune_val_subsample, float):
            k = max(1, int(round(n * self.tune_val_subsample)))
        else:
            k = int(self.tune_val_subsample)
        k = min(k, n)
        return val_df.sample(n=k, random_state=self.random_state).reset_index(drop=True)

    # -------------------- Ridge по резидуалу с K-fold CV --------------------
    def _fit_ridge_cv_on_residual(self, X: np.ndarray, y_true: np.ndarray, y_plain: np.ndarray):
        alphas = self.ridge_alpha_grid
        k = min(5, len(y_true)) if len(y_true) >= 3 else 2
        kf = KFold(n_splits=k, shuffle=True, random_state=self.random_state)

        # стандартизация фич
        mu = X.mean(axis=0)
        sd = X.std(axis=0); sd[sd == 0] = 1.0
        Xs = (X - mu) / sd
        r = y_true - y_plain

        best_alpha, best_cv = None, 1e9
        for a in alphas:
            cv_scores = []
            for tr, va in kf.split(Xs):
                m = Ridge(alpha=float(a)).fit(Xs[tr], r[tr])
                pr = m.predict(Xs[va])
                cv_scores.append(mean_squared_error(r[va], pr, squared=False))
            cv_rmse = float(np.mean(cv_scores))
            if cv_rmse < best_cv:
                best_alpha, best_cv = float(a), cv_rmse

        # финальная подгонка на всех
        model = Ridge(alpha=best_alpha).fit(Xs, r)
        return dict(model=model, alpha=best_alpha, mu=mu, sd=sd, cv_rmse=best_cv)

    # -------------------- тюнинг и/или калибровка --------------------
    def _tune_and_or_calibrate(self, val_df: pd.DataFrame, imgsz: int):
        val_sub = self._subset_val(val_df)
        paths = val_sub[self.image_col].tolist()
        y_true = np.array([len(self._parse_boxes(r)) for _, r in val_sub.iterrows()], dtype=float)

        # 1) Тюнинг plain или area-gated
        if self.enable_tuning:
            if self.enable_area_gate:
                # Подготовим базовый порог для сбора кандидатов, согласованный с минимальным cs
                min_cs = min(self.tune_conf_small_grid) if len(self.tune_conf_small_grid) else 0.10
                base_collect = max(0.03, min(self.gate_conf_base, min_cs - 0.02))
                if self.verbose:
                    print(f"[tune-gate] base_collect={base_collect:.3f} (min_cs={min_cs:.3f})")

                # 1) Предрасчёт списков [conf, area] для всех (iou, max_det)
                iou_grid = tuple(self.tune_iou_grid)
                md_grid  = tuple(self.tune_max_det_grid)
                conf_area_by_key = {}
                total_prepasses = len(iou_grid) * len(md_grid)
                if self.verbose:
                    print(f"[tune-gate] precomputing boxes for {total_prepasses} (iou,max_det) pairs...")
                for iou_ in iou_grid:
                    for md_ in md_grid:
                        conf_area_by_key[(iou_, md_)] = self._detect_conf_area(
                            paths, imgsz=imgsz,
                            conf_base=float(base_collect),
                            iou=float(iou_), max_det=int(md_)
                        )

                # 2) Подбор cs/cb/area_thr + iou/max_det
                all_combos = []
                for iou_ in iou_grid:
                    for md_ in md_grid:
                        for cs in self.tune_conf_small_grid:
                            # эффективная база (ниже cs)
                            base_eff = max(0.03, min(float(self.gate_conf_base), float(cs) - 0.02))
                            for cb in self.tune_conf_big_grid:
                                for at in self.tune_area_thr_grid:
                                    all_combos.append((float(iou_), int(md_), float(cs), float(cb), float(at), float(base_eff)))

                random.shuffle(all_combos)
                full_space = len(all_combos)
                if self.tune_max_combinations is not None:
                    all_combos = all_combos[:int(self.tune_max_combinations)]
                if self.verbose:
                    print(f"[tune-gate] search combos: {len(all_combos)} (cap), full={full_space}")

                best = dict(rmse=1e9, iou=None, max_det=None, cs=None, cb=None, at=None, base=None)
                for (iou_, md_, cs, cb, at, base_eff) in all_combos:
                    conf_area = conf_area_by_key[(iou_, md_)]
                    y_pred = self._count_with_area_gate(conf_area, cs, cb, at)
                    rmse = mean_squared_error(y_true, y_pred, squared=False)
                    if rmse < best["rmse"]:
                        best.update(dict(rmse=rmse, iou=iou_, max_det=md_,
                                         cs=cs, cb=cb, at=at, base=base_eff))

                if self.verbose:
                    print(f"[tune-gate] best: cs={best['cs']:.3f}, cb={best['cb']:.3f}, "
                          f"area_thr={best['at']:.4f}, iou={best['iou']}, max_det={best['max_det']}  RMSE={best['rmse']:.3f}")

                self.calib_.update(dict(
                    best_conf=None, best_iou=float(best['iou']), best_max_det=int(best['max_det']),
                    gate_conf_small=float(best['cs']), gate_conf_big=float(best['cb']),
                    gate_area_thr=float(best['at']), gate_conf_base=float(best['base'])
                ))
            else:
                combos = [(float(c), float(i), int(m))
                          for i in self.tune_iou_grid
                          for m in self.tune_max_det_grid
                          for c in self.tune_conf_grid]
                random.shuffle(combos)
                full_space = len(combos)
                if self.tune_max_combinations is not None:
                    combos = combos[:int(self.tune_max_combinations)]
                if self.verbose:
                    print(f"[tune-plain] search combos: {len(combos)} (cap), full={full_space}")

                best = dict(rmse=1e9, conf=None, iou=None, max_det=None)
                for conf, iou_v, max_det in combos:
                    y_pred = self._raw_counts(paths, imgsz=imgsz, conf=conf, iou=iou_v, max_det=max_det)
                    rmse = mean_squared_error(y_true, y_pred, squared=False)
                    if rmse < best["rmse"]:
                        best.update(dict(rmse=rmse, conf=conf, iou=iou_v, max_det=max_det))
                if self.verbose:
                    print(f"[tune] best plain count: conf={best['conf']}, iou={best['iou']}, max_det={best['max_det']}  RMSE={best['rmse']:.3f}")

                self.calib_.update(dict(
                    best_conf=best['conf'], best_iou=best['iou'], best_max_det=best['max_det'],
                    gate_conf_small=None, gate_conf_big=None, gate_area_thr=None, gate_conf_base=None
                ))
        else:
            # без тюнинга — дефолты для plain
            self.calib_.update(dict(
                best_conf=0.25, best_iou=0.5, best_max_det=1000,
                gate_conf_small=None, gate_conf_big=None, gate_area_thr=None, gate_conf_base=None
            ))

        # 2) Калибровка Ridge (по резидуалу, с CV)
        ridge_model, ridge_alpha = None, None
        if self.enable_ridge:
            # строим y_plain на том же сабсете val_sub
            if self.enable_tuning and self.enable_area_gate and self.calib_.get("gate_conf_small") is not None:
                iou_use = float(self.calib_["best_iou"])
                md_use  = int(self.calib_["best_max_det"])
                conf_area = self._detect_conf_area(paths, imgsz=imgsz,
                                                   conf_base=float(self.calib_["gate_conf_base"]),
                                                   iou=iou_use, max_det=md_use)
                y_plain = self._count_with_area_gate(conf_area,
                                                     float(self.calib_["gate_conf_small"]),
                                                     float(self.calib_["gate_conf_big"]),
                                                     float(self.calib_["gate_area_thr"]))
                X = self._yolo_feats(paths, imgsz=imgsz,
                                     conf=float(self.calib_["gate_conf_base"]),
                                     iou=iou_use, max_det=md_use)
            else:
                conf_use = float(self.calib_["best_conf"] if self.calib_.get("best_conf") is not None else 0.25)
                iou_use  = float(self.calib_["best_iou"]  if self.calib_.get("best_iou")  is not None else 0.5)
                md_use   = int(self.calib_["best_max_det"] if self.calib_.get("best_max_det") is not None else 1000)
                y_plain  = self._raw_counts(paths, imgsz=imgsz, conf=conf_use, iou=iou_use, max_det=md_use)
                X = self._yolo_feats(paths, imgsz=imgsz, conf=conf_use, iou=iou_use, max_det=md_use)

            pack = self._fit_ridge_cv_on_residual(X, y_true, y_plain)
            ridge_model, ridge_alpha = pack["model"], pack["alpha"]
            if self.verbose:
                print(f"[calib] Ridge(residual) alpha={ridge_alpha}  CV-RMSE(resid)={pack['cv_rmse']:.3f}")

            # сохраним стандартизацию
            self.calib_.update(dict(
                ridge_alpha=ridge_alpha, ridge_model=ridge_model,
                ridge_mu=pack["mu"], ridge_sd=pack["sd"], imgsz=imgsz
            ))
        else:
            self.calib_.update(dict(ridge_alpha=None, ridge_model=None, imgsz=imgsz))

    # -------------------- финальная валидация подсчёта --------------------
    def _validate_counting(self, val_df: pd.DataFrame):
        paths = val_df[self.image_col].tolist()
        y_true = np.array([len(self._parse_boxes(r)) for _, r in val_df.iterrows()], dtype=float)

        imgsz = self.calib_['imgsz'] or 640

        # plain/gate
        if self.enable_tuning and self.enable_area_gate and self.calib_.get("gate_conf_small") is not None:
            iou_use = float(self.calib_["best_iou"])
            md_use  = int(self.calib_["best_max_det"])
            conf_area = self._detect_conf_area(paths, imgsz=imgsz, conf_base=self.calib_["gate_conf_base"],
                                               iou=iou_use, max_det=md_use)
            y_plain = self._count_with_area_gate(conf_area,
                                                 self.calib_["gate_conf_small"],
                                                 self.calib_["gate_conf_big"],
                                                 self.calib_["gate_area_thr"])
        else:
            conf  = self.calib_['best_conf'] if self.enable_tuning else 0.25
            iou_v   = self.calib_['best_iou']  if self.enable_tuning else 0.5
            max_det = self.calib_['best_max_det'] if self.enable_tuning else 1000
            y_plain = self._raw_counts(paths, imgsz, conf, iou_v, max_det)

        rmse_plain = mean_squared_error(y_true, y_plain, squared=False)
        mae_plain  = mean_absolute_error(y_true, y_plain)

        # calibrated
        if self.enable_ridge and self.calib_['ridge_model'] is not None:
            if self.enable_tuning and self.enable_area_gate and self.calib_.get("gate_conf_small") is not None:
                X = self._yolo_feats(paths, imgsz=imgsz, conf=float(self.calib_["gate_conf_base"]),
                                     iou=float(self.calib_["best_iou"]), max_det=int(self.calib_["best_max_det"]))
            else:
                X = self._yolo_feats(paths, imgsz=imgsz,
                                     conf=(self.calib_["best_conf"] if self.calib_["best_conf"] is not None else 0.25),
                                     iou=(self.calib_["best_iou"] if self.calib_["best_iou"] is not None else 0.5),
                                     max_det=(self.calib_["best_max_det"] if self.calib_["best_max_det"] is not None else 1000))
            mu = self.calib_.get("ridge_mu"); sd = self.calib_.get("ridge_sd")
            Xs = (X - mu) / sd
            resid = self.calib_['ridge_model'].predict(Xs)
            y_cal = np.clip(y_plain + resid, 0, None)
            rmse_cal = mean_squared_error(y_true, y_cal, squared=False)
            mae_cal  = mean_absolute_error(y_true, y_cal)
            print(f"[val-count] plain: RMSE={rmse_plain:.3f}, MAE={mae_plain:.3f}  |  calibrated: RMSE={rmse_cal:.3f}, MAE={mae_cal:.3f}")
        else:
            print(f"[val-count] plain: RMSE={rmse_plain:.3f}, MAE={mae_plain:.3f}  |  calibrated: (disabled)")

    # -------------------- публичный инференс: детекции --------------------
    @torch.no_grad()
    def predict(self, df: pd.DataFrame,
                conf: float = 0.25, iou: float = 0.6,
                imgsz: int = 640, device: str | int | None = "auto",
                max_det: int = 300, agnostic_nms: bool = False) -> pd.DataFrame:
        """Детекции (калибровка НЕ используется)."""
        assert self.image_col in df.columns
        if self._model is None:
            self._model = YOLO(self.model_path or self.model_ckpt)

        dev = self._resolve_device(device if device is not None else self._infer_device())

        paths = df[self.image_col].tolist()
        preds = []
        for i in tqdm(range(0, len(paths), 64), disable=not self.verbose, desc="[predict]"):
            batch = paths[i:i+64]
            res = self._model(batch, conf=conf, iou=iou, imgsz=imgsz,
                              device=dev, verbose=False, max_det=max_det,
                              agnostic_nms=agnostic_nms)
            for r in res:
                boxes = []
                if r.boxes is not None and len(r.boxes):
                    xywhn = r.boxes.xywhn.cpu().numpy()
                    confv = r.boxes.conf.cpu().numpy()
                    clsv  = r.boxes.cls.cpu().numpy().astype(int)
                    for (x,y,w,h), c, k in zip(xywhn, confv, clsv):
                        boxes.append({"cls": int(k), "conf": float(c),
                                      "x": float(x), "y": float(y), "w": float(w), "h": float(h)})
                preds.append({
                    self.image_col: r.path,
                    "count": len(boxes),
                    "boxes_json": json.dumps(boxes, ensure_ascii=False)
                })
        return pd.DataFrame(preds)

    # -------------------- публичный инференс: подсчёт --------------------
    @torch.no_grad()
    def predict_counts(self, df: pd.DataFrame,
                       imgsz: int | None = None,
                       conf: float | None = None,
                       iou: float | None = None,
                       max_det: int | None = None,
                       device: str | int | None = "auto",
                       clamp_nonneg: bool = True,
                       do_round: bool = False) -> pd.DataFrame:
        """
        Подсчёт объектов.
          - Если enable_ridge=True и калибратор обучен → y = y_plain + Ridge(residual).
          - Иначе → plain len(dets).
          - Если enable_area_gate=True и тюнинг выполнен → двупороговая фильтрация (conf_small/conf_big) по area.
        """
        assert self.image_col in df.columns
        if self._model is None:
            self._model = YOLO(self.model_path or self.model_ckpt)

        dev = self._resolve_device(device if device is not None else self._infer_device())
        imgsz = imgsz or self.calib_['imgsz'] or 640
        paths = df[self.image_col].tolist()

        # area-gated путь
        if self.enable_tuning and self.enable_area_gate and self.calib_.get("gate_conf_small") is not None:
            conf_base = float(self.calib_["gate_conf_base"])
            iou_use   = float(self.calib_["best_iou"])
            max_det_use = int(self.calib_["best_max_det"])
            conf_area = self._detect_conf_area(paths, imgsz=imgsz, conf_base=conf_base, iou=iou_use, max_det=max_det_use)
            cs, cb, at = float(self.calib_["gate_conf_small"]), float(self.calib_["gate_conf_big"]), float(self.calib_["gate_area_thr"])
            y_plain = self._count_with_area_gate(conf_area, cs, cb, at)

            if self.enable_ridge and self.calib_.get("ridge_model") is not None:
                X = self._yolo_feats(paths, imgsz=imgsz, conf=conf_base, iou=iou_use, max_det=max_det_use)
                # стандартизация и предсказание резидуала
                mu = self.calib_.get("ridge_mu"); sd = self.calib_.get("ridge_sd")
                Xs = (X - mu) / sd
                resid = self.calib_['ridge_model'].predict(Xs)
                y = y_plain + resid
            else:
                y = y_plain

        else:
            # обычный plain путь
            if self.enable_tuning and self.calib_.get("best_conf") is not None:
                conf_def, iou_def, max_det_def = self.calib_['best_conf'], self.calib_['best_iou'], self.calib_['best_max_det']
            else:
                conf_def, iou_def, max_det_def = 0.25, 0.5, 1000

            use_conf  = conf    if conf    is not None else conf_def
            use_iou   = iou     if iou     is not None else iou_def
            use_maxdet= max_det if max_det is not None else max_det_def

            if self.enable_ridge and self.calib_.get("ridge_model") is not None:
                X = self._yolo_feats(paths, imgsz=imgsz, conf=use_conf, iou=use_iou, max_det=use_maxdet)
                mu = self.calib_.get("ridge_mu"); sd = self.calib_.get("ridge_sd")
                Xs = (X - mu) / sd
                resid = self.calib_['ridge_model'].predict(Xs)
                # базовый plain-счёт для сложения с резидуалом:
                y_plain = self._raw_counts(paths, imgsz=imgsz, conf=use_conf, iou=use_iou, max_det=use_maxdet)
                y = y_plain + resid
            else:
                y = self._raw_counts(paths, imgsz=imgsz, conf=use_conf, iou=use_iou, max_det=use_maxdet)

        if clamp_nonneg: y = np.clip(y, 0, None)
        if do_round:     y = np.rint(y)

        out = df[[self.image_col]].copy()
        out["label"] = y
        return out

    # -------------------- housekeeping --------------------
    def cleanup(self):
        if self._tmpdir_owned and os.path.isdir(self.data_root):
            shutil.rmtree(self.data_root, ignore_errors=True)

Пример использования. (Специализирован для подсчёта объектов)

In [None]:
pipe = YOLODetectionPipeline(
    model_ckpt="yolov8l.pt",
    image_col="image_path",
    boxes_col="label",
    class_names=["obj"],
    verbose=True,
    use_symlinks=True,

    enable_tuning=True,
    enable_ridge=True,
    validate_count=True,

    enable_area_gate=True,
    enable_tta_flip=False,

    tune_val_subsample=None,
    tune_max_combinations=500,

    tune_conf_grid=(0.18, 0.21, 0.24, 0.27, 0.30, 0.33, 0.36, 0.39, 0.42),
    tune_iou_grid=(0.4, 0.45, 0.50, 0.55, 0.60, 0.65),
    tune_max_det_grid=(50, 100, 200, 300, 600),
    tune_conf_small_grid=(0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20),
    tune_conf_big_grid=(0.28, 0.32, 0.36, 0.40, 0.45, 0.50, 0.55),
    tune_area_thr_grid=(0.0006, 0.0008, 0.0010, 0.0012, 0.0015, 0.0018),
    ridge_alpha_grid=(0.01, 0.03, 0.05, 0.075, 0.1, 0.3, 0.6, 1.0, 2.0),

    gate_conf_base=0.07,
    small_thr=0.0010,
    big_thr=0.003,

    random_state=42
)

pipe.fit(
    train_df=train_df,
    val_df=val_df,
    epochs=60,
    imgsz=640,
    batch=32,
    device='0,1',
    workers=4,
    lr0=0.01,
    cos_lr=True,
    rect=True,
    iou=0.5
)