# Дообучение encoder'а для классификации токенов.

Реализация классификатора.

In [None]:
!pip install --upgrade --no-cache-dir \
  --extra-index-url https://download.pytorch.org/whl/cu124 \
  numpy==1.26.4 \
  pandas==2.2.3 \
  tqdm==4.67.1 \
  transformers==4.51.3 \
  evaluate==0.4.5 \
  torch==2.6.0+cu124 \
  seqeval==1.2.2

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import gc
import math
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import torch
import evaluate
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForTokenClassification,
    TrainerCallback,
    PrinterCallback,
    EarlyStoppingCallback,
)
from transformers.modeling_outputs import ModelOutput
import pandas as pd
from datasets import Dataset
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm


def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class PbarConsoleLogger(TrainerCallback):
    def __init__(self, pbar):
        self.pbar = pbar
        self.last_logs = {}
        self.last_train_loss = None
        self.printed_eval_steps = set()

    def _step(self, state) -> int:
        return int(state.global_step or 0)

    def _fmt_postfix(self):
        parts = []
        if 'loss' in self.last_logs:
            parts.append(f"loss {self.last_logs['loss']:.4f}")
        if 'eval_loss' in self.last_logs:
            parts.append(f"val {self.last_logs['eval_loss']:.4f}")
        for k, v in self.last_logs.items():
            if k.startswith('eval_') and k not in (
                'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'
            ):
                try:
                    parts.append(f"{k.replace('eval_', '')} {float(v):.4f}")
                except Exception:
                    pass
        return " | ".join(parts)

    def on_step_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        if self.last_logs:
            self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        for k, v in logs.items():
            if isinstance(v, (int, float)):
                self.last_logs[k] = float(v)
        if 'loss' in logs and isinstance(logs['loss'], (int, float)):
            self.last_train_loss = float(logs['loss'])

        self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

        if any(k.startswith('eval_') for k in logs.keys()):
            step = self._step(state)
            if step in self.printed_eval_steps:
                return
            self.printed_eval_steps.add(step)

            train_loss_str = f"{self.last_train_loss:.10f}" if self.last_train_loss is not None else "n/a"
            val_loss = logs.get('eval_loss', None)
            val_loss_str = f"{float(val_loss):.10g}" if isinstance(val_loss, (int, float)) else "n/a"

            exclude = {'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'}
            extra_parts = []
            for k, v in logs.items():
                if k.startswith('eval_') and k not in exclude:
                    metric_name = k.replace('eval_', '')
                    try:
                        extra_parts.append(f"val {metric_name}: {float(v):.10f}")
                    except Exception:
                        pass

            line = f"step: {step}, train loss: {train_loss_str}, val loss: {val_loss_str}"
            if extra_parts:
                line += ", " + ", ".join(extra_parts)
            tqdm.write(line)

    def on_train_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        self.pbar.refresh()


class WeightedTokenCETrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = torch.as_tensor(class_weights, dtype=torch.float32) if class_weights is not None else None
        self._warned_label_tiling = False

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits if hasattr(outputs, "logits") else outputs["logits"]

        if logits.size(0) != labels.size(0):
            ngpu = torch.cuda.device_count()
            if ngpu > 1 and logits.size(0) == labels.size(0) * ngpu:
                labels = labels.repeat_interleave(ngpu, dim=0)
                if not self._warned_label_tiling:
                    print(f"[Warning] DataParallel удвоил batch для logits. "
                          f"Повторяем labels x{ngpu}. logits: {tuple(logits.shape)}, labels: {tuple(labels.shape)}")
                    self._warned_label_tiling = True
            else:
                raise ValueError(f"Batch size mismatch: logits {tuple(logits.shape)} vs labels {tuple(labels.shape)}")

        loss_fct = torch.nn.CrossEntropyLoss(
            weight=(self.class_weights.to(logits.device) if self.class_weights is not None else None),
            ignore_index=-100,
        )
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return (loss, outputs) if return_outputs else loss


class TokenClassification:
    """
    Класс-обёртка для обучения и инференса моделей токен-классификации (NER/POS и т.д.)
    на базе Hugging Face Transformers. Поддерживает обучение со «скользящим окном»,
    выравнивание меток слов с субтокенами, расчёт весов классов по словам, раннюю остановку,
    агрегирование логитов по перекрывающимся окнам и извлечение эмбеддингов слов.
    """

    def __init__(
        self,
        checkpoint: str,
        label2id: Dict[str, int],
        tokens_column_name: str,
        tags_column_name: str
    ):
        """
        Инициализирует модель, токенайзер и инфраструктуру для обучения/инференса.

        :param checkpoint: имя/путь модели в Hugging Face (например, 'bert-base-cased').
        :param label2id: словарь отображения строковых меток в целочисленные id.
        :param tokens_column_name: имя колонки DataFrame с токенами (словами).
        :param tags_column_name: имя колонки DataFrame с метками (строки или уже id).
        :return: None
        :raises: ValueError при некорректных входных параметрах (например, пустой label2id).
        """
        self.id2label = {v: k for k, v in label2id.items()}
        self.label2id = label2id

        self.model = AutoModelForTokenClassification.from_pretrained(
            checkpoint,
            num_labels=len(self.id2label),
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        self.tokens_column_name = tokens_column_name
        self.tags_column_name = tags_column_name

        # Градиентный чекпоинтинг (если поддерживается моделью)
        try:
            self.model.gradient_checkpointing_enable()
        except Exception:
            pass

        self.trainer: Optional[Trainer] = None
        self.compute_metrics = None
        self.data_collator = DataCollatorForTokenClassification(tokenizer=self.tokenizer)
        self.progress_callback: Optional[TrainerCallback] = None

    # ------------------------------
    # Вспомогательные хелперы
    # ------------------------------
    @staticmethod
    def _labels_are_strings(labels_col_list) -> bool:
        """
        Определяет, представлены ли метки строками (а не id).

        :param labels_col_list: итерируемая коллекция списков меток (по документам).
        :return: True, если метки строковые; False, если уже id или все пусто.
        :raises: None
        """
        for tags in labels_col_list:
            if isinstance(tags, (list, tuple)) and len(tags) > 0:
                return isinstance(tags[0], str)
        return False

    def _label_to_id(self, tag: str) -> int:
        """
        Преобразует строковую метку в id согласно self.label2id.

        :param tag: строковая метка.
        :return: целочисленный id метки.
        :raises ValueError: если метка отсутствует в label2id.
        """
        if tag not in self.label2id:
            raise ValueError(
                f"Unknown label encountered: '{tag}'. "
                f"Known labels: {sorted(self.label2id.keys())}"
            )
        return int(self.label2id[tag])

    def _assert_tokens_labels_same_len(self, tokens_seq, labels_seq):
        """
        Проверяет совпадение длины списков токенов и меток для каждого документа.

        :param tokens_seq: iterable со списками токенов (по документам).
        :param labels_seq: iterable со списками меток (по документам).
        :return: None
        :raises ValueError: если типы неверны или длины не совпадают.
        """
        for i, (toks, labs) in enumerate(zip(tokens_seq, labels_seq)):
            if not isinstance(toks, (list, tuple)) or not isinstance(labs, (list, tuple)):
                raise ValueError(
                    f"Row {i}: tokens/labels must be lists, got "
                    f"{type(toks).__name__} and {type(labs).__name__}"
                )
            if len(toks) != len(labs):
                raise ValueError(
                    f"Row {i}: tokens and labels length mismatch: "
                    f"{len(toks)} vs {len(labs)}"
                )

    @staticmethod
    def _to_token_list(obj):
        """
        Приводит значение ячейки к списку токенов.

        :param obj: значение колонки токенов (list/tuple/np.ndarray/None/другое).
        :return: список токенов (или пустой список при неподдерживаемом типе).
        :raises: None
        """
        if obj is None:
            return []
        if isinstance(obj, (list, tuple)):
            return list(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return []

    def _get_effective_max_length(self) -> int:
        """
        Возвращает безопасную максимальную длину контекста:
        min(model.config.max_position_embeddings, tokenizer.model_max_length),
        игнорируя «бесконечные» значения токенайзера.

        :return: целочисленное значение безопасной максимальной длины.
        :raises: None
        """
        m_conf = int(getattr(self.model.config, "max_position_embeddings", 512) or 512)
        m_tok = int(getattr(self.tokenizer, "model_max_length", 512) or 512)
        if m_tok > 100000:
            return m_conf
        return min(m_conf, m_tok)

    @staticmethod
    def _sanitize_stride(stride: int, max_length: int) -> int:
        """
        Ограничивает stride до [0, max_length - 2], учитывая спецтокены.

        :param stride: желаемый страйд перекрытия.
        :param max_length: безопасная максимальная длина контекста.
        :return: целочисленное безопасное значение stride.
        :raises: None
        """
        stride = int(max(0, stride))
        return int(min(stride, max(0, max_length - 2)))

    # ------------------------------
    # Алгоритмика
    # ------------------------------
    @staticmethod
    def _align_labels_with_word_ids(labels_ids: List[int], word_ids: List[Optional[int]]) -> List[int]:
        """
        Выравнивает метки слов по субтокенам: первый субтокен слова получает метку,
        последующие субтокены — -100 (игнор в CrossEntropy).

        :param labels_ids: список меток по словам (id), длина = числу слов.
        :param word_ids: список индексов слов для каждого субтокена (tokenizer.word_ids()).
        :return: список меток по длине субтокенов, с -100 для игнорируемых позиций.
        :raises: None
        """
        new_labels = []
        prev_word_id = None
        L = len(labels_ids)
        for wid in word_ids:
            if wid is None or wid < 0 or wid >= L:
                new_labels.append(-100)
            else:
                if wid != prev_word_id:
                    new_labels.append(labels_ids[wid])
                else:
                    new_labels.append(-100)
            prev_word_id = wid
        return new_labels

    def _tokenize_and_align_chunk(
        self,
        docs_tokens: List[List[str]],
        docs_labels_ids: List[List[int]],
        max_length: int,
        stride: int
    ) -> Dataset:
        """
        Токенизирует документы с разбиением на окна и выравниванием меток.

        :param docs_tokens: списки токенов по документам.
        :param docs_labels_ids: списки меток (id) по документам.
        :param max_length: безопасная максимальная длина контекста.
        :param stride: перекрытие между окнами.
        :return: HF Dataset с полями input_ids, attention_mask, labels.
        :raises: None
        """
        enc = self.tokenizer(
            docs_tokens,
            is_split_into_words=True,
            return_overflowing_tokens=True,
            max_length=max_length,
            stride=stride,
            truncation=True
        )
        mapping = enc.pop("overflow_to_sample_mapping")
        num_chunks = len(enc["input_ids"])

        all_labels = []
        for i in range(num_chunks):
            doc_idx = int(mapping[i])
            word_ids = enc.word_ids(batch_index=i)
            aligned = self._align_labels_with_word_ids(docs_labels_ids[doc_idx], word_ids)
            all_labels.append(aligned)

        return Dataset.from_dict({
            "input_ids": enc["input_ids"],
            "attention_mask": enc["attention_mask"],
            "labels": all_labels
        })

    def _count_total_chunks(
        self,
        docs_tokens: List[List[str]],
        max_length: int,
        stride: int,
        batch_docs: int = 64
    ) -> int:
        """
        Подсчитывает число чанков (окон) после токенизации набора документов.

        :param docs_tokens: списки токенов по документам.
        :param max_length: безопасная максимальная длина контекста.
        :param stride: перекрытие между окнами.
        :param batch_docs: размер батча документов при токенизации.
        :return: общее число чанков (int).
        :raises: None
        """
        total = 0
        for i in range(0, len(docs_tokens), batch_docs):
            batch = docs_tokens[i:i + batch_docs]
            enc = self.tokenizer(
                batch,
                is_split_into_words=True,
                return_overflowing_tokens=True,
                max_length=max_length,
                stride=stride,
                truncation=True
            )
            total += len(enc["input_ids"])
        return total

    def _compute_class_weights_over_words(self, docs_labels_ids: List[List[int]]) -> np.ndarray:
        """
        Считает веса классов по словам (без влияния overlap-окон).

        :param docs_labels_ids: списки меток (id) по документам.
        :return: массив весов классов shape=(num_labels,), dtype=float32.
        :raises: None
        """
        num_labels = len(self.id2label)
        counts = np.zeros(num_labels, dtype=np.int64)
        for labs in docs_labels_ids:
            if isinstance(labs, (list, tuple)) and len(labs) > 0:
                arr = np.asarray(labs, dtype=np.int64)
                arr = arr[(arr >= 0) & (arr < num_labels)]
                if arr.size > 0:
                    counts += np.bincount(arr, minlength=num_labels)
        N = counts.sum()
        weights = np.zeros(num_labels, dtype=np.float32)
        if N > 0:
            nonzero = counts > 0
            weights[nonzero] = N / (num_labels * counts[nonzero].astype(np.float32))
        return weights

    @staticmethod
    def _normalize_clip_weights(w: np.ndarray, clip: float = 5.0) -> np.ndarray:
        """
        Нормирует и клипует веса классов: клип сверху до clip и нормировка
        положительных весов к среднему ~1.0.

        :param w: исходные веса классов.
        :param clip: верхняя граница клипа (None/<=0 — без клипа).
        :return: нормированные веса dtype=float32.
        :raises: None
        """
        w = np.asarray(w, dtype=np.float32)
        if clip is not None and clip > 0:
            w = np.minimum(w, clip)
        mask = w > 0
        mean = float(np.mean(w[mask])) if np.any(mask) else 1.0
        if mean > 0:
            w = w / mean
        return w

    def _setup_compute_metrics(self):
        """
        Создаёт и сохраняет функцию метрик для seqeval (self.compute_metrics).

        Метрики:
        - precision/recall/f1/accuracy — агрегированные;
        - f1_{entity} — по каждой сущности.

        :return: None
        :raises: None
        """
        metric = evaluate.load("seqeval")

        def compute_seqeval_metrics(p):
            if isinstance(p, (tuple, list)):
                predictions, labels = p
            else:
                predictions, labels = p.predictions, p.label_ids

            predictions = np.argmax(predictions, axis=2)

            true_predictions = [
                [self.id2label[p] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
            true_labels = [
                [self.id2label[l] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]

            results = metric.compute(predictions=true_predictions, references=true_labels)

            out = {
                "precision": results.get("overall_precision", 0.0),
                "recall": results.get("overall_recall", 0.0),
                "f1": results.get("overall_f1", 0.0),
                "accuracy": results.get("overall_accuracy", 0.0),
            }
            for ent, vals in results.items():
                if isinstance(vals, dict) and "f1" in vals:
                    out[f"f1_{ent}"] = float(vals["f1"])
            return out

        self.compute_metrics = compute_seqeval_metrics

    def _prepare_dataset_with_sliding_window(self, df: pd.DataFrame, max_length: int, stride: int) -> Dataset:
        """
        Готовит HF Dataset для оценки/валидации со «скользящим окном».

        :param df: DataFrame с колонками токенов и меток.
        :param max_length: безопасная максимальная длина контекста.
        :param stride: перекрытие между окнами.
        :return: HF Dataset с полями input_ids, attention_mask, labels.
        :raises ValueError: при неверных типах или несовпадении длины токенов и меток.
        """
        docs_tokens = df[self.tokens_column_name].tolist()
        docs_labels = df[self.tags_column_name].tolist()

        if self._labels_are_strings(docs_labels):
            docs_labels = [[self._label_to_id(tag) for tag in tags] for tags in docs_labels]

        filtered_tokens, filtered_labels = [], []
        for i, (toks, labs) in enumerate(zip(docs_tokens, docs_labels)):
            if not isinstance(toks, (list, tuple)) or not isinstance(labs, (list, tuple)):
                raise ValueError(
                    f"Row {i}: tokens/labels must be lists, got "
                    f"{type(toks).__name__} and {type(labs).__name__}"
                )
            if len(toks) == 0 and len(labs) == 0:
                continue
            if len(toks) != len(labs):
                raise ValueError(
                    f"Row {i}: tokens and labels length mismatch: "
                    f"{len(toks)} vs {len(labs)}"
                )
            filtered_tokens.append(list(toks))
            filtered_labels.append(list(labs))

        if len(filtered_tokens) == 0:
            return Dataset.from_dict({"input_ids": [], "attention_mask": [], "labels": []})

        return self._tokenize_and_align_chunk(filtered_tokens, filtered_labels, max_length, stride)

    # ------------------------------
    # Обучение
    # ------------------------------
    def fit(
        self,
        train_data: pd.DataFrame,
        epochs: int = 3,
        per_device_train_batch_size: int = 16,
        gradient_accumulation_steps: int = 1,
        test_size: float = 0.2,
        learning_rate: float = 2e-5,
        fp16: bool = True,
        stride: int = 128,
        logging_steps: int = 50,
        eval_steps: int = 100,
        output_dir: str = "./result",
        seed: int = 42,
        fit_chunk_size_docs: Optional[int] = None,
        early_stopping_patience: Optional[int] = 3,
        early_stopping_threshold: float = 0.0,
    ):
        """
        Обучает модель токен-классификации на данных.

        :param train_data: DataFrame с колонками токенов и меток.
        :param epochs: число эпох (проходов) по обучающему набору документов.
        :param per_device_train_batch_size: размер батча на устройство при обучении.
        :param gradient_accumulation_steps: число шагов аккумуляции градиента.
        :param test_size: доля/размер валидации; при слишком малом наборе eval отключается автоматически.
        :param learning_rate: скорость обучения (LR).
        :param fp16: использовать ли fp16 (если bf16 не используется и доступен CUDA).
        :param stride: перекрытие между окнами для токенизации длинных документов.
        :param logging_steps: частота логирования в шагах.
        :param eval_steps: частота валидации/сохранения (если есть eval).
        :param output_dir: директория для артефактов обучения.
        :param seed: seed для воспроизводимости.
        :param fit_chunk_size_docs: сколько документов обучать за один «кусок» перед сменой train_dataset (None = все).
        :param early_stopping_patience: количество подряд неулучшающихся точек валидации до остановки;
                                       если None или <= 0 — ранняя остановка не используется.
        :param early_stopping_threshold: минимальное относительное улучшение метрики, требуемое для сброса счётчика patience.
        :return: self (для чейнинга).
        :raises ValueError: при несогласованных данных (тип/длина токенов и меток).
        """
        set_seed(seed)

        max_length = self._get_effective_max_length()
        stride = self._sanitize_stride(stride, max_length)

        df_all = train_data.copy()
        if self._labels_are_strings(df_all[self.tags_column_name].tolist()):
            df_all[self.tags_column_name] = df_all[self.tags_column_name].apply(
                lambda tags: [self._label_to_id(tag) for tag in tags]
            )

        self._assert_tokens_labels_same_len(
            df_all[self.tokens_column_name].tolist(),
            df_all[self.tags_column_name].tolist()
        )

        # Робастный train/val split
        n_total = len(df_all)
        use_eval = False
        test_size_abs = 0
        if n_total >= 2 and test_size and float(test_size) > 0:
            if isinstance(test_size, float):
                test_size_abs = int(round(n_total * float(test_size)))
            else:
                test_size_abs = int(test_size)
            if test_size_abs <= 0:
                test_size_abs = 1
            if test_size_abs >= n_total:
                test_size_abs = n_total - 1
            use_eval = test_size_abs > 0

        if use_eval:
            df_train, df_eval = train_test_split(df_all, test_size=test_size_abs, random_state=seed, shuffle=True)
        else:
            df_train = df_all
            df_eval = df_all.iloc[0:0]

        eval_dataset = None
        if len(df_eval) > 0:
            eval_dataset = self._prepare_dataset_with_sliding_window(df_eval, max_length, stride)

        train_docs_tokens = df_train[self.tokens_column_name].tolist()
        train_docs_labels = df_train[self.tags_column_name].tolist()

        class_weights = self._compute_class_weights_over_words(train_docs_labels)
        class_weights = self._normalize_clip_weights(class_weights, clip=5.0)

        self._setup_compute_metrics()

        args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=max(4, per_device_train_batch_size // 2),
            gradient_accumulation_steps=gradient_accumulation_steps,
            eval_accumulation_steps=max(1, gradient_accumulation_steps * 2),
            learning_rate=learning_rate,
            warmup_ratio=0.1,
            lr_scheduler_type="cosine",
            weight_decay=0.01,
            eval_strategy="steps" if eval_dataset is not None else "no",
            eval_steps=eval_steps,
            save_strategy="steps" if eval_dataset is not None else "no",
            save_steps=eval_steps,
            load_best_model_at_end=bool(eval_dataset is not None),
            metric_for_best_model="eval_f1",
            save_total_limit=1,
            logging_strategy="steps",
            logging_steps=logging_steps,
            report_to="none",
            fp16=bool(fp16 and torch.cuda.is_available()),
            dataloader_num_workers=min(4, os.cpu_count() or 4),
            seed=seed,
            remove_unused_columns=False,
            disable_tqdm=True,
            dataloader_pin_memory=True,
            gradient_checkpointing=True,
        )

        data_collator = self.data_collator

        def steps_for_size(n_samples: int, bsz: int, accum: int) -> int:
            return max(0, math.ceil(math.ceil(n_samples / max(1, bsz)) / max(1, accum)))

        def chunk_slices(n_docs: int, chunk_docs: int):
            for i in range(0, n_docs, chunk_docs):
                yield slice(i, min(i + chunk_docs, n_docs))

        n_docs = len(train_docs_tokens)
        chunk_docs = int(fit_chunk_size_docs) if (fit_chunk_size_docs and fit_chunk_size_docs > 0) else n_docs

        total_steps = 0
        rng = np.random.default_rng(seed)
        doc_indices = np.arange(n_docs)
        for _ in range(epochs):
            rng.shuffle(doc_indices)
            for slc in chunk_slices(n_docs, chunk_docs):
                idx = doc_indices[slc]
                toks_chunk = [train_docs_tokens[i] for i in idx]
                n_samples = self._count_total_chunks(toks_chunk, max_length, stride, batch_docs=64)
                total_steps += steps_for_size(n_samples, per_device_train_batch_size, gradient_accumulation_steps)

        if n_docs > 0:
            init_chunk_ds = self._tokenize_and_align_chunk(
                [train_docs_tokens[0]], [train_docs_labels[0]], max_length, stride
            )
        else:
            init_chunk_ds = Dataset.from_dict({"input_ids": [], "attention_mask": [], "labels": []})

        self.trainer = WeightedTokenCETrainer(
            model=self.model,
            args=args,
            data_collator=data_collator,
            compute_metrics=self.compute_metrics if eval_dataset is not None else None,
            train_dataset=init_chunk_ds,
            eval_dataset=eval_dataset,
            tokenizer=self.tokenizer,
            class_weights=class_weights
        )
        try:
            self.trainer.remove_callback(PrinterCallback)
        except Exception:
            pass

        if total_steps > 0:
            self.trainer.create_optimizer_and_scheduler(num_training_steps=total_steps)

        pbar = tqdm(total=total_steps, desc="Training Progress", unit="step")
        cb = PbarConsoleLogger(pbar)
        self.trainer.add_callback(cb)

        # Ранняя остановка (если есть eval и запрошена)
        if eval_dataset is not None and (early_stopping_patience is not None) and (early_stopping_patience > 0):
            early_cb = EarlyStoppingCallback(
                early_stopping_patience=int(early_stopping_patience),
                early_stopping_threshold=float(early_stopping_threshold),
            )
            self.trainer.add_callback(early_cb)

        self.progress_callback = cb

        steps_done = 0
        for ep in range(epochs):
            rng = np.random.default_rng(seed + ep)
            order = np.arange(n_docs)
            rng.shuffle(order)

            for slc in chunk_slices(n_docs, chunk_docs):
                idx = order[slc]
                toks_chunk = [train_docs_tokens[i] for i in idx]
                labs_chunk = [train_docs_labels[i] for i in idx]

                ds_chunk = self._tokenize_and_align_chunk(toks_chunk, labs_chunk, max_length, stride)
                self.trainer.train_dataset = ds_chunk

                n_samples = len(ds_chunk)
                chunk_steps = steps_for_size(n_samples, per_device_train_batch_size, gradient_accumulation_steps)
                if chunk_steps == 0:
                    del ds_chunk
                    continue

                self.trainer.args.max_steps = steps_done + chunk_steps
                self.trainer.train()
                steps_done += chunk_steps

                del ds_chunk
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        pbar.close()
        return self

    # ------------------------------
    # Инференс
    # ------------------------------
    def _predict_single_document(self, tokens: List[str], stride: int) -> List[str]:
        """
        Предсказывает метки для одного документа со «скользящим окном».

        :param tokens: список слов (токенов) документа.
        :param stride: перекрытие между окнами.
        :return: список строковых меток той же длины, что и tokens.
        :raises: None
        """
        if not isinstance(tokens, (list, tuple)) or len(tokens) == 0:
            return []

        max_length = self._get_effective_max_length()
        stride = self._sanitize_stride(stride, max_length)

        tokenized_inputs = self.tokenizer(
            [tokens],
            is_split_into_words=True,
            return_overflowing_tokens=True,
            max_length=max_length,
            stride=stride,
            truncation=True,
        )
        tokenized_inputs.pop("overflow_to_sample_mapping", None)

        default_id = int(min(self.id2label.keys())) if len(self.id2label) else 0
        default_label = self.id2label.get(default_id, str(default_id))

        if not isinstance(tokenized_inputs.get("input_ids", None), list) or len(tokenized_inputs["input_ids"]) == 0:
            return [default_label] * len(tokens)

        chunk_dataset = Dataset.from_dict(tokenized_inputs)
        outputs = self.trainer.predict(chunk_dataset)

        if hasattr(outputs, "predictions"):
            preds = outputs.predictions
        else:
            preds = outputs["predictions"]

        num_original_words = len(tokens)

        # Основной путь: 3D логиты (num_chunks, seq_len, num_labels)
        if isinstance(preds, np.ndarray) and preds.ndim == 3:
            num_labels = preds.shape[-1]
            word_logits = np.zeros((num_original_words, num_labels), dtype=np.float32)
            word_counts = np.zeros((num_original_words,), dtype=np.float32)

            for i in range(preds.shape[0]):
                chunk_logits = preds[i]
                try:
                    chunk_word_ids = tokenized_inputs.word_ids(batch_index=i)
                except Exception:
                    continue
                if chunk_word_ids is None:
                    continue

                for token_pos, word_id in enumerate(chunk_word_ids):
                    if word_id is None:
                        continue
                    if token_pos == 0 or chunk_word_ids[token_pos - 1] != word_id:
                        if 0 <= word_id < num_original_words:
                            word_logits[word_id] += chunk_logits[token_pos]
                            word_counts[word_id] += 1.0

            mask = word_counts > 0
            if np.any(mask):
                word_logits[mask] /= word_counts[mask, None]

            pred_ids = np.full(num_original_words, default_id, dtype=np.int32)
            if np.any(mask):
                pred_ids[mask] = word_logits[mask].argmax(-1)

            filled = [self.id2label.get(int(x), str(int(x))) for x in pred_ids]
            return filled

        # Fallback: если preds не 3D
        if isinstance(preds, np.ndarray):
            if preds.ndim == 2:
                predictions = preds
            elif preds.ndim == 1:
                predictions = preds[None, :]
            else:
                predictions = preds.reshape(len(tokenized_inputs["input_ids"]), -1)
        else:
            predictions = np.asarray(preds)

        final_predictions = np.full(num_original_words, -1, dtype=np.int32)
        num_chunks = predictions.shape[0]
        for i in range(num_chunks):
            chunk_preds = predictions[i]
            try:
                chunk_word_ids = tokenized_inputs.word_ids(batch_index=i)
            except Exception:
                continue
            if chunk_word_ids is None:
                continue

            chunk_len = len(chunk_preds)
            for token_pos, word_id in enumerate(chunk_word_ids):
                if token_pos >= chunk_len:
                    break
                if word_id is None:
                    continue
                if 0 <= word_id < num_original_words and final_predictions[word_id] == -1:
                    final_predictions[word_id] = int(chunk_preds[token_pos])

        filled = [
            self.id2label.get(pid, default_label) if pid != -1 else default_label
            for pid in final_predictions
        ]
        return filled

    def predict(self, df: pd.DataFrame, stride: int = 128) -> List[List[str]]:
        """
        Предсказывает метки для всех документов из DataFrame.

        :param df: DataFrame с колонкой токенов.
        :param stride: перекрытие между окнами.
        :return: список документов, каждый — список строковых меток по словам.
        :raises RuntimeError: если модель не обучена и документы непустые.
        """
        all_final_labels = []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Предсказание (sliding window)"):
            try:
                original_tokens = row.get(self.tokens_column_name, None)
            except Exception:
                original_tokens = None

            tokens = self._to_token_list(original_tokens)

            if len(tokens) == 0:
                all_final_labels.append([])
                continue

            if self.trainer is None or self.trainer.model is None:
                raise RuntimeError("Модель не обучена. Вызовите .fit() перед .predict() для непустых документов.")

            document_labels = self._predict_single_document(tokens, stride)
            all_final_labels.append(document_labels)
        return all_final_labels

    # ------------------------------
    # Эмбеддинги
    # ------------------------------
    def _get_embeddings_single_document(self, tokens: List[str], stride: int, device: torch.device) -> np.ndarray:
        """
        Извлекает эмбеддинги слов для одного документа.

        :param tokens: список слов документа.
        :param stride: перекрытие между окнами.
        :param device: устройство модели (CPU/GPU).
        :return: массив формы (num_words, hidden_size), dtype=float32.
        :raises: None
        """
        max_length = self._get_effective_max_length()
        stride = self._sanitize_stride(stride, max_length)
        num_original_words = len(tokens)

        chunk_inputs = self.tokenizer(
            [tokens],
            is_split_into_words=True,
            return_overflowing_tokens=True,
            max_length=max_length,
            stride=stride,
            truncation=True,
            return_tensors="pt"
        ).to(device)

        chunk_inputs.pop("overflow_to_sample_mapping")

        with torch.no_grad():
            base_model = getattr(self.trainer.model, self.trainer.model.base_model_prefix)
            outputs = base_model(**chunk_inputs)

        chunk_embeddings = outputs.last_hidden_state

        hidden_size = int(self.model.config.hidden_size)
        final_word_embeddings = torch.zeros(num_original_words, hidden_size, device=device)
        word_counts = torch.zeros(num_original_words, device=device)

        for i in range(len(chunk_embeddings)):
            chunk_embeds = chunk_embeddings[i]
            chunk_word_ids = chunk_inputs.word_ids(batch_index=i)
            for token_pos, word_id in enumerate(chunk_word_ids):
                if word_id is not None:
                    final_word_embeddings[word_id] += chunk_embeds[token_pos]
                    if token_pos == 0 or chunk_word_ids[token_pos - 1] != word_id:
                        word_counts[word_id] += 1

        average_embeddings = final_word_embeddings / (word_counts.unsqueeze(1) + 1e-8)
        return average_embeddings.detach().cpu().numpy()

    def get_embeddings(self, df: pd.DataFrame, stride: int = 128) -> List[np.ndarray]:
        """
        Извлекает эмбеддинги слов для каждого документа в DataFrame.

        :param df: DataFrame с колонкой токенов.
        :param stride: перекрытие между окнами.
        :return: список массивов эмбеддингов, по одному на документ (num_words, hidden_size).
        :raises RuntimeError: если модель не обучена.
        """
        if self.trainer is None or self.trainer.model is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        self.trainer.model.eval()
        device = self.trainer.model.device
        all_final_embeddings = []

        for _, row in tqdm(df.iterrows(), total=len(df), desc="Генерация эмбеддингов (sliding window)"):
            original_tokens = row[self.tokens_column_name]
            if not original_tokens:
                all_final_embeddings.append(np.zeros((0, int(self.model.config.hidden_size)), dtype=np.float32))
                continue

            document_embeddings = self._get_embeddings_single_document(original_tokens, stride, device)
            all_final_embeddings.append(document_embeddings)

        return all_final_embeddings

Пример использования 1.

In [None]:
import pandas as pd

# Данные: строковые метки (будут конвертированы в id внутри .fit)
tokens_col, tags_col = "tokens", "tags"
label2id = {
    "O": 0,
    "B-PER": 1, "I-PER": 2,
    "B-LOC": 3, "I-LOC": 4,
    "B-ORG": 5, "I-ORG": 6,
}

df_train = pd.DataFrame([
    {tokens_col: ["John", "Doe", "lives", "in", "Berlin"], tags_col: ["B-PER","I-PER","O","O","B-LOC"]},
    {tokens_col: ["Mary", "works", "at", "Google"], tags_col: ["B-PER","O","O","B-ORG"]},
    {tokens_col: ["Alice", "is", "from", "Paris"], tags_col: ["B-PER","O","O","B-LOC"]},
    {tokens_col: ["IBM", "is", "in", "Armonk"], tags_col: ["B-ORG","O","O","B-LOC"]},
    {tokens_col: ["Bob", "moved", "to", "London"],  tags_col: ["B-PER","O","O","B-LOC"]},
    {tokens_col: ["Google", "is", "in", "California"], tags_col: ["B-ORG","O","O","B-LOC"]},
])

# Инициализация (минимум, что требует класс)
CKPT = "prajjwal1/bert-tiny"
tc = TokenClassification(
    checkpoint=CKPT,
    label2id=label2id,
    tokens_column_name=tokens_col,
    tags_column_name=tags_col
)

# Обучение с максимальной параметризацией
tc.fit(
    train_data=df_train,
    epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    test_size=0.33,             # включаем валидацию
    learning_rate=3e-5,
    fp16=True,                  # если есть CUDA — включит fp16
    stride=64,                  # скользящее окно
    logging_steps=1,
    eval_steps=2,
    output_dir="./tokcls_max_param",
    seed=123,
    fit_chunk_size_docs=2,      # обучаемся «кусками» по 2 документа
    early_stopping_patience=2,  # ранняя остановка после 2 неулучшений
    early_stopping_threshold=0.0,
)

# Метрики от Trainer (включая per-entity F1: eval_f1_PER/LOC/ORG и т.д., если встретились)
metrics = tc.trainer.evaluate()
print("Eval metrics (subset):", {k: float(v) for k, v in metrics.items() if isinstance(v, (int, float))})

# Предсказание на части данных (с отдельным stride на инференсе)
df_infer = df_train.iloc[:3]
preds = tc.predict(df_infer, stride=32)
for i, (tokens, pred) in enumerate(zip(df_infer[tokens_col], preds), 1):
    print(f"Doc {i}:")
    print(list(zip(tokens, pred)))

# Эмбеддинги слов (каждый документ -> массив [num_words, hidden_size])
embs = tc.get_embeddings(df_infer, stride=32)
print("Embeddings shapes:", [e.shape for e in embs])

Пример использования 2.

In [None]:
import pandas as pd

tokens_col, tags_col = "tokens", "tags"
label2id = {"O": 0, "B-PER": 1}  # минимальный набор меток

# Данные уже в id (минимальная разметка)
df_small = pd.DataFrame([
    {tokens_col: ["John", "works"], tags_col: [1, 0]},  # ["B-PER","O"]
    {tokens_col: ["Mary", "smiles"], tags_col: [1, 0]},  # ["B-PER","O"]
])

# Инициализация
CKPT = "prajjwal1/bert-tiny"
tc = TokenClassification(
    checkpoint=CKPT,
    label2id=label2id,
    tokens_column_name=tokens_col,
    tags_column_name=tags_col
)

# Обучение — все параметры по умолчанию
tc.fit(train_data=df_small)

# Базовый предикт — тоже по умолчанию
preds = tc.predict(df_small)
print("Preds:", preds)

# При необходимости — эмбеддинги (тоже с параметрами по умолчанию)
embs = tc.get_embeddings(df_small)
print("Embeddings shape for doc 0:", embs[0].shape)

# Дообучение классификатора, который работает с данными разной модальностью.

Реализация классификатора.

In [None]:
!pip install --upgrade --no-cache-dir \
  --extra-index-url https://download.pytorch.org/whl/cu124 \
  pillow==11.1.0 \
  numpy==1.26.4 \
  pandas==2.2.3 \
  tqdm==4.67.1 \
  transformers==4.51.3 \
  evaluate==0.4.5 \
  wav2clip==0.1.0 \
  torch==2.6.0+cu124 \
  torchaudio==2.6.0+cu124

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import math
import random
import gc
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Union

from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from transformers import TrainingArguments, Trainer
from transformers.trainer_callback import TrainerCallback, PrinterCallback, EarlyStoppingCallback
from transformers.modeling_outputs import SequenceClassifierOutput
import evaluate

# =========================
# Утилиты
# =========================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def to_pil(x: Union[str, np.ndarray, 'Image.Image']) -> 'Image.Image':
    if isinstance(x, Image.Image): return x.convert("RGB")
    if isinstance(x, str): return Image.open(x).convert("RGB")
    if isinstance(x, np.ndarray): return Image.fromarray(x).convert("RGB")
    raise ValueError("Ожидается путь/np.ndarray/PIL.Image")

def load_audio(path: str, target_sr: int) -> np.ndarray:
    try:
        import torchaudio
    except Exception as e:
        raise RuntimeError("Требуется torchaudio: pip install torchaudio") from e
    waveform, sr = torchaudio.load(path)
    if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True)
    if sr != target_sr: waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=target_sr)
    return waveform.squeeze(0).numpy().astype(np.float32)

def safe_load(component_cls, checkpoint: str, local_cache_dir: str = "./model_cache",
              local_files_only: Optional[bool] = None, **kwargs):
    if local_files_only is None:
        local_files_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
    name = getattr(component_cls, "__name__", "")
    if "Tokenizer" in name:
        kwargs.setdefault("use_fast", True)
    return component_cls.from_pretrained(
        checkpoint, cache_dir=local_cache_dir, local_files_only=local_files_only, **kwargs
    )


# =========================
# Токенизатор батчевый
# =========================

class BatchTokenizer:
    def __init__(
        self,
        tokenizer,
        max_length: int = 512,
        cache_size: int = 10000,
        batch_size: int = 256,
        use_fast: bool = True,
        device: str = "cpu"
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_fast = use_fast
        self.device = device
        self._cache = lru_cache(maxsize=cache_size)(self._tokenize_single)
        self.is_fast = hasattr(tokenizer, "is_fast") and tokenizer.is_fast
        if self.is_fast:
            print("✓ Используется Fast Tokenizer")

    def _tokenize_single(self, text: str) -> tuple:
        result = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return tuple((k, v.squeeze(0).cpu().numpy()) for k, v in result.items())

    def tokenize_batch(self, texts: List[str], use_cache: bool = True) -> Dict[str, torch.Tensor]:
        if use_cache and len(texts) < 100:
            results = [dict(self._cache(text)) for text in texts]
            keys = results[0].keys()
            batch_dict = {}
            for key in keys:
                dtype = torch.long if key in ["input_ids", "attention_mask", "token_type_ids"] else torch.float32
                batch_dict[key] = torch.tensor(np.stack([r[key] for r in results]), dtype=dtype)
            return batch_dict
        else:
            result = self.tokenizer(
                texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            for key in result:
                if key in ["input_ids", "attention_mask", "token_type_ids"]:
                    result[key] = result[key].long()
            return result

    def tokenize_dataset_lazy(
        self,
        texts: List[str],
        batch_size: Optional[int] = None
    ) -> Generator[Dict[str, torch.Tensor], None, None]:
        batch_size = batch_size or self.batch_size
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            yield self.tokenize_batch(batch, use_cache=False)

    def clear_cache(self):
        self._cache.cache_clear()


# =========================
# Универсальный датасет
# =========================

class MultiComboDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        target_col: str,
        label2id: Dict[Any, int],
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_tokenizer: Optional[BatchTokenizer] = None,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        pretokenize: bool = False,
        pretokenize_batch_size: int = 256,
        max_cache_size: int = 100000,
        tokenizer_returns_tensors: bool = False,
        cache_dir: Optional[str] = None
    ):
        self.df = df.reset_index(drop=True)
        self.target_col = target_col
        self.label2id = label2id

        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []

        self.text_tokenizer = text_tokenizer
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors

        self.tokenized_cache: Dict[int, Dict[str, torch.Tensor]] = {}
        self.cache_hits = 0
        self.cache_misses = 0

        if pretokenize and self.text_tokenizer and self.text_columns:
            self._pretokenize_texts(
                batch_size=pretokenize_batch_size,
                max_cache_size=min(max_cache_size, len(self.df))
            )

    def _join_text(self, row: pd.Series) -> str:
        sep = self.special_tokens.get("sep", " [SEP] ")
        return sep.join([str(row[c]) if pd.notnull(row[c]) else "" for c in self.text_columns])

    def _pretokenize_texts(self, batch_size: int, max_cache_size: int):
        print("Предварительная токенизация текстов...")
        indices = list(range(min(len(self.df), max_cache_size)))
        all_texts = [self._join_text(self.df.iloc[i]) for i in indices]

        for start in range(0, len(indices), batch_size):
            batch_idx = indices[start:start + batch_size]
            batch_txt = all_texts[start:start + batch_size]
            tokenized = self.text_tokenizer.tokenize_batch(batch_txt, use_cache=False)

            for j, idx in enumerate(batch_idx):
                token_dict: Dict[str, torch.Tensor] = {}
                for k, v in tokenized.items():
                    t = v[j]
                    token_dict[k] = t.clone().long() if k in ["input_ids", "attention_mask", "token_type_ids"] else t.clone()
                self.tokenized_cache[idx] = token_dict

        print(f"✓ Предварительно токенизировано {len(self.tokenized_cache)} текстов")

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]
        item: Dict[str, Any] = {}
        if self.target_col in row:
            item["labels"] = int(self.label2id[row[self.target_col]])
        else:
            item["labels"] = 0

        if self.text_columns:
            if idx in self.tokenized_cache:
                cached = self.tokenized_cache[idx]
                text_tokens: Dict[str, torch.Tensor] = {}
                for k, v in cached.items():
                    text_tokens[k] = v.long() if k in ["input_ids", "attention_mask", "token_type_ids"] else v
                item["text_tokens"] = text_tokens
                self.cache_hits += 1
            elif self.text_tokenizer is not None:
                text = self._join_text(row)
                tokenized = self.text_tokenizer.tokenize_batch([text], use_cache=True)
                text_tokens = {k: (v[0].long() if k in ["input_ids", "attention_mask", "token_type_ids"] else v[0])
                               for k, v in tokenized.items()}
                item["text_tokens"] = text_tokens
                self.cache_misses += 1
                if len(self.tokenized_cache) < 100000:
                    self.tokenized_cache[idx] = {k: t.clone() for k, t in text_tokens.items()}
            elif self.text_tokenizer_fn is not None:
                text_data = {c: str(row[c]) if pd.notnull(row[c]) else "" for c in self.text_columns}
                result = self.text_tokenizer_fn(text_data, self.special_tokens)
                if isinstance(result, dict) and 'input_ids' in result:
                    item["text_tokens"] = result
                    self.tokenizer_returns_tensors = True
                else:
                    item["text"] = result
            else:
                item["text"] = self._join_text(row)

        def _as_list(v):
            if v is None or (isinstance(v, float) and np.isnan(v)):
                return []
            if isinstance(v, (list, tuple)):
                return list(v)
            return [v]

        if self.image_columns:
            imgs = []
            for c in self.image_columns:
                if c in row:
                    imgs.extend(_as_list(row[c]))
            item["images"] = imgs

        if self.audio_columns:
            auds = []
            for c in self.audio_columns:
                if c in row:
                    auds.extend(_as_list(row[c]))
            item["audios"] = auds

        return item

    def get_cache_stats(self) -> Dict[str, Any]:
        total = self.cache_hits + self.cache_misses
        hit_rate = self.cache_hits / total if total > 0 else 0.0
        return {
            "cache_size": len(self.tokenized_cache),
            "cache_hits": self.cache_hits,
            "cache_misses": self.cache_misses,
            "hit_rate": hit_rate
        }

    def clear_cache(self):
        self.tokenized_cache.clear()
        self.cache_hits = 0
        self.cache_misses = 0
        if self.text_tokenizer:
            self.text_tokenizer.clear_cache()


# =========================
# Универсальный бэкенд
# =========================

class BaseBackend(nn.Module):
    name: str = "base"
    supported: set = set()
    embed_dim: int = 0
    out_dim_per_modality: Dict[str, int] = {}
    text_tokenizer_fn: Optional[Callable] = None
    batch_tokenizer: Optional[BatchTokenizer] = None
    special_tokens: Dict[str, str] = {}
    tokenizer_returns_tensors: bool = False
    local_cache_dir: str = "./model_cache"

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        raise NotImplementedError

    def encode(self, backend_inputs: Dict[str, Any], device: torch.device) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def freeze_all(self):
        for p in self.parameters():
            p.requires_grad = False

    def get_out_dim(self, modality: str) -> int:
        return self.out_dim_per_modality.get(modality, self.embed_dim)

    def set_text_tokenizer(self, tokenizer_fn: Optional[Callable], special_tokens: Optional[Dict[str, str]] = None,
                           returns_tensors: bool = False):
        self.text_tokenizer_fn = tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = returns_tensors

    def set_batch_tokenizer(self, tokenizer, max_length: int = 512,
                            cache_size: int = 10000, batch_size: int = 256):
        self.batch_tokenizer = BatchTokenizer(
            tokenizer=tokenizer,
            max_length=max_length,
            cache_size=cache_size,
            batch_size=batch_size,
            use_fast=True
        )


class UniversalMultiBackend(BaseBackend):
    name = "universal"
    
    class _ParamDeviceProxy(nn.Module):
        def __init__(self, base, device: torch.device):
            super().__init__()
            self.base = base if isinstance(base, nn.Module) else None
            self._callable = base if not isinstance(base, nn.Module) else None
            self._dummy = nn.Parameter(torch.empty(0), requires_grad=False)
            with torch.no_grad():
                self._dummy.data = self._dummy.data.to(device)
            try:
                target = self.base if self.base is not None else self._callable
                if hasattr(target, "to"):
                    target.to(device)
            except Exception:
                pass
    
        def forward(self, *args, **kwargs):
            target = self.base if self.base is not None else self._callable
            return target(*args, **kwargs)
    
        def to(self, device, *args, **kwargs):
            self._dummy.data = self._dummy.data.to(device)
            try:
                target = self.base if self.base is not None else self._callable
                if hasattr(target, "to"):
                    target.to(device)
            except Exception:
                pass
            return super().to(device, *args, **kwargs)
    
    def _preferred_device(self) -> torch.device:
        if torch.cuda.is_available():
            return torch.device("cuda")
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")
    
    def _wrap_if_parameterless(self, model, device: torch.device):
        try:
            it = model.parameters() if hasattr(model, "parameters") else iter(())
            next(it)
            return model
        except StopIteration:
            return self._ParamDeviceProxy(model, device)
        except Exception:
            return self._ParamDeviceProxy(model, device)

    def __init__(
        self,
        text_model_config: Optional[Dict[str, Any]] = None,
        image_model_config: Optional[Dict[str, Any]] = None,
        audio_model_config: Optional[Dict[str, Any]] = None,
        freeze: bool = True,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        tokenizer_returns_tensors: bool = False,
        use_batch_tokenizer: bool = True,
        tokenizer_cache_size: int = 10000,
        tokenizer_batch_size: int = 256,
        local_cache_dir: str = "./model_cache"
    ):
        super().__init__()
        self.supported = set()
        self.out_dim_per_modality = {}
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors
        self.use_batch_tokenizer = use_batch_tokenizer
        self.tokenizer_cache_size = tokenizer_cache_size
        self.tokenizer_batch_size = tokenizer_batch_size
        self.local_cache_dir = local_cache_dir

        self.text_model = None
        self.text_processor = None
        self.text_config = text_model_config or {}
        if text_model_config:
            self._init_text_model(text_model_config)
            self.supported.add("text")

        self.image_model = None
        self.image_processor = None
        self.image_config = image_model_config or {}
        if image_model_config:
            self._init_image_model(image_model_config)
            self.supported.add("image")

        self.audio_model = None
        self.audio_processor = None
        self.audio_config = audio_model_config or {}
        if audio_model_config:
            self._init_audio_model(audio_model_config)
            self.supported.add("audio")

        if freeze:
            self.freeze_all()

    def _ensure_2d(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if x is None:
            return None
        if x.dim() == 1:
            return x.unsqueeze(0)
        if x.dim() > 2:
            return x.view(x.size(0), -1)
        return x

    def _normalize_2d(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        x = self._ensure_2d(x)
        return F.normalize(x, dim=-1, eps=1e-12) if x is not None and x.numel() > 0 else x
    
    def _init_text_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPTokenizer, ClapModel, ClapProcessor
    
        checkpoint = config['checkpoint']
        model_type = config.get('model_type', 'auto').lower()
    
        print(f"Загрузка текстовой модели {checkpoint}...")
    
        if model_type == 'clip':
            self.text_model = safe_load(CLIPModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = safe_load(CLIPTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = self.text_model.config.projection_dim
        elif model_type == 'clap':
            self.text_model = safe_load(ClapModel, checkpoint, local_cache_dir=self.local_cache_dir)
            proc = safe_load(ClapProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = getattr(proc, 'tokenizer', None) or safe_load(AutoTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = getattr(self.text_model.config, "projection_dim", 512)
        else:
            self.text_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = safe_load(AutoTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = self.text_model.config.hidden_size

        dev = self._preferred_device()
        self.text_model  = self._wrap_if_parameterless(self.text_model, dev)
    
        if self.use_batch_tokenizer and self.text_processor is not None:
            self.set_batch_tokenizer(
                self.text_processor,
                max_length=config.get('max_length', 512),
                cache_size=self.tokenizer_cache_size,
                batch_size=self.tokenizer_batch_size
            )
    
        self.text_config['max_length'] = config.get('max_length', 512)
        self.text_config['dim'] = dim
        self.text_config['model_type'] = model_type
        self.out_dim_per_modality['text'] = dim

    def _init_image_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoImageProcessor, CLIPModel, CLIPImageProcessor
    
        checkpoint = config['checkpoint']
        model_type = config.get('model_type', 'auto').lower()
    
        print(f"Загрузка визуальной модели {checkpoint}...")
    
        if model_type == 'clip':
            self.image_model = safe_load(CLIPModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.image_processor = safe_load(CLIPImageProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.image_model.config.projection_dim
        else:
            self.image_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.image_processor = safe_load(AutoImageProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.image_model.config.hidden_size

        dev = self._preferred_device()
        self.image_model = self._wrap_if_parameterless(self.image_model, dev)
    
        self.image_config['max_images'] = config.get('max_images', 1)
        self.image_config['image_agg'] = config.get('image_agg', 'concat')
        self.image_config['dim'] = dim
        self.image_config['model_type'] = model_type
    
        self.out_dim_per_modality['image'] = (dim * self.image_config['max_images']) if self.image_config['image_agg'] == 'concat' else dim
    
    def _init_audio_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoProcessor, ClapModel, ClapProcessor
    
        model_type = config.get('model_type', 'auto').lower()
        checkpoint = config.get('checkpoint', None)
    
        print(f"Загрузка аудио модели (type={model_type})...")
    
        if model_type == 'wav2clip':
            import wav2clip as w2c
            self._w2c = w2c
    
            w2c_model = None
            if hasattr(w2c, "get_model"):
                w2c_model = w2c.get_model()
            elif hasattr(w2c, "model"):
                m = w2c.model
                w2c_model = m() if callable(m) else m
            else:
                raise RuntimeError("wav2clip не содержит get_model()/model. Обновите пакет wav2clip.")
    
            self.audio_model = w2c_model
    
            try:
                if isinstance(self.audio_model, torch.nn.Module) and torch.cuda.is_available():
                    self.audio_model = self.audio_model.to("cuda")
            except Exception:
                pass
    
            self.audio_processor = None
            dim = 512
            sr = config.get('sr', 16000)

        elif model_type == 'clap':
            if checkpoint is None:
                raise ValueError("audio_model_config['checkpoint'] обязателен для CLAP")
            self.audio_model = safe_load(ClapModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.audio_processor = safe_load(ClapProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = getattr(self.audio_model.config, "projection_dim", 512)
            sr = getattr(self.audio_processor, "sampling_rate", None)
            if sr is None:
                fe = getattr(self.audio_processor, "feature_extractor", None)
                sr = getattr(fe, "sampling_rate", 48000) if fe is not None else 48000
    
        else:
            if checkpoint is None:
                raise ValueError("audio_model_config['checkpoint'] обязателен для аудио-моделей, кроме wav2clip")
            self.audio_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.audio_processor = safe_load(AutoProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.audio_model.config.hidden_size
            fe = getattr(self.audio_processor, "feature_extractor", None)
            sr = getattr(fe, "sampling_rate", 16000) if fe is not None else 16000

        dev = self._preferred_device()
        self.audio_model = self._wrap_if_parameterless(self.audio_model, dev)
    
        self.audio_config['sr'] = config.get('sr', sr)
        self.audio_config['max_audios'] = config.get('max_audios', 1)
        self.audio_config['audio_agg'] = config.get('audio_agg', 'concat')
        self.audio_config['dim'] = dim
        self.audio_config['model_type'] = model_type
    
        self.out_dim_per_modality['audio'] = (
            dim * self.audio_config['max_audios']
            if self.audio_config['audio_agg'] == 'concat' else dim
        )

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        labels = []
        for b in batch:
            labels.append(torch.tensor(b.get("labels", 0), dtype=torch.long))
        labels = torch.stack(labels)

        backend_inputs: Dict[str, Any] = {}
        batch_size = len(batch)

        if self.text_model is not None:
            if "text_tokens" in batch[0]:
                text_inputs = {}
                for key in batch[0]["text_tokens"].keys():
                    if torch.is_tensor(batch[0]["text_tokens"][key]):
                        text_inputs[key] = torch.stack([b["text_tokens"][key] for b in batch])
                    else:
                        dtype = torch.long if key in ["input_ids", "attention_mask", "token_type_ids"] else torch.float32
                        text_inputs[key] = torch.tensor([b["text_tokens"][key] for b in batch], dtype=dtype)
                backend_inputs["text_inputs"] = text_inputs
            else:
                texts = [b.get("text", "") or " " for b in batch]
                if self.batch_tokenizer:
                    text_inputs = self.batch_tokenizer.tokenize_batch(texts, use_cache=True)
                else:
                    text_inputs = self.text_processor(
                        texts, padding=True, truncation=True,
                        max_length=self.text_config.get('max_length', 512),
                        return_tensors="pt"
                    )
                backend_inputs["text_inputs"] = {k: v for k, v in text_inputs.items()}

        if self.image_model is not None:
            images_lists = [b.get("images", []) for b in batch]
            flat_images, img_counts = [], []
            for lst in images_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [img for img in lst if img is not None]
                img_counts.append(len(lst))
                for img in lst:
                    flat_images.append(to_pil(img))

            if len(flat_images) > 0:
                img_proc = self.image_processor(images=flat_images, return_tensors="pt")
                backend_inputs["image_inputs"] = {"pixel_values": img_proc["pixel_values"]}
            else:
                backend_inputs["image_inputs"] = {"pixel_values": None}

            backend_inputs["image_counts"] = torch.tensor(img_counts, dtype=torch.long)

        if self.audio_model is not None:
            audios_lists = [b.get("audios", []) for b in batch]
            flat_audios, aud_counts = [], []
            for lst in audios_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [a for a in lst if a is not None]
                aud_counts.append(len(lst))
                for a in lst:
                    if isinstance(a, str):
                        flat_audios.append(load_audio(a, self.audio_config['sr']))
                    elif isinstance(a, np.ndarray):
                        aa = np.asarray(a, dtype=np.float32)
                        if aa.ndim > 1:
                            aa = np.squeeze(aa)
                        if aa.ndim > 1:
                            aa = aa.reshape(-1)
                        flat_audios.append(aa)
            if self.audio_config.get('model_type') == 'wav2clip':
                backend_inputs["audio_inputs"] = {"raw_audios": flat_audios}
            elif len(flat_audios) > 0:
                if self.audio_config.get('model_type') == 'clap':
                    aud_proc = self.audio_processor(
                        audios=flat_audios, sampling_rate=self.audio_config['sr'], padding=True, return_tensors="pt"
                    )
                    backend_inputs["audio_inputs"] = {"input_features": aud_proc["input_features"]}
                else:
                    aud_proc = self.audio_processor(
                        flat_audios, sampling_rate=self.audio_config['sr'], padding=True, return_tensors="pt"
                    )
                    backend_inputs["audio_inputs"] = {"input_values": aud_proc["input_values"]}
            else:
                backend_inputs["audio_inputs"] = {"input_features": None, "input_values": None, "raw_audios": []}
        
            backend_inputs["audio_counts"] = torch.tensor(aud_counts, dtype=torch.long)

        backend_inputs["batch_size"] = batch_size
        return {"labels": labels, "backend_inputs": backend_inputs}

    def _aggregate_embeddings(
        self,
        embs: Optional[torch.Tensor],
        counts: List[int],
        max_k: int,
        dim_hint: int,
        agg_type: str,
        batch_size: int,
        device: torch.device
    ) -> torch.Tensor:
        if embs is None or (torch.is_tensor(embs) and embs.numel() == 0):
            feat_dim = int(dim_hint) if dim_hint is not None else 0
            out_dim = feat_dim * max_k if agg_type == 'concat' else feat_dim
            return torch.zeros((batch_size, out_dim), device=device, dtype=torch.float32)
    
        if not torch.is_tensor(embs):
            embs = torch.as_tensor(embs, device=device, dtype=torch.float32)
        if embs.dim() == 1:
            embs = embs.unsqueeze(0)
        elif embs.dim() > 2:
            embs = embs.view(embs.size(0), -1)
    
        N, D = embs.size()
        out_dim = (D * max_k) if agg_type == 'concat' else D
        out = torch.zeros((batch_size, out_dim), device=device, dtype=embs.dtype)
    
        offset = 0
        for i, c in enumerate(counts):
            if c <= 0 or offset >= N:
                continue
            take_n = min(c, N - offset)
            sample = embs[offset:offset + take_n]
            offset += take_n
    
            if agg_type == 'concat':
                take = sample[:max_k]
                if take.size(0) < max_k:
                    pad = torch.zeros((max_k - take.size(0), D), device=device, dtype=embs.dtype)
                    take = torch.cat([take, pad], dim=0)
                out[i] = take.reshape(-1)
            else:
                out[i] = sample.mean(dim=0)
    
        return F.normalize(out, dim=-1, eps=1e-12) if out.size(1) > 0 else out

    @torch.no_grad()
    def _wav2clip_embed(self, arr: np.ndarray, device: torch.device) -> torch.Tensor:
        arr = np.asarray(arr, dtype=np.float32)
        if arr.ndim > 1:
            arr = np.squeeze(arr)
        if arr.ndim > 1:
            arr = arr.reshape(-1)
        if arr.size < 512:
            arr = np.pad(arr, (0, 512 - arr.size), mode="constant")

        try:
            emb = self._w2c.embed_audio(arr, self.audio_model)
            emb = np.asarray(emb)
        except Exception:
            x = torch.from_numpy(arr).float().unsqueeze(0).to(device)
            y = self.audio_model(x)
            if isinstance(y, (tuple, list)):
                y = y[0]
            if torch.is_tensor(y):
                if y.dim() == 2 and y.size(0) == 1:
                    y = y.squeeze(0)
                emb = y.detach().cpu().numpy()
            else:
                emb = np.asarray(y)

        if emb.ndim > 1:
            emb = emb.reshape(-1)
        return torch.as_tensor(emb, device=device, dtype=torch.float32)
    
    def encode(self, backend_inputs: Dict[str, Any], device: torch.device) -> Dict[str, torch.Tensor]:
        results: Dict[str, torch.Tensor] = {}
        batch_size = int(backend_inputs.get("batch_size", 1))

        if self.text_model is not None and "text_inputs" in backend_inputs:
            text_inputs = {k: v.to(device) for k, v in backend_inputs["text_inputs"].items()}
            if hasattr(self.text_model, "get_text_features"):
                text_z = self.text_model.get_text_features(**text_inputs)
            else:
                outputs = self.text_model(**text_inputs)
                text_z = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state.mean(dim=1)
            results["text"] = self._normalize_2d(text_z)

        if self.image_model is not None and "image_inputs" in backend_inputs:
            pi = backend_inputs["image_inputs"]["pixel_values"]
            counts = backend_inputs["image_counts"].tolist()
            total_images_needed = sum(counts)

            img_flat = None
            actual_img_dim = self.image_config.get("dim", 768)

            if pi is not None and pi.numel() > 0 and total_images_needed > 0:
                pi = pi.to(device)
                if pi.size(0) > total_images_needed:
                    pi = pi[:total_images_needed]

                if hasattr(self.image_model, "get_image_features"):
                    img_flat = self.image_model.get_image_features(pixel_values=pi)
                else:
                    outputs = self.image_model(pixel_values=pi)
                    img_flat = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state[:, 0]

                img_flat = self._normalize_2d(img_flat)
                actual_img_dim = img_flat.size(1) if img_flat is not None else actual_img_dim

            img_z = self._aggregate_embeddings(
                img_flat, counts,
                self.image_config["max_images"],
                actual_img_dim,
                self.image_config["image_agg"],
                len(counts),
                device
            )

            if actual_img_dim != self.image_config.get("dim"):
                self.image_config["dim"] = actual_img_dim
                self.out_dim_per_modality["image"] = (
                    actual_img_dim * self.image_config["max_images"]
                    if self.image_config["image_agg"] == "concat" else actual_img_dim
                )

            results["image"] = img_z

        if self.audio_model is not None and "audio_inputs" in backend_inputs:
            counts = backend_inputs["audio_counts"].tolist()
            total_audios_needed = sum(counts)

            aud_flat = None
            actual_aud_dim = self.audio_config.get("dim", 768)
            model_type = self.audio_config.get("model_type")

            if total_audios_needed > 0:
                if model_type == "clap":
                    af = backend_inputs["audio_inputs"]["input_features"]
                    if af is not None and af.numel() > 0:
                        af = af.to(device)
                        if af.size(0) > total_audios_needed:
                            af = af[:total_audios_needed]
                        with torch.cuda.amp.autocast(enabled=False):
                            aud_flat = self.audio_model.get_audio_features(input_features=af.float())
                        aud_flat = self._normalize_2d(aud_flat.float())
                        actual_aud_dim = aud_flat.size(1)

                elif model_type == "wav2clip":
                    raw_list = backend_inputs["audio_inputs"].get("raw_audios", [])
                    if len(raw_list) > total_audios_needed:
                        raw_list = raw_list[:total_audios_needed]
                    if len(raw_list) > 0:
                        embs = [self._wav2clip_embed(arr, device) for arr in raw_list]
                        aud_flat = torch.stack(embs, dim=0)
                        aud_flat = self._normalize_2d(aud_flat)
                        actual_aud_dim = aud_flat.size(1)

                else:
                    av = backend_inputs["audio_inputs"]["input_values"]
                    if av is not None and av.numel() > 0:
                        av = av.to(device)
                        if av.size(0) > total_audios_needed:
                            av = av[:total_audios_needed]
                        av = av.clamp_(-1.0, 1.0)
                        with torch.cuda.amp.autocast(enabled=False):
                            outputs = self.audio_model(input_values=av.float())
                            feats = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state.mean(dim=1)
                        aud_flat = self._normalize_2d(feats.float())
                        actual_aud_dim = aud_flat.size(1)

            aud_z = self._aggregate_embeddings(
                aud_flat, counts,
                self.audio_config["max_audios"],
                actual_aud_dim,
                self.audio_config["audio_agg"],
                len(counts),
                device
            )

            if aud_flat is not None and actual_aud_dim != self.audio_config.get("dim"):
                self.audio_config["dim"] = actual_aud_dim
                self.out_dim_per_modality["audio"] = (
                    actual_aud_dim * self.audio_config["max_audios"]
                    if self.audio_config["audio_agg"] == "concat" else actual_aud_dim
                )

            results["audio"] = aud_z

        if results:
            bs_list = [v.size(0) for v in results.values()]
            if len(set(bs_list)) != 1:
                raise RuntimeError(f"Inconsistent batch sizes across modalities: {bs_list}")

        return results


# =========================
# Классификатор
# =========================

class SingleBackboneClassifier(nn.Module):
    def __init__(
        self,
        backend: BaseBackend,
        modalities: List[str],
        num_labels: int,
        fusion: str = "concat",
        hidden: int = 512,
        dropout: float = 0.1
    ):
        super().__init__()
        self.backend = backend
        self.modalities = modalities
        self.fusion = fusion
        self.num_labels = num_labels

        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        if fusion == "concat":
            in_dim = sum(self.backend.get_out_dim(m) for m in order)
        elif fusion == "mean":
            dims = [self.backend.get_out_dim(m) for m in order]
            if len(set(dims)) != 1:
                raise ValueError('Для fusion="mean" размеры модальностей должны совпадать')
            in_dim = dims[0]
        else:
            raise ValueError('fusion должен быть "concat" или "mean"')

        self.head = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_labels)
        )

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        for m in [getattr(self.backend, "text_model", None),
                  getattr(self.backend, "image_model", None),
                  getattr(self.backend, "audio_model", None)]:
            if m is None:
                continue
            try:
                has_trainable = any(p.requires_grad for p in m.parameters()) if hasattr(m, "parameters") else False
            except Exception:
                has_trainable = False
            if not has_trainable:
                continue
            try:
                cfg = getattr(m, "config", None)
                if cfg is not None and hasattr(cfg, "use_cache"):
                    cfg.use_cache = False
            except Exception:
                pass
            try:
                if hasattr(m, "gradient_checkpointing_enable"):
                    try:
                        if gradient_checkpointing_kwargs is not None:
                            m.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
                        else:
                            m.gradient_checkpointing_enable()
                    except TypeError:
                        m.gradient_checkpointing_enable()
            except Exception:
                pass

    def gradient_checkpointing_disable(self):
        for m in [getattr(self.backend, "text_model", None),
                  getattr(self.backend, "image_model", None),
                  getattr(self.backend, "audio_model", None)]:
            if m is None:
                continue
            try:
                if hasattr(m, "gradient_checkpointing_disable"):
                    m.gradient_checkpointing_disable()
            except Exception:
                pass
    
    def _infer_device_from_inputs(self, obj) -> torch.device:
        if isinstance(obj, torch.Tensor):
            return obj.device
        if isinstance(obj, dict):
            for v in obj.values():
                d = self._infer_device_from_inputs(v)
                if d is not None:
                    return d
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _fuse(self, z: Dict[str, torch.Tensor]) -> torch.Tensor:
        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        feats = []
        batch_size = None
        for m in order:
            if m in z:
                t = z[m]
                if t.dim() == 3:
                    t = t.mean(dim=1)
                elif t.dim() > 3:
                    t = t.view(t.size(0), -1)
                feats.append(t)
                if batch_size is None:
                    batch_size = t.size(0)
        if batch_size is not None:
            feats = [f[:batch_size] for f in feats]
        if self.fusion == "concat":
            out = torch.cat(feats, dim=-1)
        elif self.fusion == "mean":
            out = torch.stack(feats, dim=0).mean(dim=0)
        return out

    def forward(self, backend_inputs: Dict[str, Any], labels: Optional[torch.Tensor] = None):
        device = self._infer_device_from_inputs(backend_inputs)
        z = self.backend.encode(backend_inputs, device=device)
        if not z:
            raise ValueError("Backend не вернул эмбеддинги")
        fused = self._fuse(z)
        logits = self.head(fused)
        return SequenceClassifierOutput(logits=logits)

    @torch.no_grad()
    def get_embeddings(self, backend_inputs: Dict[str, Any], return_per_modality: bool = False):
        device = self._infer_device_from_inputs(backend_inputs)
        z = self.backend.encode(backend_inputs, device=device)
        fused = self._fuse(z)
        if return_per_modality:
            return fused, z
        return fused


# =========================
# Trainer с весами классов
# =========================

class WeightedCETrainer(Trainer):
    def __init__(self, *args, num_labels=None, train_labels=None, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            self.class_weights = torch.as_tensor(class_weights, dtype=torch.float32)
        elif train_labels is not None and num_labels is not None:
            y = np.asarray(train_labels).astype(int)
            counts = np.bincount(y, minlength=num_labels)
            n = counts.sum()
            w = np.zeros(num_labels, dtype=np.float32)
            nz = counts > 0
            w[nz] = n / (num_labels * counts[nz].astype(np.float32))
            self.class_weights = torch.tensor(w, dtype=torch.float32)
        else:
            self.class_weights = None

    def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch: Optional[int] = None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        if logits.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: logits batch={logits.size(0)} vs labels batch={labels.size(0)}")

        weight = self.class_weights.to(logits.device) if self.class_weights is not None else None
        loss = F.cross_entropy(logits, labels.long(), weight=weight)
        return (loss, outputs) if return_outputs else loss


# =========================
# Прогресс-логгер
# =========================

class PbarConsoleLogger(TrainerCallback):
    def __init__(self, pbar):
        self.pbar = pbar
        self.last_logs = {}
        self.last_train_loss = None
        self.printed_eval_steps = set()
        self.tqdm = tqdm

    def _step(self, state) -> int:
        return int(state.global_step or 0)

    def _fmt_postfix(self):
        parts = []
        if 'loss' in self.last_logs:
            parts.append(f"loss {self.last_logs['loss']:.4f}")
        if 'eval_loss' in self.last_logs:
            parts.append(f"val {self.last_logs['eval_loss']:.4f}")
        for k, v in self.last_logs.items():
            if k.startswith('eval_') and k not in (
                'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'
            ):
                parts.append(f"{k.replace('eval_', '')} {v:.4f}")
        return " | ".join(parts)

    def on_step_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        if self.last_logs:
            self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        for k, v in logs.items():
            if isinstance(v, (int, float)):
                self.last_logs[k] = float(v)
        if 'loss' in logs and isinstance(logs['loss'], (int, float)):
            self.last_train_loss = float(logs['loss'])
        self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

        if any(k.startswith('eval_') for k in logs.keys()):
            step = self._step(state)
            if step in self.printed_eval_steps:
                return
            self.printed_eval_steps.add(step)
            train_loss_str = f"{self.last_train_loss:.10f}" if self.last_train_loss is not None else "n/a"
            val_loss = logs.get('eval_loss', None)
            val_loss_str = f"{float(val_loss):.10g}" if isinstance(val_loss, (int, float)) else "n/a"
            exclude = {'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'}
            extra_parts = []
            for k, v in logs.items():
                if k.startswith('eval_') and k not in exclude:
                    metric_name = k.replace('eval_', '')
                    extra_parts.append(f"val {metric_name}: {float(v):.10f}")
            line = f"step: {step}, train loss: {train_loss_str}, val loss: {val_loss_str}"
            if extra_parts:
                line += ", " + ", ".join(extra_parts)
            self.tqdm.write(line)

    def on_train_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        self.pbar.refresh()


# =========================
# Основной пайплайн
# =========================

class SingleModelMultiComboClassification:
    """
    Универсальный пайплайн для мульти-модальной классификации (text / image / audio) поверх моделей Hugging Face.
    Поддерживает автоматический выбор бэкенда (CLIP/CLAP/Auto), батчевую токенизацию текста, кэширование,
    чанковую тренировку, раннюю остановку, взвешивание классов, предсказание и извлечение эмбеддингов.

    Основные возможности:
    - Автоматическая сборка бэкенда: CLIP для связки text+image, CLAP для text+audio, либо произвольные Auto-модели.
    - Работа с тремя модальностями: text, image, audio (любой поднабор).
    - Батчевая токенизация текста с кэшем и опциональной предварительной токенизацией датасета.
    - Чанковая тренировка очень больших датасетов без перегрузки памяти.
    - Сбалансированная (взвешенная) кросс-энтропия на основе частот классов в тренировочных данных.
    - Прогресс-бар, ранняя остановка и выбор лучшей модели по метрике.
    - Предсказания и извлечение эмбеддингов (в том числе по модальностям).

    :param modalities: Список используемых модальностей из {"text", "image", "audio"}.
    :param num_labels: Число классов в задаче классификации.
    :param target_column_name: Имя колонки с целевой меткой в DataFrame.
    :param text_columns: Список текстовых колонок (используются, если выбрана модальность "text"). Значения будут
                         конкатенированы через special_tokens["sep"] при подготовке примеров.
    :param image_columns: Список колонок с изображениями (пути к файлам, PIL.Image, np.ndarray или списки таких объектов),
                          используется, если выбрана модальность "image".
    :param audio_columns: Список колонок с аудио (пути к файлам или массивы np.ndarray; моно, float32),
                          используется, если выбрана модальность "audio". Для чтения из файлов требуется torchaudio.
    :param text_tokenizer_fn: Кастомная функция токенизации текста (если не используется встроенный BatchTokenizer).
                              Сигнатура: fn(text_dict: Dict[str, str], special_tokens: Dict[str, str]) -> Union[Dict[str, Tensor], str].
                              Если возвращает dict с ключами вроде 'input_ids', считается, что функция сразу возвращает тензоры токенов;
                              иначе строку для последующей стандартной токенизации.
    :param special_tokens: Спец. токены/разделители для подготовки текста. По умолчанию {"sep": " [SEP] "}.
    :param tokenizer_returns_tensors: Флаг, сигнализирующий, что custom text_tokenizer_fn возвращает уже тензоры
                                      (dict c 'input_ids', 'attention_mask' и т.д.). Влияет на коллатор.
    :param backend: Режим сборки бэкенда. "auto" — подобрать оптимальные модели по модальностям;
                    "clip" — CLIP для текста и изображений; "clap" — CLAP для текста и аудио; любое иное — ручные конфиги.
    :param clip_checkpoint: Чекпойнт CLIP (используется при auto/clip), по умолчанию "openai/clip-vit-base-patch32".
    :param clap_checkpoint: Чекпойнт CLAP (используется при auto/clap), по умолчанию "laion/clap-htsat-unfused".
    :param text_model_config: Конфиг текстовой модели. Минимум: {"checkpoint": "...", "model_type": "..."}.
                              Дополнительно: "max_length" и т.д. Примеры model_type: "clip", "clap", "bert", "auto".
    :param image_model_config: Конфиг визуальной модели. Минимум: {"checkpoint": "...", "model_type": "..."}.
                               Дополнительно: "max_images", "image_agg" ("concat"|"mean") и т.д. Примеры model_type: "clip", "vit", "auto".
    :param audio_model_config: Конфиг аудио-модели. Минимум: {"checkpoint": "...", "model_type": "..."} (кроме "wav2clip").
                               Дополнительно: "max_audios", "audio_agg" ("concat"|"mean"), "sr". Примеры model_type: "clap", "wav2clip", "auto".
    :param fusion: Способ слияния модальностей в классификаторе: "concat" или "mean".
                   При "mean" размеры эмбеддингов всех модальностей должны совпадать.
    :param freeze_backbone: Если True — бэкенды заморожены (тренируется только классификационная "голова").
    :param clip_max_length: Максимальная длина текста для CLIP-токенизатора (по умолчанию 77).
    :param max_images_per_sample: Сколько изображений брать на сэмпл (усреднение или конкатенация задаются в image_model_config["image_agg"]).
    :param max_audios_per_sample: Сколько аудио брать на сэмпл (аналично image, параметр audio_model_config["audio_agg"]).
    :param use_batch_tokenizer: Использовать BatchTokenizer для текста (ускоряет токенизацию и кэширует результаты).
    :param pretokenize_data: Предварительно токенизировать текст датасета (в памяти) для ускорения обучения/инференса.
    :param pretokenize_batch_size: Батч-размер при предварительной токенизации.
    :param tokenizer_cache_size: Размер LRU-кэша в BatchTokenizer.
    :param max_pretokenize_samples: Максимум сэмплов для предварительной токенизации на чанк/датасет.
    :param local_cache_dir: Локальная директория кэша моделей/процессоров HF.

    :return: None

    :raises ValueError: Если бэкенд не поддерживает выбранные модальности (внутренняя проверка соответствия).
                        Также возможны ошибки конфигов, например fusion="mean" при несовпадающих размерах эмбеддингов.
    :raises OSError: Ошибки загрузки моделей/процессоров из Hugging Face Hub (сетевые/офлайн проблемы, отсутствующие чекпойнты).
    :raises RuntimeError: Проблемы с устройством/драйвером (CUDA/MPS), несовместимость версий зависимостей и т.п.
    """
    def __init__(
        self,
        modalities: List[str],
        num_labels: int,
        target_column_name: str,
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        tokenizer_returns_tensors: bool = False,
        backend: str = "auto",
        clip_checkpoint: str = "openai/clip-vit-base-patch32",
        clap_checkpoint: str = "laion/clap-htsat-unfused",
        text_model_config: Optional[Dict[str, Any]] = None,
        image_model_config: Optional[Dict[str, Any]] = None,
        audio_model_config: Optional[Dict[str, Any]] = None,
        fusion: str = "concat",
        freeze_backbone: bool = True,
        clip_max_length: int = 77,
        max_images_per_sample: int = 1,
        max_audios_per_sample: int = 1,
        use_batch_tokenizer: bool = True,
        pretokenize_data: bool = True,
        pretokenize_batch_size: int = 256,
        tokenizer_cache_size: int = 10000,
        max_pretokenize_samples: int = 100000,
        local_cache_dir: str = "./model_cache"
    ):
        self.modalities = sorted(list(set(modalities)))
        self.num_labels = num_labels
        self.target_column_name = target_column_name
        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors
        self.backend_name = backend.lower()
        self.clip_checkpoint = clip_checkpoint
        self.clap_checkpoint = clap_checkpoint
        self.text_model_config = text_model_config
        self.image_model_config = image_model_config
        self.audio_model_config = audio_model_config
        self.fusion = fusion
        self.freeze_backbone = freeze_backbone
        self.clip_max_length = clip_max_length
        self.max_images_per_sample = int(max_images_per_sample)
        self.max_audios_per_sample = int(max_audios_per_sample)

        self.use_batch_tokenizer = use_batch_tokenizer
        self.pretokenize_data = pretokenize_data
        self.pretokenize_batch_size = pretokenize_batch_size
        self.tokenizer_cache_size = tokenizer_cache_size
        self.max_pretokenize_samples = max_pretokenize_samples
        self.local_cache_dir = local_cache_dir

        self.label2id: Optional[Dict[Any, int]] = None
        self.id2label: Optional[Dict[int, str]] = None

        self.backend: Optional[BaseBackend] = None
        self.model: Optional[SingleBackboneClassifier] = None
        self.trainer: Optional[Trainer] = None
        self.compute_metrics = None
        self.progress_callback: Optional[PbarConsoleLogger] = None

        self._build_backend()

    def _build_backend(self):
        mods = set(self.modalities)
        name = self.backend_name

        if name == "auto":
            if mods == {"text", "image"}:
                self.text_model_config = self.text_model_config or {
                    'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_length': self.clip_max_length
                }
                self.image_model_config = self.image_model_config or {
                    'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'
                }
            elif mods == {"text", "audio"}:
                self.text_model_config = self.text_model_config or {
                    'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_length': 64
                }
                self.audio_model_config = self.audio_model_config or {
                    'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000
                }
            else:
                if "text" in mods and self.text_model_config is None:
                    self.text_model_config = {'checkpoint': 'bert-base-multilingual-cased', 'model_type': 'bert', 'max_length': 512}
                if "image" in mods and self.image_model_config is None:
                    self.image_model_config = {'checkpoint': 'google/vit-base-patch16-224', 'model_type': 'vit', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'}
                if "audio" in mods and self.audio_model_config is None:
                    self.audio_model_config = {'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000}

        elif name == "clip":
            self.text_model_config = self.text_model_config or {
                'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_length': self.clip_max_length
            }
            self.image_model_config = self.image_model_config or {
                'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'
            }
        elif name == "clap":
            self.text_model_config = self.text_model_config or {
                'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_length': 64
            }
            self.audio_model_config = self.audio_model_config or {
                'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000
            }
        else:
            pass

        self.backend = UniversalMultiBackend(
            text_model_config=self.text_model_config if "text" in mods else None,
            image_model_config=self.image_model_config if "image" in mods else None,
            audio_model_config=self.audio_model_config if "audio" in mods else None,
            freeze=self.freeze_backbone,
            text_tokenizer_fn=self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            tokenizer_returns_tensors=self.tokenizer_returns_tensors,
            use_batch_tokenizer=self.use_batch_tokenizer,
            tokenizer_cache_size=self.tokenizer_cache_size,
            tokenizer_batch_size=self.pretokenize_batch_size,
            local_cache_dir=self.local_cache_dir
        )

        if not set(self.modalities).issubset(self.backend.supported):
            raise ValueError(f"Бэкенд {self.backend.name} не поддерживает модальности {self.modalities}")

    def _setup_metrics(self, metric_name: str):
        metric_name = metric_name.lower()
        if metric_name == "f1":
            metric = evaluate.load("f1")
            def compute(p):
                preds = p.predictions.argmax(-1)
                return metric.compute(predictions=preds, references=p.label_ids, average="weighted")
        elif metric_name == "accuracy":
            metric = evaluate.load("accuracy")
            def compute(p):
                preds = p.predictions.argmax(-1)
                return metric.compute(predictions=preds, references=p.label_ids)
        else:
            raise ValueError('metric_name должен быть "f1" или "accuracy"')
        self.compute_metrics = compute

    def _split(self, df: pd.DataFrame, test_size: float = 0.2, seed: int = 42):
        df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
        n_eval = int(math.ceil(len(df) * test_size))
        return df.iloc[n_eval:].reset_index(drop=True), df.iloc[:n_eval].reset_index(drop=True)

    def _validate_data(self, df: pd.DataFrame):
        if "text" in self.modalities:
            if not self.text_columns:
                raise ValueError("Вы выбрали модальность 'text', но text_columns пустой.")
            missing = [c for c in self.text_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют текстовые колонки: {missing}")
        if "image" in self.modalities:
            if not self.image_columns:
                raise ValueError("Вы выбрали модальность 'image', но image_columns пуст.")
            missing = [c for c in self.image_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют колонки изображений: {missing}")
        if "audio" in self.modalities:
            if not self.audio_columns:
                raise ValueError("Вы выбрали модальность 'audio', но audio_columns пуст.")
            missing = [c for c in self.audio_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют колонки аудио: {missing}")

    def fit(
        self,
        train_data: pd.DataFrame,
        epochs: int = 3,
        test_size: float = 0.2,
        test_data: Optional[pd.DataFrame] = None,
        per_device_train_batch_size: int = 16,
        gradient_accumulation_steps: int = 1,
        learning_rate: float = 2e-4,
        metric_name: str = "f1",
        fp16: bool = True,
        logging_steps: int = 50,
        eval_steps: int = 200,
        output_dir: str = "./result",
        seed: int = 42,
        hidden: int = 512,
        dropout: float = 0.1,
        gradient_checkpointing: bool = False,
        fit_chunk_size: Optional[int] = None,
        clear_cache_every_n_chunks: int = 10,
        early_stopping_patience: Optional[int] = 3,
        early_stopping_threshold: float = 0.0
    ):
        """
        Обучает классификатор на заданном DataFrame с поддержкой валидации, ранней остановки и чанковой тренировки.
        Внутренне использует WeightedCETrainer (кросс-энтропия с весами классов по обратной частоте в train_data),
        логгер прогресса и, при необходимости, предварительную токенизацию батчей текста.
    
        Если test_data не передан, train_data разделяется на train/eval по test_size (стратификации нет).
        При очень больших датасетах обучение проводится чанками (fit_chunk_size), чтобы ограничить потребление памяти.
    
        :param train_data: Обучающий DataFrame. Должен содержать столбец target_column_name, а также столбцы по выбранным модальностям:
                           - text_columns для "text" (строки, допускаются NaN),
                           - image_columns для "image" (строки путей к файлам, PIL.Image, np.ndarray или списки этих типов),
                           - audio_columns для "audio" (пути к аудиофайлам или np.ndarray моно-сигнала; для путей требуется torchaudio).
        :param epochs: Количество эпох обучения.
        :param test_size: Доля данных на валидацию, если test_data не задан.
        :param test_data: Отдельный DataFrame для валидации. Если указан, параметр test_size игнорируется.
        :param per_device_train_batch_size: Батч-размер на устройство для обучения.
        :param gradient_accumulation_steps: Шаги аккумулирования градиента (эффективный батч = batch_size * steps).
        :param learning_rate: Начальная скорость обучения (оптимизатор и шедулер создаются внутри Trainer).
        :param metric_name: Метрика ранней остановки/выбора лучшей модели: "f1" (weighted) или "accuracy".
        :param fp16: Использовать ли полуточность (только при наличии CUDA).
        :param logging_steps: Периодичность логирования в шагах.
        :param eval_steps: Периодичность валидации/сохранения модели в шагах.
        :param output_dir: Директория для чекпойнтов/логов.
        :param seed: Начальное зерно для воспроизводимости.
        :param hidden: Размер скрытого слоя классификационной головы.
        :param dropout: Дропаут в классификационной голове.
        :param gradient_checkpointing: Включить gradient checkpointing в бэкендах (если они обучаемые).
        :param fit_chunk_size: Размер чанка для поэтапной тренировки (None — весь датасет за эпоху без разбиения).
        :param clear_cache_every_n_chunks: Каждые N чанков очищать кэш токенизации (для экономии памяти).
        :param early_stopping_patience: Патенс ранней остановки (кол-во валидационных проверок без улучшения).
                                        Если None или <=0 — ранняя остановка отключена.
        :param early_stopping_threshold: Порог минимального улучшения метрики для сброса патенса.
    
        :return: self (для чейнинга вызовов).
    
        :raises ValueError: 
            - Отсутствуют обязательные колонки по модальностям в train_data (внутренняя проверка _validate_data).
            - Невозможно собрать классификатор с fusion="mean" при разных размерах эмбеддингов модальностей.
        :raises RuntimeError: Ошибки, возникающие внутри transformers.Trainer (например, рассогласование батчей/логитов),
                              проблемы с устройством (CUDA OOM, MPS), ошибки чтения аудио/изображений.
        :raises OSError: Ошибки чтения файлов данных (изображения/аудио) или кэша моделей.
        """
        self._validate_data(train_data)
        set_seed(seed)

        classes = sorted(train_data[self.target_column_name].unique().tolist())
        if self.num_labels != len(classes):
            print(f"Warning: num_labels={self.num_labels} != len(classes)={len(classes)}")
        self.label2id = {c: i for i, c in enumerate(classes)}
        self.id2label = {i: str(c) for c, i in self.label2id.items()}

        if test_data is None:
            df_train, df_eval = self._split(train_data, test_size=test_size, seed=seed)
        else:
            df_train, df_eval = train_data, test_data

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds_eval = MultiComboDataset(
            df=df_eval,
            target_col=self.target_column_name,
            label2id=self.label2id,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=(self.pretokenize_data and has_bt and len(df_eval) < 50000),
            pretokenize_batch_size=self.pretokenize_batch_size,
            max_cache_size=min(len(df_eval), self.max_pretokenize_samples),
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        y_train_all = np.array([self.label2id[y] for y in df_train[self.target_column_name].tolist()], dtype=int)
        counts = np.bincount(y_train_all, minlength=self.num_labels)
        n_all = counts.sum()
        class_weights = np.zeros(self.num_labels, dtype=np.float32)
        nonzero = counts > 0
        class_weights[nonzero] = n_all / (self.num_labels * counts[nonzero].astype(np.float32))

        self.model = SingleBackboneClassifier(
            backend=self.backend,
            modalities=self.modalities,
            num_labels=self.num_labels,
            fusion=self.fusion,
            hidden=hidden,
            dropout=dropout
        )
        self._setup_metrics(metric_name)

        args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=max(4, per_device_train_batch_size // 2),
            gradient_accumulation_steps=gradient_accumulation_steps,
            eval_accumulation_steps=max(1, gradient_accumulation_steps * 2),
            learning_rate=learning_rate,
            warmup_ratio=0.1,
            lr_scheduler_type="cosine",
            weight_decay=0.01,
            eval_strategy="steps",
            eval_steps=eval_steps,
            save_strategy="steps",
            save_steps=eval_steps,
            load_best_model_at_end=True,
            metric_for_best_model=f"eval_{metric_name}",
            save_total_limit=1,
            logging_strategy="steps",
            logging_steps=logging_steps,
            report_to="none",
            fp16=fp16 and torch.cuda.is_available(),
            dataloader_num_workers=min(4, os.cpu_count() or 4),
            seed=seed,
            remove_unused_columns=False,
            gradient_checkpointing=gradient_checkpointing,
            dataloader_pin_memory=True,
            ddp_find_unused_parameters=False,
            disable_tqdm=True
        )

        def data_collator(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            return self.backend.collate(batch_list)

        def steps_for_size(sz: int, bsz: int, accum: int) -> int:
            return max(0, math.ceil(math.ceil(sz / max(1, bsz)) / max(1, accum)))

        def chunk_slices(index_array: np.ndarray, chunk_size: int):
            for i in range(0, len(index_array), chunk_size):
                yield index_array[i:i + chunk_size]

        n_train = len(df_train)
        rng = np.random.default_rng(seed)
        train_idx = np.arange(n_train)

        chunk_size = fit_chunk_size if (fit_chunk_size and fit_chunk_size > 0) else len(train_idx)

        total_steps = 0
        for _ in range(epochs):
            rng.shuffle(train_idx)
            for slc in chunk_slices(train_idx, chunk_size):
                total_steps += steps_for_size(len(slc), per_device_train_batch_size, gradient_accumulation_steps)

        dummy_idx = np.arange(min(len(df_train), 1))
        ds_train_init = (
            MultiComboDataset(
                df=df_train.iloc[dummy_idx],
                target_col=self.target_column_name,
                label2id=self.label2id,
                text_columns=self.text_columns,
                image_columns=self.image_columns,
                audio_columns=self.audio_columns,
                text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
                text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
                special_tokens=self.special_tokens,
                pretokenize=False,
                tokenizer_returns_tensors=self.tokenizer_returns_tensors
            ) if len(dummy_idx) > 0 else ds_eval
        )

        self.trainer = WeightedCETrainer(
            model=self.model,
            args=args,
            train_dataset=ds_train_init,
            eval_dataset=ds_eval,
            data_collator=data_collator,
            compute_metrics=self.compute_metrics,
            num_labels=self.num_labels,
            train_labels=y_train_all,
            class_weights=class_weights
        )
        self.trainer.remove_callback(PrinterCallback)

        if (ds_eval is not None) and (early_stopping_patience is not None) and (early_stopping_patience > 0):
            esc = EarlyStoppingCallback(
                early_stopping_patience=int(early_stopping_patience),
                early_stopping_threshold=float(early_stopping_threshold)
            )
            self.trainer.add_callback(esc)

        if total_steps > 0:
            self.trainer.create_optimizer_and_scheduler(num_training_steps=total_steps)

        pbar = tqdm(total=total_steps, desc="Training Progress", unit="step")
        cb = PbarConsoleLogger(pbar)
        self.trainer.add_callback(cb)
        self.progress_callback = cb

        steps_done = 0
        chunk_counter = 0

        for ep in range(epochs):
            rng = np.random.default_rng(seed + ep)
            shuffled = np.arange(n_train)
            rng.shuffle(shuffled)

            for slc in chunk_slices(shuffled, chunk_size):
                chunk_df = df_train.iloc[slc]
                should_pretokenize = (
                    self.pretokenize_data and has_bt and len(slc) < self.max_pretokenize_samples and len(slc) > 100
                )

                ds_chunk = MultiComboDataset(
                    df=chunk_df,
                    target_col=self.target_column_name,
                    label2id=self.label2id,
                    text_columns=self.text_columns,
                    image_columns=self.image_columns,
                    audio_columns=self.audio_columns,
                    text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
                    text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
                    special_tokens=self.special_tokens,
                    pretokenize=should_pretokenize,
                    pretokenize_batch_size=self.pretokenize_batch_size,
                    max_cache_size=min(len(slc), self.max_pretokenize_samples),
                    tokenizer_returns_tensors=self.tokenizer_returns_tensors
                )

                self.trainer.train_dataset = ds_chunk

                chunk_steps = steps_for_size(len(slc), per_device_train_batch_size, gradient_accumulation_steps)
                if chunk_steps == 0:
                    del ds_chunk, chunk_df
                    continue

                self.trainer.args.max_steps = steps_done + chunk_steps
                self.trainer.train()
                steps_done += chunk_steps

                chunk_counter += 1
                if chunk_counter % clear_cache_every_n_chunks == 0:
                    if hasattr(ds_chunk, 'clear_cache'):
                        ds_chunk.clear_cache()
                        print(f"✓ Очищен кэш токенизации после {chunk_counter} чанков")

                del ds_chunk, chunk_df
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        pbar.close()

        if getattr(self.backend, "batch_tokenizer", None):
            self.backend.batch_tokenizer.clear_cache()

        return self

    def predict(
        self,
        df: pd.DataFrame,
        return_label_str: bool = False,
        return_proba: bool = False,
        batch_size: Optional[int] = None
    ) -> np.ndarray:
        """
        Выполняет инференс на новом DataFrame и возвращает предсказания.
        Если в df отсутствует столбец target_column_name, он будет добавлен фиктивными значениями.
    
        :param df: DataFrame с теми же колонками по выбранным модальностям, что и при обучении.
                   - text_columns: строки,
                   - image_columns: пути к изображениям, PIL.Image, np.ndarray или списки таких элементов,
                   - audio_columns: пути к аудиофайлам или np.ndarray (моно, float32).
        :param return_label_str: Если True — вернуть массив строковых меток (id2label), иначе — индексы классов.
        :param return_proba: Если True — вернуть распределения вероятностей (softmax) формы [N, num_labels].
                             При включении этого флага игнорируется return_label_str.
        :param batch_size: Переопределяет per_device_eval_batch_size на время инференса (опционально).
    
        :return: 
            - Если return_proba=True: np.ndarray формы [N, num_labels] — вероятности классов.
            - Иначе: 
                - Если return_label_str=True: np.ndarray формы [N] со строковыми метками,
                - Иначе: np.ndarray формы [N] с индексами предсказанных классов.
    
        :raises RuntimeError: Если модель не обучена (trainer отсутствует).
        :raises ValueError: Ошибки приведения данных (например, несоответствие ожидаемым колонкам/типам),
                            внутренние ошибки коллатора/бэкенда (рассогласование размеров батчей).
        :raises OSError: Ошибки чтения исходных файлов (изображения/аудио).
        """
        if self.trainer is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        df_c = df.copy()
        if self.target_column_name not in df_c.columns:
            df_c[self.target_column_name] = list(self.label2id.keys())[0]

        print(f"Preparing dataset for prediction ({len(df_c)} samples)...")

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds = MultiComboDataset(
            df=df_c,
            target_col=self.target_column_name,
            label2id=self.label2id,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=(self.pretokenize_data and has_bt and len(df_c) < 10000),
            pretokenize_batch_size=self.pretokenize_batch_size,
            max_cache_size=min(len(df_c), 10000),
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        if batch_size:
            original_bs = self.trainer.args.per_device_eval_batch_size
            self.trainer.args.per_device_eval_batch_size = batch_size

        effective_batch_size = batch_size or self.trainer.args.per_device_eval_batch_size
        num_batches = (len(df_c) + effective_batch_size - 1) // effective_batch_size

        print(f"Running predictions (batch_size={effective_batch_size}, num_batches={num_batches})...")

        original_disable_tqdm = self.trainer.args.disable_tqdm
        self.trainer.args.disable_tqdm = False

        preds = self.trainer.predict(test_dataset=ds)

        self.trainer.args.disable_tqdm = original_disable_tqdm

        if batch_size:
            self.trainer.args.per_device_eval_batch_size = original_bs

        if hasattr(ds, 'clear_cache'):
            ds.clear_cache()

        if return_proba:
            logits = preds.predictions
            exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
            probabilities = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
            return probabilities

        y_pred = np.argmax(preds.predictions, axis=-1)
        if return_label_str:
            return np.array([self.id2label[int(i)] for i in y_pred])
        return y_pred

    def get_embeddings(
        self,
        df: pd.DataFrame,
        batch_size: int = 32,
        return_per_modality: bool = False
    ):
        """
        Извлекает эмбеддинги для входного DataFrame. Возвращает склеенный (fused) эмбеддинг,
        а при необходимости — также эмбеддинги по каждой модальности.
        Схема слияния (concat/mean) и размеры зависят от настроек бэкенда и параметра fusion.
    
        Если в df отсутствует столбец target_column_name, он будет добавлен фиктивными значениями.
    
        :param df: DataFrame с данными по модальностям (аналогично predict()).
        :param batch_size: Батч-размер при извлечении эмбеддингов.
        :param return_per_modality: Если True — вернуть дополнительно словарь с эмбеддингами по модальностям.
    
        :return:
            - Если return_per_modality=False:
                np.ndarray формы [N, D_fused], где D_fused — размерность эмбеддинга после слияния.
            - Если return_per_modality=True:
                (fused, per_mod) — кортеж:
                    - fused: np.ndarray [N, D_fused],
                    - per_mod: Dict[str, np.ndarray], где ключ — модальность ("text"/"image"/"audio"),
                               значение — эмбеддинги этой модальности формы [N, D_mod].
    
        :raises RuntimeError: Если модель не обучена или не готова (trainer/model отсутствуют).
        :raises ValueError: Если бэкенд не вернул эмбеддинги (например, неверные настройки модальностей/данных).
        :raises OSError: Ошибки чтения исходных файлов (изображения/аудио).
        """
        if self.trainer is None or self.model is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        try:
            device = next(self.trainer.model.parameters()).device
        except StopIteration:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model.to(device).eval()

        df_c = df.copy()
        if self.target_column_name not in df_c.columns:
            df_c[self.target_column_name] = list(self.label2id.keys())[0]

        print(f"Preparing dataset for embeddings extraction ({len(df_c)} samples)...")

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds = MultiComboDataset(
            df=df_c,
            target_col=self.target_column_name,
            label2id=self.label2id,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=False,
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        def collate(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            return self.backend.collate(batch_list)

        def move_to_device(obj, device: torch.device):
            if torch.is_tensor(obj):
                return obj.to(device)
            if isinstance(obj, dict):
                return {k: move_to_device(v, device) for k, v in obj.items()}
            if isinstance(obj, (list, tuple)):
                t = [move_to_device(v, device) for v in obj]
                return type(obj)(t) if not isinstance(obj, list) else t
            return obj

        loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
        fused_list = []
        per_mod_lists = {m: [] for m in self.modalities} if return_per_modality else None

        num_batches = (len(df_c) + batch_size - 1) // batch_size

        print(f"Extracting embeddings (batch_size={batch_size}, num_batches={num_batches})...")

        with torch.no_grad():
            for batch in tqdm(loader, total=num_batches, desc="Extracting embeddings", unit="batch", leave=True):
                bi = move_to_device(batch["backend_inputs"], device)
                fused, per = self.model.get_embeddings(backend_inputs=bi, return_per_modality=True)
                fused_list.append(fused.cpu().numpy())
                if return_per_modality:
                    for m in per_mod_lists.keys():
                        if m in per:
                            per_mod_lists[m].append(per[m].cpu().numpy())

        print("Concatenating embeddings...")
        fused_arr = np.vstack(fused_list)

        if not return_per_modality:
            print(f"✓ Embeddings shape: {fused_arr.shape}")
            return fused_arr

        per_mod = {m: np.vstack(chunks) for m, chunks in per_mod_lists.items()}
        print(f"✓ Fused embeddings shape: {fused_arr.shape}")
        for m, arr in per_mod.items():
            print(f"✓ {m.capitalize()} embeddings shape: {arr.shape}")

        return fused_arr, per_mod

Функции для создания фиктивных данных.

In [None]:
import numpy as np

def make_rand_image(h=512, w=512):
    return (np.random.rand(h, w, 3) * 255).astype("uint8")

def make_sine_audio(sr=48000, seconds=1.0, freq=440.0):
    t = np.linspace(0, seconds, int(sr * seconds), endpoint=False)
    return (0.1 * np.sin(2 * np.pi * freq * t)).astype(np.float32)

Пример использования 1.

In [None]:
import pandas as pd

# Данные (строковые метки → красиво отобразятся в predict(return_label_str=True))
df_clip_clap = pd.DataFrame([
    {"text": "A man riding a bike", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 440.0), "label": "sports"},
    {"text": "A cat lying on sofa", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 330.0), "label": "lifestyle"},
    {"text": "Stock market is volatile", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 550.0), "label": "business"},
    {"text": "Runner on the track", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 660.0), "label": "sports"},
    {"text": "New cafe opens downtown", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 220.0), "label": "lifestyle"},
    {"text": "Company reports revenue", "image": make_rand_image(), "audio": make_sine_audio(48000, 1.0, 770.0), "label": "business"},
])

pipe1 = SingleModelMultiComboClassification(
    modalities=["text", "image", "audio"],
    num_labels=3,                          # == числу уникальных меток
    target_column_name="label",
    text_columns=["text"],
    image_columns=["image"],
    audio_columns=["audio"],
    # Явные конфиги бэкендов
    text_model_config={
        "checkpoint": "openai/clip-vit-base-patch32",
        "model_type": "clip",
        "max_length": 77
    },
    image_model_config={
        "checkpoint": "openai/clip-vit-base-patch32",
        "model_type": "clip",
        "max_images": 1,
        "image_agg": "mean"
    },
    audio_model_config={
        "checkpoint": "laion/clap-htsat-unfused",
        "model_type": "clap",
        "sr": 48000,
        "max_audios": 1,
        "audio_agg": "mean"
    },
    fusion="mean",                         # у всех 512 → mean
    freeze_backbone=False,                  # linear probing
    use_batch_tokenizer=True,              # быстрый токенизатор
    pretokenize_data=True,                 # предварительная токенизация
    pretokenize_batch_size=128,
    tokenizer_cache_size=5000,
    max_pretokenize_samples=100000,
    local_cache_dir="./model_cache"
)

pipe1.fit(
    train_data=df_clip_clap,
    epochs=2,
    test_size=0.33,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=1e-4,
    metric_name="f1",
    fp16=True,                 # если есть CUDA
    logging_steps=1,
    eval_steps=2,
    output_dir="./mc_max_clip_clap",
    seed=42,
    hidden=512,
    dropout=0.2,
    gradient_checkpointing=True,
    fit_chunk_size=2,          # чанки по 2 сэмпла
    clear_cache_every_n_chunks=5,
    early_stopping_patience=2,         # ранняя остановка
    early_stopping_threshold=0.0
)

# Предсказания — вероятности
probas1 = pipe1.predict(df_clip_clap.iloc[:3], return_proba=True)
print("Probas shape:", probas1.shape)

# Предсказания — строковые метки
labels1 = pipe1.predict(df_clip_clap.iloc[:3], return_label_str=True)
print("Labels:", labels1)

# Эмбеддинги (fused + по модальностям)
fused1, per1 = pipe1.get_embeddings(df_clip_clap.iloc[:3], batch_size=2, return_per_modality=True)
print("Fused shape:", fused1.shape)
for m, arr in per1.items():
    print(f"{m} emb shape:", arr.shape)

Пример использования 2.

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

def build_binary_multimodal_df(n_per_class: int = 50, sr: int = 16000) -> pd.DataFrame:
    # Генерация текстов «pets»
    pet_animals = ["кошка", "собака", "щенок", "кот", "котёнок", "пёс", "питомец", "котик", "пёсик", "котяра"]
    pet_actions = ["сидит", "лежит", "играет", "смотрит", "прячется", "спит", "тянется", "мурлычет", "исследует", "охотится"]
    pet_places = ["на подоконнике", "на диване", "на ковре", "на кухне", "в коробке", "на кресле", "у окна", "в саду", "на полу", "на стуле"]

    # Генерация текстов «news»
    news_subjects = ["власти города", "жители района", "аналитики", "эксперты", "журналисты", "компания", "департамент", "учёные", "инженеры", "ведомство"]
    news_verbs = ["обсудили", "обновили", "сообщили", "рассказали", "анонсировали", "заявили", "подтвердили", "планируют", "запустили", "увеличили"]
    news_topics = ["экономику", "политику", "транспорт", "погоду", "технологии", "культуру", "спорт", "здравоохранение", "образование", "экологию"]

    rows = []

    # PETS класс
    pet_freqs = [260.0, 280.0, 300.0, 320.0, 340.0, 360.0]
    for _ in range(n_per_class):
        text = f"{random.choice(pet_animals).capitalize()} {random.choice(pet_actions)} {random.choice(pet_places)}"
        img = make_rand_image()
        aud = make_sine_audio(sr, 1.0, random.choice(pet_freqs))
        rows.append({"text": text, "image": img, "audio": aud, "label": "pets"})

    # NEWS класс
    news_freqs = [560.0, 580.0, 600.0, 620.0, 640.0, 660.0]
    for _ in range(n_per_class):
        text = f"{random.choice(news_subjects).capitalize()} {random.choice(news_verbs)} новости про {random.choice(news_topics)}"
        img = make_rand_image()
        aud = make_sine_audio(sr, 1.0, random.choice(news_freqs))
        rows.append({"text": text, "image": img, "audio": aud, "label": "news"})

    random.shuffle(rows)
    df = pd.DataFrame(rows)
    return df

train_data = build_binary_multimodal_df(n_per_class=12)
df_train, df_eval = train_test_split(
    train_data, test_size=0.3, random_state=42, shuffle=True,
    stratify=train_data['label'])

pipe3 = SingleModelMultiComboClassification(
    modalities=["text", "image", "audio"],
    num_labels=2,
    target_column_name="label",
    text_columns=["text"],
    image_columns=["image"],
    audio_columns=["audio"],
    text_model_config={
        "checkpoint": "DeepPavlov/rubert-base-cased",
        "model_type": "bert",
        "max_length": 256
    },
    image_model_config={
        "checkpoint": "google/vit-base-patch16-224",
        "model_type": "vit",
        "max_images": 1,
        "image_agg": "mean"
    },
    audio_model_config={
        "checkpoint": "facebook/wav2vec2-base-960h",
        "model_type": "auto",   # AutoModel + AutoProcessor
        "sr": 16000,
        "max_audios": 1,
        "audio_agg": "mean"
    },
    fusion="concat",
    freeze_backbone=True,
    use_batch_tokenizer=True,
    pretokenize_data=True,
    pretokenize_batch_size=128,
    tokenizer_cache_size=5000,
    local_cache_dir="./model_cache"
)

pipe3.fit(
    train_data=df_train,
    test_data=df_eval,
    epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    metric_name="f1",
    fp16=True,
    logging_steps=1,
    eval_steps=1,
    output_dir="./mc_max_rubert_vit_w2v2",
    seed=2025,
    hidden=512,
    dropout=0.2,
    gradient_checkpointing=True,
    fit_chunk_size=8,
    clear_cache_every_n_chunks=3,
    early_stopping_patience=2,
    early_stopping_threshold=0.0
)

print("Pred (labels):", pipe3.predict(df_rubert_vit_w2v, return_label_str=True))
fused3, per3 = pipe3.get_embeddings(df_rubert_vit_w2v, batch_size=2, return_per_modality=True)
print("Fused:", fused3.shape, "| text:", per3["text"].shape, "| image:", per3["image"].shape, "| audio:", per3["audio"].shape)

Пример использования 3.

In [None]:
import pandas as pd

df_min = pd.DataFrame([
    {"text": "Привет, как дела?", "label": "greet"},
    {"text": "Сегодня отличная погода", "label": "weather"},
    {"text": "До встречи!", "label": "greet"},
])

pipe_min = SingleModelMultiComboClassification(
    modalities=["text"],
    num_labels=2,                           # ровно столько, сколько уникальных меток в df_min
    target_column_name="label",
    text_columns=["text"],
    # Достаточно указать только текстовую модель
    text_model_config={
        "checkpoint": "DeepPavlov/rubert-base-cased",
        "model_type": "bert",
        "max_length": 128
    },
    fusion="concat",                        # неважно для single-modality
    freeze_backbone=True
)

# Минимальный fit: всё по умолчанию (без ранней остановки, без чанков)
pipe_min.fit(train_data=df_min)

print("Pred (ids):", pipe_min.predict(df_min))
print("Pred (labels):", pipe_min.predict(df_min, return_label_str=True))
emb_min = pipe_min.get_embeddings(df_min)
print("Embeddings shape:", emb_min.shape)

# Дообучение регрессора, который работает с данными разной модальностью.

Реализация регрессора.

In [None]:
!pip install --upgrade --no-cache-dir \
  --extra-index-url https://download.pytorch.org/whl/cu124 \
  pillow==11.1.0 \
  numpy==1.26.4 \
  pandas==2.2.3 \
  tqdm==4.67.1 \
  transformers==4.51.3 \
  evaluate==0.4.5 \
  wav2clip==0.1.0 \
  torch==2.6.0+cu124 \
  torchaudio==2.6.0+cu124


import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import math
import random
import gc
from functools import lru_cache
from typing import Any, Callable, Dict, Generator, List, Optional, Union

from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from transformers import TrainingArguments, Trainer
from transformers.trainer_callback import TrainerCallback, PrinterCallback, EarlyStoppingCallback
from transformers.modeling_outputs import SequenceClassifierOutput
import evaluate


# =========================
# Утилиты
# =========================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def to_pil(x: Union[str, np.ndarray, 'Image.Image']) -> 'Image.Image':
    if isinstance(x, Image.Image): return x.convert("RGB")
    if isinstance(x, str): return Image.open(x).convert("RGB")
    if isinstance(x, np.ndarray): return Image.fromarray(x).convert("RGB")
    raise ValueError("Ожидается путь/np.ndarray/PIL.Image")


def load_audio(path: str, target_sr: int) -> np.ndarray:
    try:
        import torchaudio
    except Exception as e:
        raise RuntimeError("Требуется torchaudio: pip install torchaudio") from e
    waveform, sr = torchaudio.load(path)
    if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True)
    if sr != target_sr: waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=target_sr)
    return waveform.squeeze(0).numpy().astype(np.float32)


def safe_load(component_cls, checkpoint: str, local_cache_dir: str = "./model_cache",
              local_files_only: Optional[bool] = None, **kwargs):
    if local_files_only is None:
        local_files_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
    name = getattr(component_cls, "__name__", "")
    if "Tokenizer" in name:
        kwargs.setdefault("use_fast", True)
    return component_cls.from_pretrained(
        checkpoint, cache_dir=local_cache_dir, local_files_only=local_files_only, **kwargs
    )


# =========================
# Токенизатор батчевый
# =========================

class BatchTokenizer:
    def __init__(
        self,
        tokenizer,
        max_length: int = 512,
        cache_size: int = 10000,
        batch_size: int = 256,
        use_fast: bool = True,
        device: str = "cpu"
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_fast = use_fast
        self.device = device
        self._cache = lru_cache(maxsize=cache_size)(self._tokenize_single)
        self.is_fast = hasattr(tokenizer, "is_fast") and tokenizer.is_fast
        if self.is_fast:
            print("✓ Используется Fast Tokenizer")

    def _tokenize_single(self, text: str) -> tuple:
        result = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return tuple((k, v.squeeze(0).cpu().numpy()) for k, v in result.items())

    def tokenize_batch(self, texts: List[str], use_cache: bool = True) -> Dict[str, torch.Tensor]:
        if use_cache and len(texts) < 100:
            results = [dict(self._cache(text)) for text in texts]
            keys = results[0].keys()
            batch_dict = {}
            for key in keys:
                dtype = torch.long if key in ["input_ids", "attention_mask", "token_type_ids"] else torch.float32
                batch_dict[key] = torch.tensor(np.stack([r[key] for r in results]), dtype=dtype)
            return batch_dict
        else:
            result = self.tokenizer(
                texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
            for key in result:
                if key in ["input_ids", "attention_mask", "token_type_ids"]:
                    result[key] = result[key].long()
            return result

    def tokenize_dataset_lazy(
        self,
        texts: List[str],
        batch_size: Optional[int] = None
    ) -> Generator[Dict[str, torch.Tensor], None, None]:
        batch_size = batch_size or self.batch_size
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            yield self.tokenize_batch(batch, use_cache=False)

    def clear_cache(self):
        self._cache.cache_clear()


# =========================
# Универсальный датасет
# =========================

class MultiComboDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        target_col: str,
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_tokenizer: Optional[BatchTokenizer] = None,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        pretokenize: bool = False,
        pretokenize_batch_size: int = 256,
        max_cache_size: int = 100000,
        tokenizer_returns_tensors: bool = False,
        cache_dir: Optional[str] = None
    ):
        self.df = df.reset_index(drop=True)
        self.target_col = target_col

        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []

        self.text_tokenizer = text_tokenizer
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors

        self.tokenized_cache: Dict[int, Dict[str, torch.Tensor]] = {}
        self.cache_hits = 0
        self.cache_misses = 0

        if pretokenize and self.text_tokenizer and self.text_columns:
            self._pretokenize_texts(
                batch_size=pretokenize_batch_size,
                max_cache_size=min(max_cache_size, len(self.df))
            )

    def _join_text(self, row: pd.Series) -> str:
        sep = self.special_tokens.get("sep", " [SEP] ")
        return sep.join([str(row[c]) if pd.notnull(row[c]) else "" for c in self.text_columns])

    def _pretokenize_texts(self, batch_size: int, max_cache_size: int):
        print("Предварительная токенизация текстов...")
        indices = list(range(min(len(self.df), max_cache_size)))
        all_texts = [self._join_text(self.df.iloc[i]) for i in indices]

        for start in range(0, len(indices), batch_size):
            batch_idx = indices[start:start + batch_size]
            batch_txt = all_texts[start:start + batch_size]
            tokenized = self.text_tokenizer.tokenize_batch(batch_txt, use_cache=False)

            for j, idx in enumerate(batch_idx):
                token_dict: Dict[str, torch.Tensor] = {}
                for k, v in tokenized.items():
                    t = v[j]
                    token_dict[k] = t.clone().long() if k in ["input_ids", "attention_mask", "token_type_ids"] else t.clone()
                self.tokenized_cache[idx] = token_dict

        print(f"✓ Предварительно токенизировано {len(self.tokenized_cache)} текстов")

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]
        item: Dict[str, Any] = {}

        # Регрессия: метка всегда float
        if self.target_col in row:
            try:
                item["labels"] = float(row[self.target_col])
            except Exception:
                item["labels"] = 0.0
        else:
            item["labels"] = 0.0

        # Текст
        if self.text_columns:
            if idx in self.tokenized_cache:
                cached = self.tokenized_cache[idx]
                text_tokens: Dict[str, torch.Tensor] = {}
                for k, v in cached.items():
                    text_tokens[k] = v.long() if k in ["input_ids", "attention_mask", "token_type_ids"] else v
                item["text_tokens"] = text_tokens
                self.cache_hits += 1
            elif self.text_tokenizer is not None:
                text = self._join_text(row)
                tokenized = self.text_tokenizer.tokenize_batch([text], use_cache=True)
                text_tokens = {k: (v[0].long() if k in ["input_ids", "attention_mask", "token_type_ids"] else v[0])
                               for k, v in tokenized.items()}
                item["text_tokens"] = text_tokens
                self.cache_misses += 1
                if len(self.tokenized_cache) < 100000:
                    self.tokenized_cache[idx] = {k: t.clone() for k, t in text_tokens.items()}
            elif self.text_tokenizer_fn is not None:
                text_data = {c: str(row[c]) if pd.notnull(row[c]) else "" for c in self.text_columns}
                result = self.text_tokenizer_fn(text_data, self.special_tokens)
                if isinstance(result, dict) and 'input_ids' in result:
                    item["text_tokens"] = result
                    self.tokenizer_returns_tensors = True
                else:
                    item["text"] = result
            else:
                item["text"] = self._join_text(row)

        # Утилита для нормализации поля в список
        def _as_list(v):
            if v is None or (isinstance(v, float) and np.isnan(v)):
                return []
            if isinstance(v, (list, tuple)):
                return list(v)
            return [v]

        # Картинки
        if self.image_columns:
            imgs = []
            for c in self.image_columns:
                if c in row:
                    imgs.extend(_as_list(row[c]))
            item["images"] = imgs

        # Аудио
        if self.audio_columns:
            auds = []
            for c in self.audio_columns:
                if c in row:
                    auds.extend(_as_list(row[c]))
            item["audios"] = auds

        return item

    def get_cache_stats(self) -> Dict[str, Any]:
        total = self.cache_hits + self.cache_misses
        hit_rate = self.cache_hits / total if total > 0 else 0.0
        return {
            "cache_size": len(self.tokenized_cache),
            "cache_hits": self.cache_hits,
            "cache_misses": self.cache_misses,
            "hit_rate": hit_rate
        }

    def clear_cache(self):
        self.tokenized_cache.clear()
        self.cache_hits = 0
        self.cache_misses = 0
        if self.text_tokenizer:
            self.text_tokenizer.clear_cache()


# =========================
# Универсальный бэкенд
# =========================

class BaseBackend(nn.Module):
    name: str = "base"
    supported: set = set()
    embed_dim: int = 0
    out_dim_per_modality: Dict[str, int] = {}
    text_tokenizer_fn: Optional[Callable] = None
    batch_tokenizer: Optional[BatchTokenizer] = None
    special_tokens: Dict[str, str] = {}
    tokenizer_returns_tensors: bool = False
    local_cache_dir: str = "./model_cache"

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        raise NotImplementedError

    def encode(self, backend_inputs: Dict[str, Any], device: torch.device) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def freeze_all(self):
        for p in self.parameters():
            p.requires_grad = False

    def get_out_dim(self, modality: str) -> int:
        return self.out_dim_per_modality.get(modality, self.embed_dim)

    def set_text_tokenizer(self, tokenizer_fn: Optional[Callable], special_tokens: Optional[Dict[str, str]] = None,
                           returns_tensors: bool = False):
        self.text_tokenizer_fn = tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = returns_tensors

    def set_batch_tokenizer(self, tokenizer, max_length: int = 512,
                            cache_size: int = 10000, batch_size: int = 256):
        self.batch_tokenizer = BatchTokenizer(
            tokenizer=tokenizer,
            max_length=max_length,
            cache_size=cache_size,
            batch_size=batch_size,
            use_fast=True
        )


class UniversalMultiBackend(BaseBackend):
    name = "universal"
    
    class _ParamDeviceProxy(nn.Module):
        def __init__(self, base, device: torch.device):
            super().__init__()
            self.base = base if isinstance(base, nn.Module) else None
            self._callable = base if not isinstance(base, nn.Module) else None
            self._dummy = nn.Parameter(torch.empty(0), requires_grad=False)
            with torch.no_grad():
                self._dummy.data = self._dummy.data.to(device)
            try:
                target = self.base if self.base is not None else self._callable
                if hasattr(target, "to"):
                    target.to(device)
            except Exception:
                pass
    
        def forward(self, *args, **kwargs):
            target = self.base if self.base is not None else self._callable
            return target(*args, **kwargs)
    
        def to(self, device, *args, **kwargs):
            self._dummy.data = self._dummy.data.to(device)
            try:
                target = self.base if self.base is not None else self._callable
                if hasattr(target, "to"):
                    target.to(device)
            except Exception:
                pass
            return super().to(device, *args, **kwargs)

    def _model_device(self, model, default: torch.device) -> torch.device:
        try:
            return next(model.parameters()).device
        except StopIteration:
            pass
        try:
            buf = next(model.buffers())
            return buf.device
        except StopIteration:
            pass
        return default
    
    def _preferred_device(self) -> torch.device:
        if torch.cuda.is_available():
            return torch.device("cuda")
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")
    
    def _wrap_if_parameterless(self, model, device: torch.device):
        try:
            it = model.parameters() if hasattr(model, "parameters") else iter(())
            next(it)
            return model
        except StopIteration:
            return self._ParamDeviceProxy(model, device)
        except Exception:
            return self._ParamDeviceProxy(model, device)

    def __init__(
        self,
        text_model_config: Optional[Dict[str, Any]] = None,
        image_model_config: Optional[Dict[str, Any]] = None,
        audio_model_config: Optional[Dict[str, Any]] = None,
        freeze: bool = True,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        tokenizer_returns_tensors: bool = False,
        use_batch_tokenizer: bool = True,
        tokenizer_cache_size: int = 10000,
        tokenizer_batch_size: int = 256,
        local_cache_dir: str = "./model_cache"
    ):
        super().__init__()
        self.supported = set()
        self.out_dim_per_modality = {}
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors
        self.use_batch_tokenizer = use_batch_tokenizer
        self.tokenizer_cache_size = tokenizer_cache_size
        self.tokenizer_batch_size = tokenizer_batch_size
        self.local_cache_dir = local_cache_dir

        self.text_model = None
        self.text_processor = None
        self.text_config = text_model_config or {}
        if text_model_config:
            self._init_text_model(text_model_config)
            self.supported.add("text")

        self.image_model = None
        self.image_processor = None
        self.image_config = image_model_config or {}
        if image_model_config:
            self._init_image_model(image_model_config)
            self.supported.add("image")

        self.audio_model = None
        self.audio_processor = None
        self.audio_config = audio_model_config or {}
        if audio_model_config:
            self._init_audio_model(audio_model_config)
            self.supported.add("audio")

        if freeze:
            self.freeze_all()

    def _ensure_2d(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if x is None:
            return None
        if x.dim() == 1:
            return x.unsqueeze(0)
        if x.dim() > 2:
            return x.view(x.size(0), -1)
        return x

    def _normalize_2d(self, x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        x = self._ensure_2d(x)
        return F.normalize(x, dim=-1) if x is not None and x.numel() > 0 else x
    
    def _init_text_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPTokenizer, ClapModel, ClapProcessor
    
        checkpoint = config['checkpoint']
        model_type = config.get('model_type', 'auto').lower()
    
        print(f"Загрузка текстовой модели {checkpoint}...")
    
        if model_type == 'clip':
            self.text_model = safe_load(CLIPModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = safe_load(CLIPTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = self.text_model.config.projection_dim
        elif model_type == 'clap':
            self.text_model = safe_load(ClapModel, checkpoint, local_cache_dir=self.local_cache_dir)
            proc = safe_load(ClapProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = getattr(proc, 'tokenizer', None) or safe_load(AutoTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = getattr(self.text_model.config, "projection_dim", 512)
        else:
            self.text_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.text_processor = safe_load(AutoTokenizer, checkpoint, local_cache_dir=self.local_cache_dir, use_fast=True)
            dim = self.text_model.config.hidden_size

        dev = self._preferred_device()
        self.text_model  = self._wrap_if_parameterless(self.text_model, dev)
    
        if self.use_batch_tokenizer and self.text_processor is not None:
            self.set_batch_tokenizer(
                self.text_processor,
                max_length=config.get('max_length', 512),
                cache_size=self.tokenizer_cache_size,
                batch_size=self.tokenizer_batch_size
            )
    
        self.text_config['max_length'] = config.get('max_length', 512)
        self.text_config['dim'] = dim
        self.text_config['model_type'] = model_type
        self.out_dim_per_modality['text'] = dim

    def _init_image_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoImageProcessor, CLIPModel, CLIPImageProcessor
    
        checkpoint = config['checkpoint']
        model_type = config.get('model_type', 'auto').lower()
    
        print(f"Загрузка визуальной модели {checkpoint}...")
    
        if model_type == 'clip':
            self.image_model = safe_load(CLIPModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.image_processor = safe_load(CLIPImageProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.image_model.config.projection_dim
        else:
            self.image_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.image_processor = safe_load(AutoImageProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.image_model.config.hidden_size

        dev = self._preferred_device()
        self.image_model = self._wrap_if_parameterless(self.image_model, dev)
    
        self.image_config['max_images'] = config.get('max_images', 1)
        self.image_config['image_agg'] = config.get('image_agg', 'concat')
        self.image_config['dim'] = dim
        self.image_config['model_type'] = model_type
    
        self.out_dim_per_modality['image'] = (dim * self.image_config['max_images']) if self.image_config['image_agg'] == 'concat' else dim
    
    def _init_audio_model(self, config: Dict[str, Any]):
        from transformers import AutoModel, AutoProcessor, ClapModel, ClapProcessor
    
        model_type = config.get('model_type', 'auto').lower()
        checkpoint = config.get('checkpoint', None)
    
        print(f"Загрузка аудио модели (type={model_type})...")
    
        if model_type == 'wav2clip':
            import wav2clip as w2c
            self._w2c = w2c
    
            w2c_model = None
            if hasattr(w2c, "get_model"):
                w2c_model = w2c.get_model()
            elif hasattr(w2c, "model"):
                m = w2c.model
                w2c_model = m() if callable(m) else m
            else:
                raise RuntimeError("wav2clip не содержит get_model()/model. Обновите пакет wav2clip.")
    
            self.audio_model = w2c_model
    
            try:
                if isinstance(self.audio_model, torch.nn.Module) and torch.cuda.is_available():
                    self.audio_model = self.audio_model.to("cuda")
            except Exception:
                pass
    
            self.audio_processor = None
            dim = 512
            sr = config.get('sr', 16000)

        elif model_type == 'clap':
            if checkpoint is None:
                raise ValueError("audio_model_config['checkpoint'] обязателен для CLAP")
            self.audio_model = safe_load(ClapModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.audio_processor = safe_load(ClapProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = getattr(self.audio_model.config, "projection_dim", 512)
            sr = getattr(self.audio_processor, "sampling_rate", None)
            if sr is None:
                fe = getattr(self.audio_processor, "feature_extractor", None)
                sr = getattr(fe, "sampling_rate", 48000) if fe is not None else 48000
    
        else:
            if checkpoint is None:
                raise ValueError("audio_model_config['checkpoint'] обязателен для аудио-моделей, кроме wav2clip")
            self.audio_model = safe_load(AutoModel, checkpoint, local_cache_dir=self.local_cache_dir)
            self.audio_processor = safe_load(AutoProcessor, checkpoint, local_cache_dir=self.local_cache_dir)
            dim = self.audio_model.config.hidden_size
            fe = getattr(self.audio_processor, "feature_extractor", None)
            sr = getattr(fe, "sampling_rate", 16000) if fe is not None else 16000

        dev = self._preferred_device()
        self.audio_model = self._wrap_if_parameterless(self.audio_model, dev)
    
        self.audio_config['sr'] = config.get('sr', sr)
        self.audio_config['max_audios'] = config.get('max_audios', 1)
        self.audio_config['audio_agg'] = config.get('audio_agg', 'concat')
        self.audio_config['dim'] = dim
        self.audio_config['model_type'] = model_type
    
        self.out_dim_per_modality['audio'] = (
            dim * self.audio_config['max_audios']
            if self.audio_config['audio_agg'] == 'concat' else dim
        )

    def collate(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        labels = []
        for b in batch:
            labels.append(torch.tensor(b.get("labels", 0.0), dtype=torch.float32))
        labels = torch.stack(labels)

        backend_inputs: Dict[str, Any] = {}
        batch_size = len(batch)

        # Текст
        if self.text_model is not None:
            if "text_tokens" in batch[0]:
                text_inputs = {}
                for key in batch[0]["text_tokens"].keys():
                    if torch.is_tensor(batch[0]["text_tokens"][key]):
                        text_inputs[key] = torch.stack([b["text_tokens"][key] for b in batch])
                    else:
                        dtype = torch.long if key in ["input_ids", "attention_mask", "token_type_ids"] else torch.float32
                        text_inputs[key] = torch.tensor([b["text_tokens"][key] for b in batch], dtype=dtype)
                backend_inputs["text_inputs"] = text_inputs
            else:
                texts = [b.get("text", "") or " " for b in batch]
                if self.batch_tokenizer:
                    text_inputs = self.batch_tokenizer.tokenize_batch(texts, use_cache=True)
                else:
                    text_inputs = self.text_processor(
                        texts, padding=True, truncation=True,
                        max_length=self.text_config.get('max_length', 512),
                        return_tensors="pt"
                    )
                backend_inputs["text_inputs"] = {k: v for k, v in text_inputs.items()}

        # Изображения
        if self.image_model is not None:
            images_lists = [b.get("images", []) for b in batch]
            flat_images, img_counts = [], []
            for lst in images_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [img for img in lst if img is not None]
                img_counts.append(len(lst))
                for img in lst:
                    flat_images.append(to_pil(img))

            if len(flat_images) > 0:
                img_proc = self.image_processor(images=flat_images, return_tensors="pt")
                backend_inputs["image_inputs"] = {"pixel_values": img_proc["pixel_values"]}
            else:
                backend_inputs["image_inputs"] = {"pixel_values": None}

            backend_inputs["image_counts"] = torch.tensor(img_counts, dtype=torch.long)

        # Аудио
        if self.audio_model is not None:
            audios_lists = [b.get("audios", []) for b in batch]
            flat_audios, aud_counts = [], []
            for lst in audios_lists:
                lst = lst if isinstance(lst, list) else ([] if lst is None else [lst])
                lst = [a for a in lst if a is not None]
                aud_counts.append(len(lst))
                for a in lst:
                    if isinstance(a, str):
                        flat_audios.append(load_audio(a, self.audio_config['sr']))
                    elif isinstance(a, np.ndarray):
                        aa = np.asarray(a, dtype=np.float32)
                        if aa.ndim > 1:
                            aa = np.squeeze(aa)
                        if aa.ndim > 1:
                            aa = aa.reshape(-1)
                        flat_audios.append(aa)
            if self.audio_config.get('model_type') == 'wav2clip':
                backend_inputs["audio_inputs"] = {"raw_audios": flat_audios}
            elif len(flat_audios) > 0:
                if self.audio_config.get('model_type') == 'clap':
                    aud_proc = self.audio_processor(
                        audios=flat_audios, sampling_rate=self.audio_config['sr'], padding=True, return_tensors="pt"
                    )
                    backend_inputs["audio_inputs"] = {"input_features": aud_proc["input_features"]}
                else:
                    aud_proc = self.audio_processor(
                        flat_audios, sampling_rate=self.audio_config['sr'], padding=True, return_tensors="pt"
                    )
                    backend_inputs["audio_inputs"] = {"input_values": aud_proc["input_values"]}
            else:
                backend_inputs["audio_inputs"] = {"input_features": None, "input_values": None, "raw_audios": []}
        
            backend_inputs["audio_counts"] = torch.tensor(aud_counts, dtype=torch.long)

        backend_inputs["batch_size"] = batch_size
        return {"labels": labels, "backend_inputs": backend_inputs}

    def _aggregate_embeddings(
        self,
        embs: Optional[torch.Tensor],
        counts: List[int],
        max_k: int,
        dim_hint: int,
        agg_type: str,
        batch_size: int,
        device: torch.device
    ) -> torch.Tensor:
        if embs is None or (torch.is_tensor(embs) and embs.numel() == 0):
            feat_dim = int(dim_hint) if dim_hint is not None else 0
            out_dim = feat_dim * max_k if agg_type == 'concat' else feat_dim
            return torch.zeros((batch_size, out_dim), device=device, dtype=torch.float32)
    
        if not torch.is_tensor(embs):
            embs = torch.as_tensor(embs, device=device, dtype=torch.float32)
        if embs.dim() == 1:
            embs = embs.unsqueeze(0)
        elif embs.dim() > 2:
            embs = embs.view(embs.size(0), -1)
    
        N, D = embs.size()
        out_dim = (D * max_k) if agg_type == 'concat' else D
        out = torch.zeros((batch_size, out_dim), device=device, dtype=embs.dtype)
    
        offset = 0
        for i, c in enumerate(counts):
            if c <= 0 or offset >= N:
                continue
            take_n = min(c, N - offset)
            sample = embs[offset:offset + take_n]
            offset += take_n
    
            if agg_type == 'concat':
                take = sample[:max_k]
                if take.size(0) < max_k:
                    pad = torch.zeros((max_k - take.size(0), D), device=device, dtype=embs.dtype)
                    take = torch.cat([take, pad], dim=0)
                out[i] = take.reshape(-1)
            else:
                out[i] = sample.mean(dim=0)
    
        return F.normalize(out, dim=-1) if out.size(1) > 0 else out

    @torch.no_grad()
    def _wav2clip_embed(self, arr: np.ndarray, device: torch.device) -> torch.Tensor:
        arr = np.asarray(arr, dtype=np.float32)
        if arr.ndim > 1:
            arr = np.squeeze(arr)
        if arr.ndim > 1:
            arr = arr.reshape(-1)
        if arr.size < 512:
            arr = np.pad(arr, (0, 512 - arr.size), mode="constant")
    
        try:
            emb = self._w2c.embed_audio(arr, self.audio_model)
            emb = np.asarray(emb)
        except Exception:
            model_dev = self._model_device(self.audio_model, default=device)
            x = torch.from_numpy(arr).float().unsqueeze(0).to(model_dev)
            y = self.audio_model(x)
            if isinstance(y, (tuple, list)):
                y = y[0]
            if torch.is_tensor(y):
                if y.dim() == 2 and y.size(0) == 1:
                    y = y.squeeze(0)
                emb = y.detach().cpu().numpy()
            else:
                emb = np.asarray(y)
    
        if emb.ndim > 1:
            emb = emb.reshape(-1)
        return torch.as_tensor(emb, device=device, dtype=torch.float32)
    
    def encode(self, backend_inputs: Dict[str, Any], device: torch.device) -> Dict[str, torch.Tensor]:
        results: Dict[str, torch.Tensor] = {}
        batch_size = int(backend_inputs.get("batch_size", 1))
    
        # Текст
        if self.text_model is not None and "text_inputs" in backend_inputs:
            text_inputs = {k: v.to(device) for k, v in backend_inputs["text_inputs"].items()}
            if hasattr(self.text_model, "get_text_features"):
                text_z = self.text_model.get_text_features(**text_inputs)
            else:
                outputs = self.text_model(**text_inputs)
                text_z = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state.mean(dim=1)
            results["text"] = self._normalize_2d(text_z)
    
        # Изображения
        if self.image_model is not None and "image_inputs" in backend_inputs:
            pi = backend_inputs["image_inputs"]["pixel_values"]
            counts = backend_inputs["image_counts"].tolist()
            total_images_needed = sum(counts)
    
            img_flat = None
            actual_img_dim = self.image_config.get("dim", 768)
    
            if pi is not None and pi.numel() > 0 and total_images_needed > 0:
                pi = pi.to(device)
                if pi.size(0) > total_images_needed:
                    pi = pi[:total_images_needed]
    
                if hasattr(self.image_model, "get_image_features"):
                    img_flat = self.image_model.get_image_features(pixel_values=pi)
                else:
                    outputs = self.image_model(pixel_values=pi)
                    img_flat = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state[:, 0]
    
                img_flat = self._normalize_2d(img_flat)
                actual_img_dim = img_flat.size(1) if img_flat is not None else actual_img_dim
    
            img_z = self._aggregate_embeddings(
                img_flat, counts,
                self.image_config["max_images"],
                actual_img_dim,
                self.image_config["image_agg"],
                len(counts),
                device
            )
    
            if actual_img_dim != self.image_config.get("dim"):
                self.image_config["dim"] = actual_img_dim
                self.out_dim_per_modality["image"] = (
                    actual_img_dim * self.image_config["max_images"]
                    if self.image_config["image_agg"] == "concat" else actual_img_dim
                )
    
            results["image"] = img_z
    
        # Аудио
        if self.audio_model is not None and "audio_inputs" in backend_inputs:
            counts = backend_inputs["audio_counts"].tolist()
            total_audios_needed = sum(counts)
    
            aud_flat = None
            actual_aud_dim = self.audio_config.get("dim", 768)
            model_type = self.audio_config.get("model_type")
    
            if total_audios_needed > 0:
                if model_type == "clap":
                    af = backend_inputs["audio_inputs"]["input_features"]
                    if af is not None and af.numel() > 0:
                        af = af.to(device)
                        if af.size(0) > total_audios_needed:
                            af = af[:total_audios_needed]
                        with torch.cuda.amp.autocast(enabled=False):
                            aud_flat = self.audio_model.get_audio_features(input_features=af.float())
                        aud_flat = self._normalize_2d(aud_flat.float())
                        actual_aud_dim = aud_flat.size(1)
    
                elif model_type == "wav2clip":
                    raw_list = backend_inputs["audio_inputs"].get("raw_audios", [])
                    if len(raw_list) > total_audios_needed:
                        raw_list = raw_list[:total_audios_needed]
                    if len(raw_list) > 0:
                        embs = [self._wav2clip_embed(arr, device) for arr in raw_list]
                        aud_flat = torch.stack(embs, dim=0)
                        aud_flat = self._normalize_2d(aud_flat)
                        actual_aud_dim = aud_flat.size(1)
    
                else:
                    av = backend_inputs["audio_inputs"]["input_values"]
                    if av is not None and av.numel() > 0:
                        av = av.to(device)
                        if av.size(0) > total_audios_needed:
                            av = av[:total_audios_needed]
                        av = av.clamp_(-1.0, 1.0)
                        with torch.cuda.amp.autocast(enabled=False):
                            outputs = self.audio_model(input_values=av.float())
                            feats = outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state.mean(dim=1)
                        aud_flat = self._normalize_2d(feats.float())
                        actual_aud_dim = aud_flat.size(1)
    
            aud_z = self._aggregate_embeddings(
                aud_flat, counts,
                self.audio_config["max_audios"],
                actual_aud_dim,
                self.audio_config["audio_agg"],
                len(counts),
                device
            )
    
            if aud_flat is not None and actual_aud_dim != self.audio_config.get("dim"):
                self.audio_config["dim"] = actual_aud_dim
                self.out_dim_per_modality["audio"] = (
                    actual_aud_dim * self.audio_config["max_audios"]
                    if self.audio_config["audio_agg"] == "concat" else actual_aud_dim
                )
    
            results["audio"] = aud_z
    
        # Строгая проверка согласованности размеров батча между модальностями
        if results:
            bs_list = [v.size(0) for v in results.values()]
            if len(set(bs_list)) != 1:
                raise RuntimeError(f"Inconsistent batch sizes across modalities: {bs_list}")
    
        return results


# =========================
# Классификатор (голова регрессии)
# =========================

class SingleBackboneClassifier(nn.Module):
    def __init__(
        self,
        backend: BaseBackend,
        modalities: List[str],
        num_labels: int,
        fusion: str = "concat",
        hidden: int = 512,
        dropout: float = 0.1
    ):
        super().__init__()
        self.backend = backend
        self.modalities = modalities
        self.fusion = fusion
        self.num_labels = num_labels

        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        if fusion == "concat":
            in_dim = sum(self.backend.get_out_dim(m) for m in order)
        elif fusion == "mean":
            dims = [self.backend.get_out_dim(m) for m in order]
            if len(set(dims)) != 1:
                raise ValueError('Для fusion="mean" размеры модальностей должны совпадать')
            in_dim = dims[0]
        else:
            raise ValueError('fusion должен быть "concat" или "mean"')

        self.head = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_labels)
        )

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        for m in [getattr(self.backend, "text_model", None),
                  getattr(self.backend, "image_model", None),
                  getattr(self.backend, "audio_model", None)]:
            if m is None:
                continue
            try:
                has_trainable = any(p.requires_grad for p in m.parameters()) if hasattr(m, "parameters") else False
            except Exception:
                has_trainable = False
            if not has_trainable:
                continue
            try:
                cfg = getattr(m, "config", None)
                if cfg is not None and hasattr(cfg, "use_cache"):
                    cfg.use_cache = False
            except Exception:
                pass
            try:
                if hasattr(m, "gradient_checkpointing_enable"):
                    try:
                        if gradient_checkpointing_kwargs is not None:
                            m.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
                        else:
                            m.gradient_checkpointing_enable()
                    except TypeError:
                        m.gradient_checkpointing_enable()
            except Exception:
                pass

    def gradient_checkpointing_disable(self):
        for m in [getattr(self.backend, "text_model", None),
                  getattr(self.backend, "image_model", None),
                  getattr(self.backend, "audio_model", None)]:
            if m is None:
                continue
            try:
                if hasattr(m, "gradient_checkpointing_disable"):
                    m.gradient_checkpointing_disable()
            except Exception:
                pass

    def _infer_device_from_inputs(self, obj) -> torch.device:
        if isinstance(obj, torch.Tensor):
            return obj.device
        if isinstance(obj, dict):
            for v in obj.values():
                d = self._infer_device_from_inputs(v)
                if d is not None:
                    return d
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _fuse(self, z: Dict[str, torch.Tensor]) -> torch.Tensor:
        order = [m for m in ["image", "text", "audio"] if m in self.modalities]
        feats = []
        batch_size = None
        for m in order:
            if m in z:
                t = z[m]
                if t.dim() == 3:
                    t = t.mean(dim=1)
                elif t.dim() > 3:
                    t = t.view(t.size(0), -1)
                feats.append(t)
                if batch_size is None:
                    batch_size = t.size(0)
        if batch_size is not None:
            feats = [f[:batch_size] for f in feats]
        if self.fusion == "concat":
            return torch.cat(feats, dim=-1)
        elif self.fusion == "mean":
            return torch.stack(feats, dim=0).mean(dim=0)

    def forward(self, backend_inputs: Dict[str, Any], labels: Optional[torch.Tensor] = None):
        device = self._infer_device_from_inputs(backend_inputs)
        z = self.backend.encode(backend_inputs, device=device)
        if not z:
            raise ValueError("Backend не вернул эмбеддинги")
        fused = self._fuse(z)
        logits = self.head(fused)
        return SequenceClassifierOutput(logits=logits)

    @torch.no_grad()
    def get_embeddings(self, backend_inputs: Dict[str, Any], return_per_modality: bool = False):
        device = self._infer_device_from_inputs(backend_inputs)
        z = self.backend.encode(backend_inputs, device=device)
        fused = self._fuse(z)
        if return_per_modality:
            return fused, z
        return fused


# =========================
# Trainer для регрессии (MSE)
# =========================

class MSETrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch: Optional[int] = None):
        labels = inputs.pop("labels").to(torch.float32)
        outputs = model(**inputs)
        logits = outputs.logits
        preds = logits if logits.dim() == 1 else (logits.squeeze(-1) if logits.size(-1) == 1 else logits)
        labels = labels.view_as(preds)
        loss = F.mse_loss(preds, labels)
        return (loss, outputs) if return_outputs else loss


# =========================
# Прогресс-логгер
# =========================

class PbarConsoleLogger(TrainerCallback):
    def __init__(self, pbar):
        self.pbar = pbar
        self.last_logs = {}
        self.last_train_loss = None
        self.printed_eval_steps = set()
        self.tqdm = tqdm

    def _step(self, state) -> int:
        return int(state.global_step or 0)

    def _fmt_postfix(self):
        parts = []
        if 'loss' in self.last_logs:
            parts.append(f"loss {self.last_logs['loss']:.4f}")
        if 'eval_loss' in self.last_logs:
            parts.append(f"val {self.last_logs['eval_loss']:.4f}")
        for k, v in self.last_logs.items():
            if k.startswith('eval_') and k not in (
                'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'
            ):
                parts.append(f"{k.replace('eval_', '')} {v:.4f}")
        return " | ".join(parts)

    def on_step_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        if self.last_logs:
            self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        for k, v in logs.items():
            if isinstance(v, (int, float)):
                self.last_logs[k] = float(v)
        if 'loss' in logs and isinstance(logs['loss'], (int, float)):
            self.last_train_loss = float(logs['loss'])

        self.pbar.set_postfix_str(self._fmt_postfix(), refresh=False)
        self.pbar.refresh()

        if any(k.startswith('eval_') for k in logs.keys()):
            step = self._step(state)
            if step in self.printed_eval_steps:
                return
            self.printed_eval_steps.add(step)

            train_loss_str = f"{self.last_train_loss:.10f}" if self.last_train_loss is not None else "n/a"
            val_loss = logs.get('eval_loss', None)
            val_loss_str = f"{float(val_loss):.10g}" if isinstance(val_loss, (int, float)) else "n/a"

            exclude = {'eval_loss', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch'}
            extra_parts = []
            for k, v in logs.items():
                if k.startswith('eval_') and k not in exclude:
                    metric_name = k.replace('eval_', '')
                    extra_parts.append(f"val {metric_name}: {float(v):.10f}")

            line = f"step: {step}, train loss: {train_loss_str}, val loss: {val_loss_str}"
            if extra_parts:
                line += ", " + ", ".join(extra_parts)
            self.tqdm.write(line)

    def on_train_end(self, args, state, control, **kwargs):
        n = min(self._step(state), self.pbar.total)
        if n > self.pbar.n:
            self.pbar.update(n - self.pbar.n)
        self.pbar.refresh()


# =========================
# Основной пайплайн (регрессия)
# =========================

class SingleModelMultiComboRegression:
    """
    Высокоуровневый пайплайн мультимодальной регрессии (text / image / audio) поверх моделей Hugging Face
    и wav2clip. Поддерживает автоматическую сборку бэкенда под набор модальностей, батчевую токенизацию,
    предварительную токенизацию датасета, чанковую тренировку, раннюю остановку, извлечение эмбеддингов.

    Основное:
    - Комбинации модальностей: ["text"], ["image"], ["audio"], ["text","image"], ["text","audio"], ["text","image","audio"].
    - Модели:
        • Text: Auto (BERT/DistilBERT и т.п.), CLIP text, CLAP text.
        • Image: Auto (ViT и т.п.), CLIP image.
        • Audio: CLAP audio, Auto (Wav2Vec2), wav2clip (без HF процессора).
    - Слияние эмбеддингов: "concat" или "mean" (для "mean" размеры эмбеддингов должны совпадать).
    - Тренируется регрессионная «голова» поверх (замороженных по умолчанию) энкодеров.
    - Метрики: rmse, mae, r2 (считаются внутри numpy; библиотека evaluate не требуется).
    - Предикт: возвращает массив float значений формы [N] или [N, num_labels].
    - Эмбеддинги: fused (и опционально по модальностям).

    :param modalities: Список используемых модальностей из {"text","image","audio"}.
    :param num_labels: Размерность целевой переменной (обычно 1). Определяет выходную размерность регрессионной «головы».
    :param target_column_name: Имя столбца-цели в DataFrame (float значения).
    :param text_columns: Список текстовых колонок. Их значения конкатенируются через special_tokens["sep"].
    :param image_columns: Список колонок с изображениями (пути, PIL.Image.Image, np.ndarray или списки таких объектов).
    :param audio_columns: Список колонок с аудио (пути к файлам или np.ndarray float32 моно; для путей нужен torchaudio).
    :param text_tokenizer_fn: Опциональная пользовательская функция препроцессинга текста.
                              Если возвращает dict с 'input_ids' — пайплайн считает, что это уже тензоры токенов.
                              Если возвращает строку — затем применяется стандартная токенизация процессором.
    :param special_tokens: Спец. токены (разделители), по умолчанию {"sep": " [SEP] "}.
    :param tokenizer_returns_tensors: Флаг для совместимости с внешними токенизаторами (если text_tokenizer_fn отдает тензоры).
    :param backend: Режим сборки бэкенда: "auto" | "clip" | "clap" | другой (для явных конфигов).
    :param clip_checkpoint: Дефолтный чекпойнт CLIP (используется, если выбран соответствующий путь).
    :param clap_checkpoint: Дефолтный чекпойнт CLAP.
    :param text_model_config: Конфиг текстовой модели, например:
                             {"checkpoint": "...", "model_type": "clip|clap|auto|bert", "max_length": int}.
    :param image_model_config: Конфиг визуальной модели, например:
                             {"checkpoint": "...", "model_type": "clip|vit|auto", "max_images": int, "image_agg": "concat|mean"}.
    :param audio_model_config: Конфиг аудио-модели, например:
                             {"checkpoint": "...", "model_type": "clap|auto|wav2clip", "sr": int, "max_audios": int, "audio_agg": "concat|mean"}.
    :param fusion: Слияние модальностей: "concat" (по умолчанию) или "mean". Для "mean" размерности эмбеддингов должны совпадать.
    :param freeze_backbone: Заморозить энкодеры модальностей (True по умолчанию).
    :param clip_max_length: Макс. длина для CLIP-текста (обычно 77).
    :param max_images_per_sample: Максимум изображений на сэмпл (используется при агрегации "concat" или "mean").
    :param max_audios_per_sample: Максимум аудио на сэмпл (аналогично изображению).
    :param use_batch_tokenizer: Использовать батчевую токенизацию текста (LRU-кэш, ускорение).
    :param pretokenize_data: Включить предварительную токенизацию датасета (ускоряет обучение/предикт).
    :param pretokenize_batch_size: Батч-размер при предварительной токенизации.
    :param tokenizer_cache_size: Размер LRU-кэша в батчевом токенизаторе.
    :param max_pretokenize_samples: Лимит записей для предварительной токенизации (на чанк/раздел).
    :param local_cache_dir: Локальная директория кэша моделей/процессоров HF.

    :return: None

    :raises ValueError: 
        - Если выбранные модальности пайплайна не поддерживаются собранным бэкендом.
        - Если fusion="mean" при различающихся размерностях эмбеддингов.
        - Если отсутствуют обязательные колонки для выбранных модальностей в данных.
    :raises OSError: Ошибки загрузки моделей/процессоров из HF Hub/локального кэша (сетевые/офлайн).
    :raises RuntimeError: Ошибки устройств (CUDA/MPS), рассогласование размеров батчей между модальностями и т.п.
    """
    def __init__(
        self,
        modalities: List[str],
        num_labels: int,
        target_column_name: str,
        text_columns: Optional[List[str]] = None,
        image_columns: Optional[List[str]] = None,
        audio_columns: Optional[List[str]] = None,
        text_tokenizer_fn: Optional[Callable] = None,
        special_tokens: Optional[Dict[str, str]] = None,
        tokenizer_returns_tensors: bool = False,
        backend: str = "auto",
        clip_checkpoint: str = "openai/clip-vit-base-patch32",
        clap_checkpoint: str = "laion/clap-htsat-unfused",
        text_model_config: Optional[Dict[str, Any]] = None,
        image_model_config: Optional[Dict[str, Any]] = None,
        audio_model_config: Optional[Dict[str, Any]] = None,
        fusion: str = "concat",
        freeze_backbone: bool = True,
        clip_max_length: int = 77,
        max_images_per_sample: int = 1,
        max_audios_per_sample: int = 1,
        use_batch_tokenizer: bool = True,
        pretokenize_data: bool = True,
        pretokenize_batch_size: int = 256,
        tokenizer_cache_size: int = 10000,
        max_pretokenize_samples: int = 100000,
        local_cache_dir: str = "./model_cache"
    ):
        self.modalities = sorted(list(set(modalities)))
        self.num_labels = num_labels
        self.target_column_name = target_column_name
        self.text_columns = text_columns or []
        self.image_columns = image_columns or []
        self.audio_columns = audio_columns or []
        self.text_tokenizer_fn = text_tokenizer_fn
        self.special_tokens = special_tokens or {"sep": " [SEP] "}
        self.tokenizer_returns_tensors = tokenizer_returns_tensors
        self.backend_name = backend.lower()
        self.clip_checkpoint = clip_checkpoint
        self.clap_checkpoint = clap_checkpoint
        self.text_model_config = text_model_config
        self.image_model_config = image_model_config
        self.audio_model_config = audio_model_config
        self.fusion = fusion
        self.freeze_backbone = freeze_backbone
        self.clip_max_length = clip_max_length
        self.max_images_per_sample = int(max_images_per_sample)
        self.max_audios_per_sample = int(max_audios_per_sample)

        self.use_batch_tokenizer = use_batch_tokenizer
        self.pretokenize_data = pretokenize_data
        self.pretokenize_batch_size = pretokenize_batch_size
        self.tokenizer_cache_size = tokenizer_cache_size
        self.max_pretokenize_samples = max_pretokenize_samples
        self.local_cache_dir = local_cache_dir

        self.backend: Optional[BaseBackend] = None
        self.model: Optional[SingleBackboneClassifier] = None
        self.trainer: Optional[Trainer] = None
        self.compute_metrics = None
        self.progress_callback: Optional[PbarConsoleLogger] = None

        self._build_backend()

    def _build_backend(self):
        mods = set(self.modalities)
        name = self.backend_name

        if name == "auto":
            if mods == {"text", "image"}:
                self.text_model_config = self.text_model_config or {
                    'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_length': self.clip_max_length
                }
                self.image_model_config = self.image_model_config or {
                    'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'
                }
            elif mods == {"text", "audio"}:
                self.text_model_config = self.text_model_config or {
                    'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_length': 64
                }
                self.audio_model_config = self.audio_model_config or {
                    'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000
                }
            else:
                if "text" in mods and self.text_model_config is None:
                    self.text_model_config = {'checkpoint': 'bert-base-multilingual-cased', 'model_type': 'bert', 'max_length': 512}
                if "image" in mods and self.image_model_config is None:
                    self.image_model_config = {'checkpoint': 'google/vit-base-patch16-224', 'model_type': 'vit', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'}
                if "audio" in mods and self.audio_model_config is None:
                    self.audio_model_config = {'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000}

        elif name == "clip":
            self.text_model_config = self.text_model_config or {
                'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_length': self.clip_max_length
            }
            self.image_model_config = self.image_model_config or {
                'checkpoint': self.clip_checkpoint, 'model_type': 'clip', 'max_images': self.max_images_per_sample, 'image_agg': 'concat'
            }
        elif name == "clap":
            self.text_model_config = self.text_model_config or {
                'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_length': 64
            }
            self.audio_model_config = self.audio_model_config or {
                'checkpoint': self.clap_checkpoint, 'model_type': 'clap', 'max_audios': self.max_audios_per_sample, 'audio_agg': 'concat', 'sr': 48000
            }
        else:
            pass

        self.backend = UniversalMultiBackend(
            text_model_config=self.text_model_config if "text" in mods else None,
            image_model_config=self.image_model_config if "image" in mods else None,
            audio_model_config=self.audio_model_config if "audio" in mods else None,
            freeze=self.freeze_backbone,
            text_tokenizer_fn=self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            tokenizer_returns_tensors=self.tokenizer_returns_tensors,
            use_batch_tokenizer=self.use_batch_tokenizer,
            tokenizer_cache_size=self.tokenizer_cache_size,
            tokenizer_batch_size=self.pretokenize_batch_size,
            local_cache_dir=self.local_cache_dir
        )

        if not set(self.modalities).issubset(self.backend.supported):
            raise ValueError(f"Бэкенд {self.backend.name} не поддерживает модальности {self.modalities}")

    def _setup_metrics(self, metric_name: str):
        name = metric_name.lower()
        if name not in ("rmse", "mae", "r2"):
            raise ValueError('metric_name для регрессии должен быть "rmse", "mae" или "r2"')

        def compute(p):
            preds = p.predictions
            y = p.label_ids
            preds = preds.squeeze(-1) if preds.ndim == 2 and preds.shape[-1] == 1 else preds
            y = y.squeeze(-1) if y.ndim == 2 and y.shape[-1] == 1 else y
            axis = 0 if preds.ndim == 2 else None
            if name == "rmse":
                err = preds - y
                mse = np.mean(err**2, axis=axis)
                rmse = np.sqrt(mse)
                return {"rmse": float(np.mean(rmse))}
            elif name == "mae":
                mae = np.mean(np.abs(preds - y), axis=axis)
                return {"mae": float(np.mean(mae))}
            else:
                y_mean = np.mean(y, axis=axis, keepdims=True) if preds.ndim == 2 else np.mean(y)
                ss_res = np.sum((y - preds) ** 2, axis=axis)
                ss_tot = np.sum((y - y_mean) ** 2, axis=axis)
                r2 = 1.0 - (ss_res / (ss_tot + 1e-12))
                return {"r2": float(np.mean(r2))}

        self.compute_metrics = compute

    def _split(self, df: pd.DataFrame, test_size: float = 0.2, seed: int = 42):
        df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
        n_eval = int(math.ceil(len(df) * test_size))
        return df.iloc[n_eval:].reset_index(drop=True), df.iloc[:n_eval].reset_index(drop=True)

    def _validate_data(self, df: pd.DataFrame):
        if "text" in self.modalities:
            if not self.text_columns:
                raise ValueError("Вы выбрали модальность 'text', но text_columns пустой.")
            missing = [c for c in self.text_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют текстовые колонки: {missing}")
        if "image" in self.modalities:
            if not self.image_columns:
                raise ValueError("Вы выбрали модальность 'image', но image_columns пуст.")
            missing = [c for c in self.image_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют колонки изображений: {missing}")
        if "audio" in self.modalities:
            if not self.audio_columns:
                raise ValueError("Вы выбрали модальность 'audio', но audio_columns пуст.")
            missing = [c for c in self.audio_columns if c not in df.columns]
            if missing:
                raise ValueError(f"В DataFrame отсутствуют колонки аудио: {missing}")

    def fit(
        self,
        train_data: pd.DataFrame,
        epochs: int = 3,
        test_size: float = 0.2,
        test_data: Optional[pd.DataFrame] = None,
        per_device_train_batch_size: int = 16,
        gradient_accumulation_steps: int = 1,
        learning_rate: float = 2e-4,
        metric_name: str = "rmse",
        fp16: bool = True,
        logging_steps: int = 50,
        eval_steps: int = 200,
        output_dir: str = "./result_reg",
        seed: int = 42,
        hidden: int = 256,
        dropout: float = 0.1,
        gradient_checkpointing: bool = False,
        fit_chunk_size: Optional[int] = None,
        clear_cache_every_n_chunks: int = 10,
        early_stopping_patience: Optional[int] = 3,
        early_stopping_threshold: float = 0.0
    ):
        """
        Обучает регрессионную голову поверх (по умолчанию) замороженных энкодеров модальностей.
        Поддерживает предварительную токенизацию, чанковую тренировку, раннюю остановку,
        градиентный чекпоинтинг, автоматический прогресс-бар.
    
        Данные:
        - train_data/test_data — DataFrame со столбцом target_column_name (float).
        - По каждой модальности — соответствующие колонки (text/image/audio).
    
        Тренировка:
        - Trainer = MSETrainer (MSE лосс).
        - Метрика для лучшей модели: eval_<metric_name>.
        - evaluation_strategy="steps" (каждые eval_steps), сохранение чекпоинтов каждые eval_steps.
        - Чанковая тренировка: train_data делится на чанки (fit_chunk_size) для экономии памяти.
    
        :param train_data: Обучающий DataFrame со столбцом target_column_name.
        :param epochs: Количество эпох.
        :param test_size: Доля валидации (если test_data не задан).
        :param test_data: Отдельный валидационный DataFrame.
        :param per_device_train_batch_size: Батч на устройство.
        :param gradient_accumulation_steps: Аккумуляция градиента (эффективный батч = batch_size * steps).
        :param learning_rate: Базовый learning rate.
        :param metric_name: "rmse" | "mae" | "r2".
        :param fp16: Включить ли fp16 (True только при наличии CUDA).
        :param logging_steps: Шаг логирования Trainer (учтите, что при одном шаге обучения train_loss может не залогироваться).
        :param eval_steps: Периодичность валидации/сохранения.
        :param output_dir: Каталог для результатов и чекпоинтов.
        :param seed: Сид.
        :param hidden: Размер скрытого слоя регрессионной головы.
        :param dropout: Дропаут в голове.
        :param gradient_checkpointing: Включить градиентный чекпоинтинг (пробрасывается в энкодеры, если они поддерживают).
        :param fit_chunk_size: Размер чанка тренировочных данных (None — без разбиения).
        :param clear_cache_every_n_chunks: Каждые N чанков очищать кэш токенизации.
        :param early_stopping_patience: Патенс ранней остановки (None или <=0 — отключено).
        :param early_stopping_threshold: Минимальное улучшение метрики для сброса патенса.
    
        :return: self
    
        :raises ValueError: Некорректные конфигурации (например, отсутствуют обязательные колонки).
        :raises OSError: Ошибки чтения/загрузки данных/моделей.
        :raises RuntimeError: Ошибки устройств (CUDA/MPS), рассогласование размеров батчей при слиянии модальностей.
        """
        self._validate_data(train_data)
        set_seed(seed)

        df_train, df_eval = (train_data, test_data) if test_data is not None else self._split(train_data, test_size, seed)

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds_eval = MultiComboDataset(
            df=df_eval,
            target_col=self.target_column_name,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=(self.pretokenize_data and has_bt and len(df_eval) < 50000),
            pretokenize_batch_size=self.pretokenize_batch_size,
            max_cache_size=min(len(df_eval), self.max_pretokenize_samples),
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        self.model = SingleBackboneClassifier(
            backend=self.backend,
            modalities=self.modalities,
            num_labels=self.num_labels,
            fusion=self.fusion,
            hidden=hidden,
            dropout=dropout
        )
        if gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

        self._setup_metrics(metric_name)

        args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=max(4, per_device_train_batch_size // 2),
            gradient_accumulation_steps=gradient_accumulation_steps,
            eval_accumulation_steps=max(1, gradient_accumulation_steps * 2),
            learning_rate=learning_rate,
            warmup_ratio=0.1,
            lr_scheduler_type="cosine",
            weight_decay=0.0,
            eval_strategy="steps",
            eval_steps=eval_steps,
            save_strategy="steps",
            save_steps=eval_steps,
            load_best_model_at_end=True,
            metric_for_best_model=f"eval_{metric_name}",
            save_total_limit=1,
            logging_strategy="steps",
            logging_steps=logging_steps,
            report_to="none",
            fp16=fp16 and torch.cuda.is_available(),
            dataloader_num_workers=min(4, os.cpu_count() or 4),
            seed=seed,
            remove_unused_columns=False,
            gradient_checkpointing=gradient_checkpointing,
            dataloader_pin_memory=True,
            ddp_find_unused_parameters=False,
            disable_tqdm=True
        )

        def regression_collator(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            out = self.backend.collate(batch_list)
            if "labels" in out:
                out["labels"] = out["labels"].to(torch.float32)
            return out

        def steps_for_size(sz: int, bsz: int, accum: int) -> int:
            return max(0, math.ceil(math.ceil(sz / max(1, bsz)) / max(1, accum)))

        def chunk_slices(index_array: np.ndarray, chunk_size: int):
            for i in range(0, len(index_array), chunk_size):
                yield index_array[i:i + chunk_size]

        n_train = len(df_train)
        rng = np.random.default_rng(seed)
        train_idx = np.arange(n_train)
        chunk_size = fit_chunk_size if (fit_chunk_size and fit_chunk_size > 0) else len(train_idx)

        total_steps = 0
        for _ in range(epochs):
            rng.shuffle(train_idx)
            for slc in chunk_slices(train_idx, chunk_size):
                total_steps += steps_for_size(len(slc), per_device_train_batch_size, gradient_accumulation_steps)

        dummy_idx = np.arange(min(len(df_train), 1))
        ds_train_init = (
            MultiComboDataset(
                df=df_train.iloc[dummy_idx],
                target_col=self.target_column_name,
                text_columns=self.text_columns,
                image_columns=self.image_columns,
                audio_columns=self.audio_columns,
                text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
                text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
                special_tokens=self.special_tokens,
                pretokenize=False,
                tokenizer_returns_tensors=self.tokenizer_returns_tensors
            ) if len(dummy_idx) > 0 else ds_eval
        )

        self.trainer = MSETrainer(
            model=self.model,
            args=args,
            train_dataset=ds_train_init,
            eval_dataset=ds_eval,
            data_collator=regression_collator,
            compute_metrics=self.compute_metrics
        )
        self.trainer.remove_callback(PrinterCallback)

        if (ds_eval is not None) and (early_stopping_patience is not None) and (early_stopping_patience > 0):
            esc = EarlyStoppingCallback(
                early_stopping_patience=int(early_stopping_patience),
                early_stopping_threshold=float(early_stopping_threshold)
            )
            self.trainer.add_callback(esc)

        if total_steps > 0:
            self.trainer.create_optimizer_and_scheduler(num_training_steps=total_steps)

        pbar = tqdm(total=total_steps, desc="Training Progress", unit="step")
        cb = PbarConsoleLogger(pbar)
        self.trainer.add_callback(cb)
        self.progress_callback = cb

        steps_done = 0
        chunk_counter = 0

        for ep in range(epochs):
            rng = np.random.default_rng(seed + ep)
            shuffled = np.arange(n_train)
            rng.shuffle(shuffled)
            for slc in chunk_slices(shuffled, chunk_size):
                chunk_df = df_train.iloc[slc]
                ds_chunk = MultiComboDataset(
                    df=chunk_df,
                    target_col=self.target_column_name,
                    text_columns=self.text_columns,
                    image_columns=self.image_columns,
                    audio_columns=self.audio_columns,
                    text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
                    text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
                    special_tokens=self.special_tokens,
                    pretokenize=(self.pretokenize_data and has_bt and len(slc) < self.max_pretokenize_samples and len(slc) > 100),
                    pretokenize_batch_size=self.pretokenize_batch_size,
                    max_cache_size=min(len(slc), self.max_pretokenize_samples),
                    tokenizer_returns_tensors=self.tokenizer_returns_tensors
                )
                self.trainer.train_dataset = ds_chunk

                chunk_steps = steps_for_size(len(slc), per_device_train_batch_size, gradient_accumulation_steps)
                if chunk_steps == 0:
                    del ds_chunk, chunk_df
                    continue

                self.trainer.args.max_steps = steps_done + chunk_steps
                self.trainer.train()
                steps_done += chunk_steps

                chunk_counter += 1
                if chunk_counter % clear_cache_every_n_chunks == 0:
                    if hasattr(ds_chunk, 'clear_cache'):
                        ds_chunk.clear_cache()
                        print(f"✓ Очищен кэш токенизации после {chunk_counter} чанков")

                del ds_chunk, chunk_df
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        pbar.close()

        if getattr(self.backend, "batch_tokenizer", None):
            self.backend.batch_tokenizer.clear_cache()

        return self

    def predict(
        self,
        df: pd.DataFrame,
        batch_size: Optional[int] = None
    ) -> np.ndarray:
        """
        Выполняет предсказание для регрессии.
    
        :param df: DataFrame с колонками модальностей (если нет целевой колонки, будет заполнена нулями).
        :param batch_size: Переопределить per_device_eval_batch_size на время инференса.
    
        :return: np.ndarray формы [N] (если num_labels=1) или [N, num_labels] (для многомерной регрессии).
    
        :raises RuntimeError: Если модель не обучена (trainer отсутствует).
        :raises OSError: Ошибки чтения исходных данных (пути к изображениям/аудио).
        :raises ValueError: Внутренние ошибки коллатора/бэкенда (например, несовпадение размеров батчей между модальностями).
        """
        if self.trainer is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        df_c = df.copy()
        if self.target_column_name not in df_c.columns:
            df_c[self.target_column_name] = 0.0

        print(f"Preparing dataset for prediction ({len(df_c)} samples)...")

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds = MultiComboDataset(
            df=df_c,
            target_col=self.target_column_name,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=(self.pretokenize_data and has_bt and len(df_c) < 10000),
            pretokenize_batch_size=self.pretokenize_batch_size,
            max_cache_size=min(len(df_c), 10000),
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        if batch_size:
            original_bs = self.trainer.args.per_device_eval_batch_size
            self.trainer.args.per_device_eval_batch_size = batch_size

        effective_batch_size = batch_size or self.trainer.args.per_device_eval_batch_size
        num_batches = (len(df_c) + effective_batch_size - 1) // effective_batch_size

        print(f"Running predictions (batch_size={effective_batch_size}, num_batches={num_batches})...")

        original_disable_tqdm = self.trainer.args.disable_tqdm
        self.trainer.args.disable_tqdm = False

        preds = self.trainer.predict(test_dataset=ds)

        self.trainer.args.disable_tqdm = original_disable_tqdm

        if batch_size:
            self.trainer.args.per_device_eval_batch_size = original_bs

        if hasattr(ds, 'clear_cache'):
            ds.clear_cache()

        y = preds.predictions
        y = y.squeeze(-1) if y.ndim == 2 and y.shape[-1] == 1 else y
        return y

    def get_embeddings(self, df: pd.DataFrame, batch_size: int = 32, return_per_modality: bool = False):
        """
        Извлекает эмбеддинги для входных данных.
    
        :param df: DataFrame с колонками модальностей (если нет целевой колонки, будет заполнена нулями).
        :param batch_size: Батч-размер во время извлечения эмбеддингов.
        :param return_per_modality: Если True — вернуть помимо fused также по-модальные эмбеддинги.
    
        :return:
            - Если return_per_modality=False:
                np.ndarray fused формы [N, D_fused], где D_fused — размерность после слияния (concat/mean).
            - Если return_per_modality=True:
                (fused, per_mod) — кортеж:
                    fused: np.ndarray [N, D_fused],
                    per_mod: Dict[str, np.ndarray] с ключами "text"/"image"/"audio", значениями [N, D_mod].
    
        :raises RuntimeError: Если модель не обучена (нет trainer/model).
        :raises OSError: Ошибки чтения исходных данных (пути к изображениям/аудио).
        """
        if self.trainer is None or self.model is None:
            raise RuntimeError("Модель не обучена. Вызовите .fit().")

        try:
            device = next(self.trainer.model.parameters()).device
        except StopIteration:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model.to(device).eval()

        df_c = df.copy()
        if self.target_column_name not in df_c.columns:
            df_c[self.target_column_name] = 0.0

        print(f"Preparing dataset for embeddings extraction ({len(df_c)} samples)...")

        has_bt = bool(self.use_batch_tokenizer and getattr(self.backend, "batch_tokenizer", None))

        ds = MultiComboDataset(
            df=df_c,
            target_col=self.target_column_name,
            text_columns=self.text_columns,
            image_columns=self.image_columns,
            audio_columns=self.audio_columns,
            text_tokenizer=self.backend.batch_tokenizer if has_bt else None,
            text_tokenizer_fn=None if has_bt else self.text_tokenizer_fn,
            special_tokens=self.special_tokens,
            pretokenize=False,
            tokenizer_returns_tensors=self.tokenizer_returns_tensors
        )

        def collate(batch_list: List[Dict[str, Any]]) -> Dict[str, Any]:
            return self.backend.collate(batch_list)

        def move_to_device(obj, device: torch.device):
            if torch.is_tensor(obj):
                return obj.to(device)
            if isinstance(obj, dict):
                return {k: move_to_device(v, device) for k, v in obj.items()}
            if isinstance(obj, (list, tuple)):
                t = [move_to_device(v, device) for v in obj]
                return type(obj)(t) if not isinstance(obj, list) else t
            return obj

        loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
        fused_list = []
        per_mod_lists = {m: [] for m in self.modalities} if return_per_modality else None

        num_batches = (len(df_c) + batch_size - 1) // batch_size

        print(f"Extracting embeddings (batch_size={batch_size}, num_batches={num_batches})...")

        with torch.no_grad():
            for batch in tqdm(loader, total=num_batches, desc="Extracting embeddings", unit="batch", leave=True):
                bi = move_to_device(batch["backend_inputs"], device)
                fused, per = self.model.get_embeddings(backend_inputs=bi, return_per_modality=True)
                fused_list.append(fused.cpu().numpy())
                if return_per_modality:
                    for m in per_mod_lists.keys():
                        if m in per:
                            per_mod_lists[m].append(per[m].cpu().numpy())

        print("Concatenating embeddings...")
        fused_arr = np.vstack(fused_list)

        if not return_per_modality:
            print(f"✓ Embeddings shape: {fused_arr.shape}")
            return fused_arr

        per_mod = {m: np.vstack(chunks) for m, chunks in per_mod_lists.items()}
        print(f"✓ Fused embeddings shape: {fused_arr.shape}")
        for m, arr in per_mod.items():
            print(f"✓ {m.capitalize()} embeddings shape: {arr.shape}")

        return fused_arr, per_mod

Пример использования 1.

In [None]:
import numpy as np
import pandas as pd

# Данные
n = 8
df = pd.DataFrame({
    "title": [f"some short text {i}" for i in range(n)],
    "audio": [(np.sin(np.linspace(0, 2*np.pi, 48000)).astype(np.float32) * 0.1) for _ in range(n)],
    "target": np.random.randn(n).astype(np.float32)
})

# Инициализация пайплайна (text=CLIP, audio=CLAP)
pipe = SingleModelMultiComboRegression(
    modalities=["text","audio"],
    num_labels=1,
    target_column_name="target",
    text_columns=["title"],
    audio_columns=["audio"],
    text_model_config={
        "checkpoint": "openai/clip-vit-base-patch32",
        "model_type": "clip",
        "max_length": 77
    },
    audio_model_config={
        "checkpoint": "laion/clap-htsat-unfused",
        "model_type": "clap",
        "sr": 48000,
        "max_audios": 2,
        "audio_agg": "concat"
    },
    fusion="concat",
    freeze_backbone=True,
    use_batch_tokenizer=True,
    pretokenize_data=True,
    pretokenize_batch_size=64,
    tokenizer_cache_size=10000,
    max_pretokenize_samples=50000,
    local_cache_dir="./model_cache"
)

# Тренировка
pipe.fit(
    train_data=df,
    epochs=2,
    test_size=0.25,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    metric_name="rmse",
    fp16=True,                       # будет True, только если есть CUDA
    logging_steps=10,
    eval_steps=10,
    output_dir="./out_clip_clap",
    seed=42,
    hidden=512,
    dropout=0.1,
    gradient_checkpointing=True,     # включить GC там, где поддерживается
    fit_chunk_size=None,
    clear_cache_every_n_chunks=2,
    early_stopping_patience=3,
    early_stopping_threshold=0.0
)

# Предсказания
y_pred = pipe.predict(df, batch_size=4)
print("Pred shape:", y_pred.shape)

# Эмбеддинги
emb_fused, emb_per = pipe.get_embeddings(df, batch_size=4, return_per_modality=True)
print("Fused:", emb_fused.shape, "| text:", emb_per["text"].shape, "| audio:", emb_per["audio"].shape)

Пример использования 2.

In [None]:
import numpy as np
import pandas as pd

# Данные
n = 9
df = pd.DataFrame({
    "text": [f"example text #{i}" for i in range(n)],
    "image": [(np.random.rand(224,224,3) * 255).astype(np.uint8) for _ in range(n)],
    "audio": [(np.random.randn(16000).astype(np.float32)*0.05) for _ in range(n)],
    "target": np.random.randn(n).astype(np.float32)
})

pipe = SingleModelMultiComboRegression(
    modalities=["text","image","audio"],
    num_labels=1,
    target_column_name="target",
    text_columns=["text"],
    image_columns=["image"],
    audio_columns=["audio"],
    text_model_config={
        "checkpoint": "distilbert-base-uncased",
        "model_type": "auto",
        "max_length": 128
    },
    image_model_config={
        "checkpoint": "google/vit-base-patch16-224",
        "model_type": "vit",
        "max_images": 2,
        "image_agg": "mean"
    },
    audio_model_config={
        "checkpoint": "facebook/wav2vec2-base-960h",
        "model_type": "auto",
        "sr": 16000,
        "max_audios": 2,
        "audio_agg": "mean"
    },
    fusion="concat",
    freeze_backbone=True,
    use_batch_tokenizer=True,
    pretokenize_data=True,
    pretokenize_batch_size=32,
    tokenizer_cache_size=20000,
    max_pretokenize_samples=100000,
    local_cache_dir="./model_cache"
)

pipe.fit(
    train_data=df,
    epochs=2,
    test_size=0.3,
    per_device_train_batch_size=3,
    gradient_accumulation_steps=1,
    learning_rate=1e-4,
    metric_name="r2",
    fp16=True,
    logging_steps=5,
    eval_steps=5,
    output_dir="./out_auto_triplet",
    seed=123,
    hidden=384,
    dropout=0.1,
    gradient_checkpointing=True,
    fit_chunk_size=6,                # демонстрация чанковой тренировки
    clear_cache_every_n_chunks=2,
    early_stopping_patience=2
)

y = pipe.predict(df, batch_size=3)
print("Pred shape:", y.shape)

fused, per_mod = pipe.get_embeddings(df, batch_size=3, return_per_modality=True)
print("Fused:", fused.shape, "| text:", per_mod["text"].shape, "| image:", per_mod["image"].shape, "| audio:", per_mod["audio"].shape)

Пример использования 3.

In [None]:
import numpy as np
import pandas as pd

# Данные
n = 5
df = pd.DataFrame({
    "audio": [(np.sin(np.linspace(0, 6.28, 16000)).astype(np.float32)*0.05) for _ in range(n)],
    "target": np.random.randn(n).astype(np.float32)
})

pipe = SingleModelMultiComboRegression(
    modalities=["audio"],
    num_labels=1,
    target_column_name="target",
    audio_columns=["audio"],
    audio_model_config={"model_type": "wav2clip", "sr": 16000, "max_audios": 1, "audio_agg": "mean"},
    fusion="concat",
    freeze_backbone=True,
    use_batch_tokenizer=False,      # текст не используется
    pretokenize_data=False
)

pipe.fit(df, epochs=1, test_size=0.33, per_device_train_batch_size=2, eval_steps=1, logging_steps=1, output_dir="./out_w2c")

pred = pipe.predict(df)
print("Pred shape:", pred.shape)

emb = pipe.get_embeddings(df)
print("Embeddings:", emb.shape)    # ожидаемо [N, 512]