# ReRanker с catboostranker (лучше подбор весов работает, но можно попробовать это при больших данных)

In [None]:
import re
import numpy as np
import pickle
from pathlib import Path
from tqdm.auto import tqdm
from collections import Counter
import torch
from rouge_score import rouge_scorer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from catboost import CatBoostRanker, Pool
from scipy.spatial.distance import cosine


class FeaturesExtractor:
    """извлечение признаков для ранжирования"""
    
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        
        # ленивая инициализация моделей
        self._fluency_model = None
        self._fluency_tokenizer = None
        self._embedding_model = None
        self._embedding_tokenizer = None
        self._rouge_scorer = None
        
        # регистр признаков (для гибкости)
        self.enabled_features = set()
    
    def __getstate__(self):
        state = self.__dict__.copy()
        state['_fluency_model'] = None
        state['_fluency_tokenizer'] = None
        state['_embedding_model'] = None
        state['_embedding_tokenizer'] = None
        state['_rouge_scorer'] = None
        return state
    
    def __setstate__(self, state):
        self.__dict__.update(state)
    
    def enable_features(self, *feature_names):
        """включить признаки: extractor.enable_features('coverage', 'length', 'embeddings')"""
        self.enabled_features.update(feature_names)
    
    def disable_features(self, *feature_names):
        """выключить признаки"""
        self.enabled_features.difference_update(feature_names)
    
    # =============== ЗАГРУЗКА МОДЕЛЕЙ ===============
    
    def _load_rouge_scorer(self):
        if self._rouge_scorer is None:
            self._rouge_scorer = rouge_scorer.RougeScorer(['rouge2', 'rougeL'], use_stemmer=True)
        return self._rouge_scorer
    
    def _load_fluency_model(self):
        if self._fluency_model is None:
            print("загрузка модели для fluency (gpt-2)...")
            self._fluency_tokenizer = AutoTokenizer.from_pretrained('gpt2')
            self._fluency_model = AutoModelForCausalLM.from_pretrained('gpt2').to(self.device)
            self._fluency_model.eval()
        return self._fluency_model, self._fluency_tokenizer
    
    def _load_embedding_model(self):
        if self._embedding_model is None:
            print("загрузка модели для embeddings (sentence-transformers)...")
            model_name = 'sentence-transformers/all-MiniLM-L6-v2'
            self._embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
            self._embedding_model = AutoModel.from_pretrained(model_name).to(self.device)
            self._embedding_model.eval()
        return self._embedding_model, self._embedding_tokenizer
    
    # =============== БАЗОВЫЕ МЕТРИКИ (из оригинального класса) ===============
    
    @staticmethod
    def _get_ngrams(text, n):
        words = text.lower().split()
        return Counter([' '.join(words[i:i+n]) for i in range(len(words)-n+1)])
    
    def compute_coverage(self, candidate, context):
        """покрытие слов из контекста"""
        if context is None:
            return 0.0
        words = context.replace(',', '').lower().split()
        if not words:
            return 0.0
        covered = sum(1 for word in words if word in candidate.lower())
        return covered / len(words)
    
    def compute_fluency(self, candidate):
        """perplexity-based fluency"""
        model, tokenizer = self._load_fluency_model()
        inputs = tokenizer(candidate, return_tensors='pt', truncation=True, max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss
        
        perplexity = torch.exp(loss).item()
        return 1.0 / (1.0 + perplexity / 30.0)
    
    def compute_grammar(self, candidate):
        """эвристическая грамматика"""
        score = 1.0
        if candidate and not candidate[0].isupper():
            score -= 0.2
        if candidate and not candidate.rstrip()[-1] in '.!?':
            score -= 0.2
        if '  ' in candidate:
            score -= 0.1
        if candidate.count('"') % 2 != 0:
            score -= 0.1
        if len(candidate.split()) < 3:
            score -= 0.3
        return max(0.0, score)
    
    def compute_length_simple(self, candidate, target_min=10, target_max=20):
        """соответствие целевой длине"""
        words = len(candidate.split())
        if target_min <= words <= target_max:
            return 1.0
        elif words < target_min:
            return words / target_min
        else:
            return max(0, 1 - (words - target_max) / (target_max * 0.5))
    
    def compute_repetition_penalty(self, candidate):
        """штраф за повторы"""
        words = candidate.lower().split()
        if len(words) < 3:
            return 1.0
        
        unique_words = len(set(words))
        total_words = len(words)
        unigram_diversity = unique_words / total_words if total_words > 0 else 0
        
        bigrams = self._get_ngrams(candidate, 2)
        unique_bigrams = len(bigrams)
        total_bigrams = sum(bigrams.values())
        bigram_diversity = unique_bigrams / total_bigrams if total_bigrams > 0 else 0
        
        return 0.5 * unigram_diversity + 0.5 * bigram_diversity
    
    def compute_lexical_diversity(self, candidate):
        """лексическое разнообразие"""
        words = candidate.lower().split()
        if len(words) < 5:
            return 0.5
        
        unique_words = len(set(words))
        total_words = len(words)
        ttr = unique_words / total_words
        return min(1.0, ttr * (1 + np.log(total_words) / 5))
    
    def compute_extractive_coverage(self, candidate, context):
        """покрытие n-грамм"""
        if context is None:
            return 0.0
        
        source_ngrams = {n: self._get_ngrams(context, n) for n in [1, 2, 3, 4]}
        
        ngram_scores = []
        for n in [1, 2, 3, 4]:
            cand_ngrams = self._get_ngrams(candidate, n)
            if not cand_ngrams:
                continue
            overlap = sum((cand_ngrams & source_ngrams[n]).values())
            total = sum(cand_ngrams.values())
            ngram_scores.append(overlap / total)
        
        return np.mean(ngram_scores) if ngram_scores else 0.0
    
    # =============== НОВЫЕ ПРИЗНАКИ ===============
    
    def compute_length_candidate(self, candidate):
        """длина кандидата (в словах)"""
        return len(candidate.split())
    
    def compute_length_context(self, context):
        """длина контекста (в словах)"""
        if context is None:
            return 0
        return len(context.split())
    
    def compute_length_ratio(self, candidate, context):
        """отношение длин candidate / context"""
        if context is None or len(context.split()) == 0:
            return 0.0
        return len(candidate.split()) / len(context.split())
    
    def compute_char_length_candidate(self, candidate):
        """длина в символах"""
        return len(candidate)
    
    def compute_embedding_similarity(self, candidate, context):
        """косинусная близость embeddings"""
        if context is None:
            return 0.0
        
        model, tokenizer = self._load_embedding_model()
        
        # mean pooling
        def get_embedding(text):
            inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512, padding=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model(**inputs)
                # mean pooling
                embeddings = outputs.last_hidden_state
                attention_mask = inputs['attention_mask']
                mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
                sum_embeddings = torch.sum(embeddings * mask_expanded, 1)
                sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
                mean_pooled = sum_embeddings / sum_mask
            
            return mean_pooled.cpu().numpy()[0]
        
        emb_candidate = get_embedding(candidate)
        emb_context = get_embedding(context)
        
        # косинусная близость (1 - distance)
        return 1 - cosine(emb_candidate, emb_context)
    
    def compute_rouge_l(self, candidate, context):
        """rouge-l с контекстом"""
        if context is None:
            return 0.0
        scorer = self._load_rouge_scorer()
        result = scorer.score(context, candidate)
        return result['rougeL'].fmeasure
    
    def compute_rouge_2(self, candidate, context):
        """rouge-2 с контекстом"""
        if context is None:
            return 0.0
        scorer = self._load_rouge_scorer()
        result = scorer.score(context, candidate)
        return result['rouge2'].fmeasure
    
    def compute_word_overlap(self, candidate, context):
        """процент слов из candidate, которые есть в context"""
        if context is None:
            return 0.0
        
        candidate_words = set(candidate.lower().split())
        context_words = set(context.lower().split())
        
        if not candidate_words:
            return 0.0
        
        overlap = len(candidate_words & context_words)
        return overlap / len(candidate_words)
    
    def compute_unique_words_ratio(self, candidate):
        """отношение уникальных слов к общему количеству"""
        words = candidate.lower().split()
        if not words:
            return 0.0
        return len(set(words)) / len(words)
    
    def compute_avg_word_length(self, candidate):
        """средняя длина слова"""
        words = candidate.split()
        if not words:
            return 0.0
        return np.mean([len(w) for w in words])
    
    def compute_sentence_count(self, candidate):
        """количество предложений"""
        sentences = re.split(r'[.!?]+', candidate)
        sentences = [s.strip() for s in sentences if s.strip()]
        return len(sentences)
    
    # =============== ОСНОВНОЙ МЕТОД ИЗВЛЕЧЕНИЯ ===============
    
    def extract_features(self, candidate, context=None, feature_params=None):
        """
        извлечь все включенные признаки
        
        args:
            candidate: str - текст кандидата
            context: str - контекст (опционально)
            feature_params: dict - параметры для признаков
        
        returns:
            dict: {feature_name: value}
        """
        feature_params = feature_params or {}
        features = {}
        
        # базовые метрики
        if 'coverage' in self.enabled_features:
            features['coverage'] = self.compute_coverage(candidate, context)
        
        if 'fluency' in self.enabled_features:
            features['fluency'] = self.compute_fluency(candidate)
        
        if 'grammar' in self.enabled_features:
            features['grammar'] = self.compute_grammar(candidate)
        
        if 'length_simple' in self.enabled_features:
            params = feature_params.get('length_simple', {'target_min': 10, 'target_max': 20})
            features['length_simple'] = self.compute_length_simple(
                candidate, params['target_min'], params['target_max']
            )
        
        if 'repetition_penalty' in self.enabled_features:
            features['repetition_penalty'] = self.compute_repetition_penalty(candidate)
        
        if 'lexical_diversity' in self.enabled_features:
            features['lexical_diversity'] = self.compute_lexical_diversity(candidate)
        
        if 'extractive_coverage' in self.enabled_features:
            features['extractive_coverage'] = self.compute_extractive_coverage(candidate, context)
        
        # новые признаки
        if 'length_candidate' in self.enabled_features:
            features['length_candidate'] = self.compute_length_candidate(candidate)
        
        if 'length_context' in self.enabled_features:
            features['length_context'] = self.compute_length_context(context)
        
        if 'length_ratio' in self.enabled_features:
            features['length_ratio'] = self.compute_length_ratio(candidate, context)
        
        if 'char_length' in self.enabled_features:
            features['char_length'] = self.compute_char_length_candidate(candidate)
        
        if 'embedding_similarity' in self.enabled_features:
            features['embedding_similarity'] = self.compute_embedding_similarity(candidate, context)
        
        if 'rouge_l' in self.enabled_features:
            features['rouge_l'] = self.compute_rouge_l(candidate, context)
        
        if 'rouge_2' in self.enabled_features:
            features['rouge_2'] = self.compute_rouge_2(candidate, context)
        
        if 'word_overlap' in self.enabled_features:
            features['word_overlap'] = self.compute_word_overlap(candidate, context)
        
        if 'unique_words_ratio' in self.enabled_features:
            features['unique_words_ratio'] = self.compute_unique_words_ratio(candidate)
        
        if 'avg_word_length' in self.enabled_features:
            features['avg_word_length'] = self.compute_avg_word_length(candidate)
        
        if 'sentence_count' in self.enabled_features:
            features['sentence_count'] = self.compute_sentence_count(candidate)
        
        return features
    
    def extract_features_batch(self, candidates, context=None, feature_params=None):
        """извлечь признаки для батча кандидатов"""
        return [
            self.extract_features(candidate, context, feature_params)
            for candidate in candidates
        ]


class CatBoostReRanker:
    """re-ranker на основе CatBoost"""
    
    # предустановленные наборы признаков
    FEATURE_SETS = {
        'minimal': ['length_simple', 'repetition_penalty'],
        
        'basic': [
            'coverage', 'length_simple', 'repetition_penalty', 
            'grammar', 'lexical_diversity'
        ],
        
        'extended': [
            'coverage', 'fluency', 'grammar', 'length_simple',
            'repetition_penalty', 'lexical_diversity',
            'length_candidate', 'length_ratio', 'word_overlap',
            'unique_words_ratio'
        ],
        
        'full': [
            'coverage', 'fluency', 'grammar', 'length_simple',
            'repetition_penalty', 'lexical_diversity', 'extractive_coverage',
            'length_candidate', 'length_context', 'length_ratio', 
            'char_length', 'embedding_similarity', 'rouge_l', 'rouge_2',
            'word_overlap', 'unique_words_ratio', 'avg_word_length',
            'sentence_count'
        ]
    }
    
    def __init__(
        self,
        features='basic',  # 'minimal', 'basic', 'extended', 'full' или список признаков
        feature_params=None,
        catboost_params=None,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        cache_dir='./reranker_cache'
    ):
        """
        args:
            features: str или list
                - str: название preset'а ('minimal', 'basic', 'extended', 'full')
                - list: список названий признаков
            feature_params: dict - параметры для признаков
            catboost_params: dict - параметры для CatBoost
            device: str - устройство для вычислений
            cache_dir: str - директория для кэша
        """
        self.device = device
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        
        # инициализация extractor
        self.extractor = FeaturesExtractor(device=device)
        
        # определяем признаки
        if isinstance(features, str):
            if features not in self.FEATURE_SETS:
                raise ValueError(f"unknown feature set: {features}. available: {list(self.FEATURE_SETS.keys())}")
            self.feature_names = self.FEATURE_SETS[features]
        elif isinstance(features, list):
            self.feature_names = features
        else:
            raise ValueError("features must be str or list")
        
        # включаем признаки в extractor
        self.extractor.enable_features(*self.feature_names)
        
        self.feature_params = feature_params or {}
        
        # параметры CatBoost по умолчанию
        default_catboost_params = {
            'iterations': 500,
            'depth': 6,
            'learning_rate': 0.03,
            'loss_function': 'YetiRank',
            'verbose': False,
            'random_seed': 42,
            'task_type': 'GPU' if device == 'cuda' else 'CPU',
        }
        
        if catboost_params:
            default_catboost_params.update(catboost_params)
        
        self.catboost_params = default_catboost_params
        self.model = None
    
    def _prepare_data(self, candidates_list, contexts, y_texts, target_metric='rouge2'):
        """подготовить данные для CatBoost"""
        print(f"извлечение признаков для {len(candidates_list)} примеров...")
        
        all_features = []
        all_labels = []
        all_groups = []
        
        # для вычисления целевой метрики
        scorer = rouge_scorer.RougeScorer([target_metric], use_stemmer=True)
        
        for group_id, (candidates, context, y_text) in enumerate(tqdm(
            zip(candidates_list, contexts, y_texts),
            total=len(candidates_list),
            desc="preparing data"
        )):
            for candidate in candidates:
                # извлекаем признаки
                features = self.extractor.extract_features(candidate, context, self.feature_params)
                
                # преобразуем в список (сохраняем порядок)
                feature_vector = [features.get(name, 0.0) for name in self.feature_names]
                
                # целевая метрика
                target_score = scorer.score(y_text, candidate)[target_metric].fmeasure
                
                all_features.append(feature_vector)
                all_labels.append(target_score)
                all_groups.append(group_id)
        
        return np.array(all_features), np.array(all_labels), np.array(all_groups)
    
    def fit(
        self,
        candidates_list,
        contexts,
        y_texts,
        val_candidates_list=None,
        val_contexts=None,
        val_y_texts=None,
        target_metric='rouge2',
        cache_name=None,
        use_cache=True,
        early_stopping_rounds=50
    ):
        """
        обучить модель
        
        args:
            candidates_list: list of list of str - train кандидаты
            contexts: list of str - train контексты
            y_texts: list of str - train референсы
            val_candidates_list: list of list of str - val кандидаты (опционально)
            val_contexts: list of str - val контексты (опционально)
            val_y_texts: list of str - val референсы (опционально)
            target_metric: str - целевая метрика ('rouge2', 'rougeL')
            cache_name: str - имя кэша
            use_cache: bool - использовать кэш
            early_stopping_rounds: int - ранняя остановка
        
        returns:
            dict: метрики обучения
        """
        # кэширование
        cache_path = None
        cache_loaded = False
        
        if cache_name and use_cache:
            feature_hash = hash(tuple(sorted(self.feature_names)))
            cache_path = self.cache_dir / f"{cache_name}_catboost_{feature_hash}.pkl"
            
            if cache_path.exists():
                print(f"загрузка данных из кэша: {cache_path}")
                with open(cache_path, 'rb') as f:
                    cached_data = pickle.load(f)
                X_train, y_train, groups_train = cached_data['train']
                if 'val' in cached_data:
                    X_val, y_val, groups_val = cached_data['val']
                else:
                    X_val = y_val = groups_val = None
                cache_loaded = True
        
        if not cache_loaded:
            print("\nподготовка train данных...")
            X_train, y_train, groups_train = self._prepare_data(
                candidates_list, contexts, y_texts, target_metric
            )
            
            if val_candidates_list is not None:
                print("подготовка val данных...")
                X_val, y_val, groups_val = self._prepare_data(
                    val_candidates_list, val_contexts, val_y_texts, target_metric
                )
            else:
                X_val = y_val = groups_val = None
            
            if cache_path:
                cache_data = {'train': (X_train, y_train, groups_train)}
                if X_val is not None:
                    cache_data['val'] = (X_val, y_val, groups_val)
                
                with open(cache_path, 'wb') as f:
                    pickle.dump(cache_data, f)
                print(f"данные сохранены в кэш: {cache_path}")
        
        # Pool
        train_pool = Pool(
            data=X_train,
            label=y_train,
            group_id=groups_train,
            feature_names=self.feature_names
        )
        
        eval_pool = None
        if X_val is not None:
            eval_pool = Pool(
                data=X_val,
                label=y_val,
                group_id=groups_val,
                feature_names=self.feature_names
            )
        
        # параметры для fit
        fit_params = {}
        if eval_pool is not None:
            fit_params['verbose_eval'] = 50
            if early_stopping_rounds:
                fit_params['early_stopping_rounds'] = early_stopping_rounds
                print(f"\nранняя остановка: {early_stopping_rounds} итераций без улучшения")
        
        # обучение
        print(f"\nобучение CatBoost с {len(self.feature_names)} признаками...")
        print(f"признаки: {self.feature_names}")
        print(f"train: {len(set(groups_train))} примеров")
        if eval_pool:
            print(f"val:   {len(set(groups_val))} примеров")
        
        self.model = CatBoostRanker(**self.catboost_params)
        
        self.model.fit(
            train_pool,
            eval_set=eval_pool,
            **fit_params
        )
        
        # feature importance
        try:
            feature_importance = self.model.get_feature_importance(type='PredictionValuesChange')
        except:
            feature_importance = self.model.get_feature_importance(
                data=train_pool,
                type='LossFunctionChange'
            )
        
        sorted_importance = sorted(
            zip(self.feature_names, feature_importance),
            key=lambda x: x[1],
            reverse=True
        )
        
        # НОВОЕ: валидация на val по целевой метрике (rouge2)
        val_rouge_score = None
        if val_candidates_list is not None:
            print(f"\nвалидация на val set по {target_metric}...")
            
            from rouge_score import rouge_scorer
            scorer = rouge_scorer.RougeScorer([target_metric], use_stemmer=True)
            
            val_predictions = []
            for candidates, context, y_text in tqdm(
                zip(val_candidates_list, val_contexts, val_y_texts),
                total=len(val_candidates_list),
                desc="validation"
            ):
                best = self.get_best_candidate(candidates, context)
                val_predictions.append(best)
            
            # вычисляем rouge
            rouge_scores = []
            for pred, ref in zip(val_predictions, val_y_texts):
                score = scorer.score(ref, pred)[target_metric].fmeasure
                rouge_scores.append(score)
            
            val_rouge_score = np.mean(rouge_scores)
            print(f"val {target_metric}: {val_rouge_score:.4f}")
        
        # итоги
        print("\n" + "="*60)
        print("РЕЗУЛЬТАТЫ ОБУЧЕНИЯ")
        print("="*60)
        print(f"количество признаков: {len(self.feature_names)}")
        print(f"количество итераций: {self.model.tree_count_}")
        if val_rouge_score is not None:
            print(f"val {target_metric}: {val_rouge_score:.4f}")
        print(f"\nтоп-10 важных признаков:")
        for name, importance in sorted_importance[:10]:
            print(f"  {name:25s}: {importance:.4f}")
        print("="*60)
        
        return {
            'feature_importance': dict(sorted_importance),
            'train_size': len(set(groups_train)),
            'val_size': len(set(groups_val)) if groups_val is not None else 0,
            'iterations': self.model.tree_count_,
            f'val_{target_metric}': val_rouge_score
        }

    def rank_candidates(self, candidates, context=None):
        """
        ранжировать кандидатов
        
        returns:
            list of tuple: [(idx, score, features), ...]
        """
        if self.model is None:
            raise ValueError("model not trained. call fit() first")
        
        if not candidates:
            return []
        
        # извлечение признаков
        features_list = self.extractor.extract_features_batch(
            candidates, context, self.feature_params
        )
        
        # преобразование в матрицу
        X = np.array([
            [f.get(name, 0.0) for name in self.feature_names]
            for f in features_list
        ])
        
        # предсказание
        scores = self.model.predict(X)
        
        # сортировка
        results = []
        for idx, (score, features) in enumerate(zip(scores, features_list)):
            results.append((idx, score, features))
        
        results.sort(key=lambda x: x[1], reverse=True)
        return results
    
    def get_best_candidate(self, candidates, context=None):
        """получить лучшего кандидата"""
        results = self.rank_candidates(candidates, context)
        if not results:
            return None
        best_idx = results[0][0]
        return candidates[best_idx]
    
    def save(self, path):
        """сохранить ranker"""
        save_data = {
            'feature_names': self.feature_names,
            'feature_params': self.feature_params,
            'catboost_params': self.catboost_params,
            'device': self.device,
        }
        
        # сохраняем модель отдельно
        model_path = Path(path).with_suffix('.cbm')
        if self.model is not None:
            self.model.save_model(str(model_path))
            save_data['model_path'] = str(model_path)
        
        with open(path, 'wb') as f:
            pickle.dump(save_data, f)
        
        print(f"ranker сохранен: {path}")
        if self.model is not None:
            print(f"модель сохранена: {model_path}")
    
    @staticmethod
    def load(path):
        """загрузить ranker"""
        with open(path, 'rb') as f:
            save_data = pickle.load(f)
        
        ranker = CatBoostReRanker(
            features=save_data['feature_names'],
            feature_params=save_data['feature_params'],
            catboost_params=save_data['catboost_params'],
            device=save_data['device']
        )
        
        if 'model_path' in save_data:
            ranker.model = CatBoostRanker()
            ranker.model.load_model(save_data['model_path'])
        
        print(f"ranker загружен: {path}")
        return ranker

In [None]:
# инициализация для QA
ranker_qa = CatBoostReRanker(
    features=[
        # кастомный набор признаков для QA
        'length_simple',          # ответ должен быть коротким
        'embedding_similarity',   # ответ должен быть семантически близок к вопросу
        'extractive_coverage',    # хороший ответ часто является цитатой из контекста
        'repetition_penalty',     # избегаем повторов
        'unique_words_ratio',     # поощряем лаконичные ответы
        'grammar'                 # базовый контроль грамматики
    ],
    feature_params={
        # устанавливаем целевую длину для коротких ответов
        'length_simple': {'target_min': 1, 'target_max': 15}
    },
    catboost_params={
        'iterations': 1000,
        'depth': 6,
        'learning_rate': 0.03,
        'loss_function': 'YetiRank',
        'eval_metric': 'NDCG:top=5',  # понятная метрика для логов
        'verbose': 100,
        'random_seed': 42
    },
    device='cuda'
)

# обучение (dev_contexts - это исходные параграфы из SQuAD)
results_qa = ranker_qa.fit(
    candidates_list=dev_candidates,
    contexts=dev_contexts,
    y_texts=dev_y_texts,
    val_candidates_list=val_candidates,  # валидационный сет
    val_contexts=val_contexts,
    val_y_texts=val_y_texts,
    target_metric='rougeL',  # rougeL лучше для коротких ответов
    cache_name='squad_custom',
    early_stopping_rounds=50
)

In [None]:
# инициализация для CommonGen
ranker_commongen = CatBoostReRanker(
    features='extended',  # preset 'extended' хорошо подходит
    feature_params={
        # устанавливаем целевую длину для предложений
        'length_simple': {'target_min': 8, 'target_max': 25}
    },
    catboost_params={
        'iterations': 800,
        'depth': 8,  # можно глубже, т.к. больше признаков
        'learning_rate': 0.02,
        'loss_function': 'YetiRank',
        'eval_metric': 'NDCG:top=10',
        'verbose': 100,
        'random_seed': 42
    },
    device='cuda'
)

# обучение (dev_contexts - это строки со словами через запятую)
results_commongen = ranker_commongen.fit(
    candidates_list=dev_candidates,
    contexts=dev_contexts,
    y_texts=dev_y_texts,
    val_candidates_list=val_candidates,
    val_contexts=val_contexts,
    val_y_texts=val_y_texts,
    target_metric='rouge2',  # rouge2 хорошо ловит n-граммы
    cache_name='commongen_extended',
    early_stopping_rounds=50
)

In [None]:
# инициализация для CommonGen
ranker_commongen = CatBoostReRanker(
    features='extended',  # preset 'extended' хорошо подходит
    feature_params={
        # устанавливаем целевую длину для предложений
        'length_simple': {'target_min': 8, 'target_max': 25}
    },
    catboost_params={
        'iterations': 800,
        'depth': 8,  # можно глубже, т.к. больше признаков
        'learning_rate': 0.02,
        'loss_function': 'YetiRank',
        'eval_metric': 'NDCG:top=10',
        'verbose': 100,
        'random_seed': 42
    },
    device='cuda'
)

# обучение (dev_contexts - это строки со словами через запятую)
results_commongen = ranker_commongen.fit(
    candidates_list=dev_candidates,
    contexts=dev_contexts,
    y_texts=dev_y_texts,
    val_candidates_list=val_candidates,
    val_contexts=val_contexts,
    val_y_texts=val_y_texts,
    target_metric='rouge2',  # rouge2 хорошо ловит n-граммы
    cache_name='commongen_extended',
    early_stopping_rounds=50
)

In [None]:
# инициализация для Creative Writing
ranker_creative = CatBoostReRanker(
    features=[
        # кастомный набор для стиля
        'lexical_diversity',      # богатство словаря
        'fluency',                # естественность языка
        'repetition_penalty',     # избегаем банальностей
        'grammar',                # базовый стиль
        'length_simple',          # контроль длины (например, для твита)
        'avg_word_length',        # более длинные слова могут указывать на более сложный стиль
        'semantic_coherence'      # логичность (для длинных ответов)
    ],
    feature_params={
        'length_simple': {'target_min': 15, 'target_max': 40}  # например, для поста
    },
    catboost_params={
        'iterations': 500,
        'depth': 4, # неглубокие деревья, т.к. сигналы могут быть слабыми
        'learning_rate': 0.05,
        'loss_function': 'YetiRank',
        'verbose': 100,
        'random_seed': 42
    },
    device='cuda'
)

# обучение (dev_contexts - это промпты, например, "Напиши твит о...")
results_creative = ranker_creative.fit(
    candidates_list=dev_candidates,
    contexts=dev_contexts,
    y_texts=dev_y_texts,
    val_candidates_list=val_candidates,
    val_contexts=val_contexts,
    val_y_texts=val_y_texts,
    target_metric='rougeL', # здесь rouge - это прокси, лучше использовать human evaluation
    cache_name='creative_style',
    early_stopping_rounds=30
)

# ReRanker с подбором весов

In [None]:
import re
import numpy as np
from scipy.optimize import differential_evolution
from scipy.stats import spearmanr
import pickle
from pathlib import Path
from tqdm.auto import tqdm
from collections import Counter
import torch
from rouge_score import rouge_scorer
from transformers import AutoModelForCausalLM, AutoTokenizer


class MetricsComputer:
    """класс для вычисления метрик качества текста"""
    
    ALL_METRICS = [
        'coverage',
        'compression_ratio',
        'extractive_coverage',
        'fluency',
        'grammar',
        'length_simple',
        'lexical_diversity',
        'repetition_penalty',
        'rouge_with_source',
        'semantic_coherence'
    ]
    
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self._fluency_model = None
        self._fluency_tokenizer = None
        self._rouge_scorer = None
    
    def __getstate__(self):
        state = self.__dict__.copy()
        state['_fluency_model'] = None
        state['_fluency_tokenizer'] = None
        state['_rouge_scorer'] = None
        return state
    
    def __setstate__(self, state):
        self.__dict__.update(state)
    
    def _load_rouge_scorer(self):
        """загрузка rouge scorer"""
        if self._rouge_scorer is None:
            self._rouge_scorer = rouge_scorer.RougeScorer(
                ['rouge2', 'rougeL'], 
                use_stemmer=True
            )
        return self._rouge_scorer
    
    def _load_fluency_model(self):
        """загрузка модели для fluency"""
        if self._fluency_model is None:
            print("загрузка модели для fluency (gpt-2)...")
            self._fluency_tokenizer = AutoTokenizer.from_pretrained('gpt2')
            self._fluency_model = AutoModelForCausalLM.from_pretrained('gpt2').to(self.device)
            self._fluency_model.eval()
        return self._fluency_model, self._fluency_tokenizer
    
    @staticmethod
    def _get_ngrams(text, n):
        """извлечение n-грамм"""
        words = text.lower().split()
        return Counter([' '.join(words[i:i+n]) for i in range(len(words)-n+1)])
    
    def compute_coverage(self, candidates, context, params=None):
        """
        покрытие обязательных элементов
        
        применение:
        - commongen: все ли слова использованы
        - keyword-to-text: присутствуют ли ключевые слова
        
        значения: 0.0 - 1.0
        """
        if context is None:
            return [0.0] * len(candidates)
        
        words = context.replace(',', '').lower().split()
        if not words:
            return [0.0] * len(candidates)
        
        scores = []
        for candidate in candidates:
            candidate_lower = candidate.lower()
            covered = sum(1 for word in words if word in candidate_lower)
            scores.append(covered / len(words))
        
        return scores
    
    def compute_compression_ratio(self, candidates, source_text, params=None):
        """
        соответствие целевому коэффициенту сжатия
        
        применение:
        - суммаризация: контроль длины
        - compression tasks
        
        значения: 0.0 - 1.0
        """
        if source_text is None:
            return [0.0] * len(candidates)
        
        params = params or {'optimal_ratio': 0.15, 'sigma': 0.05}
        source_len = len(source_text.split())
        optimal = params['optimal_ratio']
        sigma = params['sigma']
        
        scores = []
        for candidate in candidates:
            cand_len = len(candidate.split())
            ratio = cand_len / source_len if source_len > 0 else 0
            score = np.exp(-((ratio - optimal) ** 2) / (2 * sigma ** 2))
            scores.append(score)
        
        return scores
    
    def compute_extractive_coverage(self, candidates, source_text, params=None):
        """
        покрытие n-грамм из исходного текста
        
        применение:
        - extractive summarization
        - faithful generation
        - qa: ответ из контекста
        
        значения: 0.0 - 1.0
        """
        if source_text is None:
            return [0.0] * len(candidates)
        
        source_ngrams = {
            n: self._get_ngrams(source_text, n)
            for n in [1, 2, 3, 4]
        }
        
        scores = []
        for candidate in candidates:
            ngram_scores = []
            for n in [1, 2, 3, 4]:
                cand_ngrams = self._get_ngrams(candidate, n)
                if not cand_ngrams:
                    continue
                overlap = sum((cand_ngrams & source_ngrams[n]).values())
                total = sum(cand_ngrams.values())
                ngram_scores.append(overlap / total)
            
            scores.append(np.mean(ngram_scores) if ngram_scores else 0.0)
        
        return scores
    
    def compute_fluency(self, candidates, params=None):
        """
        беглость текста на основе perplexity
        
        применение:
        - все задачи генерации
        
        значения: 0.0 - 1.0
        """
        model, tokenizer = self._load_fluency_model()
        
        scores = []
        for candidate in candidates:
            inputs = tokenizer(candidate, return_tensors='pt', truncation=True, max_length=512)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model(**inputs, labels=inputs['input_ids'])
                loss = outputs.loss
            
            perplexity = torch.exp(loss).item()
            fluency_score = 1.0 / (1.0 + perplexity / 30.0)
            scores.append(fluency_score)
        
        return scores
    
    def compute_grammar(self, candidates, params=None):
        """
        грамматическая корректность (эвристика)
        
        применение:
        - все задачи генерации
        
        значения: 0.0 - 1.0
        """
        scores = []
        for candidate in candidates:
            score = 1.0
            
            if candidate and not candidate[0].isupper():
                score -= 0.2
            
            if candidate and not candidate.rstrip()[-1] in '.!?':
                score -= 0.2
            
            if '  ' in candidate:
                score -= 0.1
            
            if candidate.count('"') % 2 != 0:
                score -= 0.1
            
            if len(candidate.split()) < 3:
                score -= 0.3
            
            scores.append(max(0.0, score))
        
        return scores
    
    def compute_length_simple(self, candidates, params=None):
        """
        соответствие целевой длине
        
        применение:
        - commongen: 10-20 слов
        - заголовки: 5-10 слов
        - qa: 1-10 слов
        
        значения: 0.0 - 1.0
        """
        params = params or {'target_min': 10, 'target_max': 20}
        target_min = params['target_min']
        target_max = params['target_max']
        
        scores = []
        for candidate in candidates:
            words = len(candidate.split())
            
            if target_min <= words <= target_max:
                score = 1.0
            elif words < target_min:
                score = words / target_min
            else:
                score = max(0, 1 - (words - target_max) / (target_max * 0.5))
            
            scores.append(score)
        
        return scores
    
    def compute_lexical_diversity(self, candidates, params=None):
        """
        лексическое разнообразие
        
        применение:
        - creative writing
        - диалоги
        
        значения: 0.0 - 1.0
        """
        scores = []
        for candidate in candidates:
            words = candidate.lower().split()
            
            if len(words) < 5:
                scores.append(0.5)
                continue
            
            unique_words = len(set(words))
            total_words = len(words)
            ttr = unique_words / total_words
            
            normalized_diversity = min(1.0, ttr * (1 + np.log(total_words) / 5))
            scores.append(normalized_diversity)
        
        return scores
    
    def compute_repetition_penalty(self, candidates, params=None):
        """
        штраф за повторы
        
        применение:
        - все задачи генерации
        - борьба с вырожденными генерациями
        
        значения: 0.0 - 1.0
        """
        scores = []
        for candidate in candidates:
            words = candidate.lower().split()
            
            if len(words) < 3:
                scores.append(1.0)
                continue
            
            unique_words = len(set(words))
            total_words = len(words)
            unigram_diversity = unique_words / total_words if total_words > 0 else 0
            
            bigrams = self._get_ngrams(candidate, 2)
            unique_bigrams = len(bigrams)
            total_bigrams = sum(bigrams.values())
            bigram_diversity = unique_bigrams / total_bigrams if total_bigrams > 0 else 0
            
            score = 0.5 * unigram_diversity + 0.5 * bigram_diversity
            scores.append(score)
        
        return scores
    
    def compute_rouge_with_source(self, candidates, source_text, params=None):
        """
        rouge overlap с исходным текстом
        
        применение:
        - extractive summarization
        - перефразирование
        
        значения: 0.0 - 1.0
        """
        if source_text is None:
            return [0.0] * len(candidates)
        
        scorer = self._load_rouge_scorer()
        scores = []
        
        for candidate in candidates:
            result = scorer.score(source_text, candidate)
            scores.append(result['rougeL'].fmeasure)
        
        return scores
    
    def compute_semantic_coherence(self, candidates, params=None):
        """
        семантическая связность (эвристика)
        
        применение:
        - длинные генерации
        - диалоги
        
        значения: 0.0 - 1.0
        """
        scores = []
        for candidate in candidates:
            sentences = re.split(r'[.!?]+', candidate)
            sentences = [s.strip() for s in sentences if s.strip()]
            
            if len(sentences) <= 1:
                scores.append(1.0)
                continue
            
            score = 1.0
            
            connectives = [
                'however', 'therefore', 'moreover', 'furthermore',
                'additionally', 'consequently', 'thus', 'hence',
                'because', 'since', 'although', 'while', 'but', 'and'
            ]
            
            candidate_lower = candidate.lower()
            has_connectives = any(conn in candidate_lower for conn in connectives)
            
            if len(sentences) > 2 and not has_connectives:
                score -= 0.2
            
            lengths = [len(s.split()) for s in sentences]
            if len(lengths) > 1:
                length_variance = np.std(lengths) / (np.mean(lengths) + 1e-6)
                if length_variance > 1.5:
                    score -= 0.2
            
            scores.append(max(0.0, score))
        
        return scores
    
    def compute_metric(self, metric_name, candidates, context, params=None):
        """вычислить одну метрику по имени"""
        method_name = f'compute_{metric_name}'
        if not hasattr(self, method_name):
            raise ValueError(f"unknown metric: {metric_name}")
        
        method = getattr(self, method_name)
        
        # метрики требующие context
        if metric_name in ['coverage', 'compression_ratio', 'extractive_coverage', 'rouge_with_source']:
            return method(candidates, context, params)
        else:
            return method(candidates, params)


class ReRankingModel:
    """линейная модель для ранжирования"""
    
    def __init__(self, metric_names, weights=None):
        self.metric_names = metric_names
        self.weights = weights or self._init_uniform_weights()
    
    def _init_uniform_weights(self):
        """равномерные веса"""
        n = len(self.metric_names)
        return {name: 1.0 / n for name in self.metric_names}
    
    def predict(self, metrics_dict):
        """вычислить взвешенный скор"""
        return sum(metrics_dict.get(m, 0) * self.weights.get(m, 0) for m in self.metric_names)
    
    def predict_batch(self, metrics_list):
        """вычислить скоры для батча"""
        return [self.predict(m) for m in metrics_list]
    
    def set_weights(self, weights):
        """установить веса"""
        self.weights = weights


class UniversalTextReRanker:
    """универсальный re-ranker для генераций текста"""
    
    def __init__(
        self,
        use_coverage=False,
        use_compression_ratio=False,
        use_extractive_coverage=False,
        use_fluency=False,
        use_grammar=False,
        use_length_simple=False,
        use_lexical_diversity=False,
        use_repetition_penalty=False,
        use_rouge_with_source=False,
        use_semantic_coherence=False,
        coverage_params=None,
        compression_params=None,
        length_params=None,
        weights=None,
        metric_selection='manual',
        min_correlation=0.15,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        cache_dir='./reranker_cache'
    ):
        """
        args:
            use_*: флаги для включения метрик
            *_params: параметры для метрик
            weights: предустановленные веса
            metric_selection: 'manual', 'auto' или 'auto_all'
                - 'manual': используем только метрики с use_*=True
                - 'auto': из метрик с use_*=True отбираем по корреляции
                - 'auto_all': тестируем все метрики, отбираем по корреляции
            min_correlation: порог корреляции для auto режимов
            device: устройство для вычислений
            cache_dir: директория для кэша
        """
        self.use_coverage = use_coverage
        self.use_compression_ratio = use_compression_ratio
        self.use_extractive_coverage = use_extractive_coverage
        self.use_fluency = use_fluency
        self.use_grammar = use_grammar
        self.use_length_simple = use_length_simple
        self.use_lexical_diversity = use_lexical_diversity
        self.use_repetition_penalty = use_repetition_penalty
        self.use_rouge_with_source = use_rouge_with_source
        self.use_semantic_coherence = use_semantic_coherence
        
        self.coverage_params = coverage_params or {}
        self.compression_params = compression_params or {'optimal_ratio': 0.15, 'sigma': 0.05}
        self.length_params = length_params or {'target_min': 10, 'target_max': 20}
        
        if metric_selection not in ['manual', 'auto', 'auto_all']:
            raise ValueError(f"metric_selection must be 'manual', 'auto' or 'auto_all', got {metric_selection}")
        
        self.metric_selection = metric_selection
        self.min_correlation = min_correlation
        self.device = device
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        
        self.metrics_computer = MetricsComputer(device=device)
        
        active_metrics = self._get_active_metrics()
        self.model = ReRankingModel(active_metrics, weights)
    
    def __getstate__(self):
        state = self.__dict__.copy()
        return state
    
    def __setstate__(self, state):
        self.__dict__.update(state)
    
    def _get_active_metrics(self):
        """получить список активных метрик"""
        metrics = []
        if self.use_coverage:
            metrics.append('coverage')
        if self.use_compression_ratio:
            metrics.append('compression_ratio')
        if self.use_extractive_coverage:
            metrics.append('extractive_coverage')
        if self.use_fluency:
            metrics.append('fluency')
        if self.use_grammar:
            metrics.append('grammar')
        if self.use_length_simple:
            metrics.append('length_simple')
        if self.use_lexical_diversity:
            metrics.append('lexical_diversity')
        if self.use_repetition_penalty:
            metrics.append('repetition_penalty')
        if self.use_rouge_with_source:
            metrics.append('rouge_with_source')
        if self.use_semantic_coherence:
            metrics.append('semantic_coherence')
        return metrics
    
    def _get_metric_params(self, metric_name):
        """получить параметры для метрики"""
        if metric_name == 'coverage':
            return self.coverage_params
        elif metric_name == 'compression_ratio':
            return self.compression_params
        elif metric_name == 'length_simple':
            return self.length_params
        else:
            return None
    
    def compute_batch(self, candidates, context=None):
        """
        вычислить все метрики для батча кандидатов
        
        args:
            candidates: list of str
            context: str или None
        
        returns:
            list of dict: [{metric_name: score}, ...]
        """
        if not candidates:
            return []
        
        all_scores = {}
        
        for metric_name in self.model.metric_names:
            params = self._get_metric_params(metric_name)
            all_scores[metric_name] = self.metrics_computer.compute_metric(
                metric_name, candidates, context, params
            )
        
        results = []
        for i in range(len(candidates)):
            candidate_scores = {
                metric: all_scores[metric][i] 
                for metric in all_scores
            }
            results.append(candidate_scores)
        
        return results
    
    def compute(self, text, context=None, return_individual=True, return_weighted=True):
        """вычислить метрики для одного текста"""
        batch_results = self.compute_batch([text], context)
        scores = batch_results[0]
        
        result = {}
        if return_individual:
            result['scores'] = scores
        
        if return_weighted:
            result['weighted_score'] = self.model.predict(scores)
        
        return result
    
    def rank_candidates(self, candidates, context=None):
        """
        ранжировать кандидатов
        
        returns:
            list of tuple: [(idx, weighted_score, individual_scores), ...]
        """
        if not candidates:
            return []
        
        batch_scores = self.compute_batch(candidates, context)
        
        results = []
        for idx, scores in enumerate(batch_scores):
            weighted_score = self.model.predict(scores)
            results.append((idx, weighted_score, scores))
        
        results.sort(key=lambda x: x[1], reverse=True)
        return results
    
    def get_best_candidate(self, candidates, context=None):
        """получить лучшего кандидата"""
        results = self.rank_candidates(candidates, context)
        if not results:
            return None
        best_idx = results[0][0]
        return candidates[best_idx]
    
    def fit(
        self,
        candidates_list,
        contexts,
        y_texts,
        metric='rouge2',
        cache_name=None,
        use_cache=True,
        max_iter=50,
        popsize=15,
        seed=42,
        n_workers=1,
        print_correlations=True
    ):
        """
        оптимизировать веса метрик
        
        args:
            candidates_list: list of list of str
            contexts: list
            y_texts: list of str
            metric: str - целевая метрика ('rouge2', 'rougeL')
        
        returns:
            dict, float: оптимальные веса и достигнутый скор
        """
        if len(candidates_list) != len(contexts) or len(candidates_list) != len(y_texts):
            raise ValueError(
                f"length mismatch: candidates_list={len(candidates_list)}, "
                f"contexts={len(contexts)}, y_texts={len(y_texts)}"
            )
        
        # определяем метрики для вычисления
        if self.metric_selection == 'auto_all':
            # вычисляем ВСЕ метрики
            metrics_to_compute = MetricsComputer.ALL_METRICS
            print(f"\nрежим auto_all: вычисляем все {len(metrics_to_compute)} метрик")
        else:
            # вычисляем только активные (use_*=True)
            metrics_to_compute = self._get_active_metrics()
            if self.metric_selection == 'auto':
                print(f"\nрежим auto: из {len(metrics_to_compute)} активных метрик отберём по корреляции >= {self.min_correlation}")
            else:
                print(f"\nрежим manual: используем {len(metrics_to_compute)} метрик")
        
        # предвычисление метрик
        precomputed_metrics = self._load_or_compute_metrics(
            candidates_list, contexts, cache_name, use_cache, metrics_to_compute
        )
        
        # предвычисление целевой метрики
        print(f"предвычисление целевой метрики ({metric})...")
        target_scores_cache = self._precompute_target_scores(
            candidates_list, y_texts, metric
        )
        
        metric_names = list(precomputed_metrics[0][0].keys())
        
        print(f"целевая метрика: {metric}")
        
        # вычисление корреляций
        correlations_dict = {}
        if print_correlations or self.metric_selection in ['auto', 'auto_all']:
            correlations_dict = self._compute_correlations(
                precomputed_metrics, target_scores_cache, metric_names
            )
        
        if print_correlations:
            self._print_correlations(correlations_dict)
        
        # автоматический отбор метрик
        if self.metric_selection in ['auto', 'auto_all']:
            selected_metrics = self._select_metrics_by_correlation(
                correlations_dict, self.min_correlation
            )
            
            if not selected_metrics:
                print(f"warning: ни одна метрика не прошла порог {self.min_correlation}")
                print(f"используем все {len(metric_names)} метрик")
                selected_metrics = metric_names
            else:
                dropped = set(metric_names) - set(selected_metrics)
                if dropped:
                    print(f"\nотброшены ({len(dropped)}):")
                    for m in sorted(dropped):
                        corr = correlations_dict[m]['mean']
                        print(f"  {m:25s}: корреляция {corr:+.3f}")
                
                print(f"\nотобрано: {len(selected_metrics)}/{len(metric_names)} метрик")
                print(f"метрики: {selected_metrics}")
            
            # фильтруем precomputed_metrics
            precomputed_metrics = [
                [{k: v for k, v in m.items() if k in selected_metrics} for m in batch]
                for batch in precomputed_metrics
            ]
            metric_names = selected_metrics
            
            # обновляем активные метрики в объекте (для auto_all)
            if self.metric_selection == 'auto_all':
                self._update_active_metrics(selected_metrics)
        
        n_metrics = len(metric_names)
        
        # оптимизация весов
        best_weights, best_score = self._optimize_weights(
            precomputed_metrics, target_scores_cache,
            metric_names, max_iter, popsize, seed, n_workers
        )
        
        # обновление модели
        weights_dict = {
            metric_names[i]: best_weights[i]
            for i in range(n_metrics)
        }
        
        # добавляем нулевые веса для неиспользуемых метрик
        for metric in MetricsComputer.ALL_METRICS:
            if metric not in weights_dict:
                weights_dict[metric] = 0.0
        
        # обновляем модель с новыми метриками
        self.model = ReRankingModel(metric_names, weights_dict)
        
        self._print_results(metric, best_score, weights_dict)
        
        return weights_dict, best_score
    
    def _update_active_metrics(self, selected_metrics):
        """обновить флаги use_* (для auto_all режима)"""
        for metric in MetricsComputer.ALL_METRICS:
            attr_name = f'use_{metric}'
            setattr(self, attr_name, metric in selected_metrics)
    
    def _compute_correlations(self, precomputed_metrics, target_scores_cache, metric_names):
        """вычислить корреляции метрик с целевой метрикой"""
        metric_correlations = {name: [] for name in metric_names}
        
        for i in range(len(precomputed_metrics)):
            target_values = target_scores_cache[i]
            
            for metric_name in metric_names:
                metric_values = [
                    precomputed_metrics[i][j][metric_name]
                    for j in range(len(precomputed_metrics[i]))
                ]
                
                if np.std(metric_values) > 1e-10 and np.std(target_values) > 1e-10:
                    corr, _ = spearmanr(metric_values, target_values)
                    metric_correlations[metric_name].append(corr)
        
        # усреднение
        avg_correlations = {}
        for metric_name in metric_names:
            if metric_correlations[metric_name]:
                avg_correlations[metric_name] = {
                    'mean': np.mean(metric_correlations[metric_name]),
                    'std': np.std(metric_correlations[metric_name])
                }
        
        return avg_correlations
    
    def _select_metrics_by_correlation(self, correlations_dict, min_correlation):
        """отобрать метрики по корреляции"""
        selected = []
        for metric_name, stats in correlations_dict.items():
            if abs(stats['mean']) >= min_correlation:
                selected.append(metric_name)
        return selected
    
    def _print_correlations(self, correlations_dict):
        """вывести корреляции"""
        print("\n" + "="*60)
        print("корреляция метрик с целевой метрикой (spearman)")
        print("="*60)
        
        correlations = [
            (name, stats['mean'], stats['std'])
            for name, stats in correlations_dict.items()
        ]
        correlations.sort(key=lambda x: abs(x[1]), reverse=True)
        
        for metric_name, avg_corr, std_corr in correlations:
            if abs(avg_corr) > 0.4:
                status = "сильная"
            elif abs(avg_corr) > 0.25:
                status = "средняя"
            elif abs(avg_corr) > 0.15:
                status = "слабая"
            else:
                status = "очень слабая"
            
            print(f"{status:15s} {metric_name:25s}: {avg_corr:+.3f} (±{std_corr:.3f})")
        
        print("="*60)
    
    def _print_results(self, metric, best_score, weights_dict):
        """вывести результаты калибровки"""
        print("\n" + "="*60)
        print("результаты калибровки")
        print("="*60)
        print(f"best {metric}: {best_score:.4f}")
        print(f"\nоптимальные веса:")
        
        sorted_weights = sorted(
            weights_dict.items(), 
            key=lambda x: abs(x[1]), 
            reverse=True
        )
        
        for name, weight in sorted_weights:
            if abs(weight) > 1e-6:
                print(f"  {name:25s}: {weight:+.4f}")
        
        print("="*60)
    
    def _load_or_compute_metrics(self, candidates_list, contexts, cache_name, use_cache, metrics_to_compute):
        """загрузить или вычислить метрики"""
        cache_path = None
        if cache_name and use_cache:
            # добавляем метрики в имя кэша для различения
            metrics_hash = hash(tuple(sorted(metrics_to_compute)))
            cache_path = self.cache_dir / f"{cache_name}_metrics_{metrics_hash}.pkl"
        
        if cache_path and cache_path.exists():
            print(f"загрузка метрик из кэша: {cache_path}")
            with open(cache_path, 'rb') as f:
                return pickle.load(f)
        
        print(f"предвычисление {len(metrics_to_compute)} метрик...")
        precomputed = []
        
        for candidates, context in tqdm(
            list(zip(candidates_list, contexts)), 
            desc="computing metrics"
        ):
            candidate_metrics = []
            for candidate in candidates:
                metrics_dict = {}
                for metric_name in metrics_to_compute:
                    params = self._get_metric_params(metric_name)
                    scores = self.metrics_computer.compute_metric(
                        metric_name, [candidate], context, params
                    )
                    metrics_dict[metric_name] = scores[0]
                candidate_metrics.append(metrics_dict)
            
            precomputed.append(candidate_metrics)
        
        if cache_path:
            with open(cache_path, 'wb') as f:
                pickle.dump(precomputed, f)
            print(f"метрики сохранены в кэш: {cache_path}")
        
        return precomputed
    
    def _precompute_target_scores(self, candidates_list, y_texts, metric='rouge2'):
        """предвычислить целевую метрику"""
        scorer = self.metrics_computer._load_rouge_scorer()
        target_scores = []
        
        for candidates, y_text in tqdm(
            zip(candidates_list, y_texts),
            total=len(candidates_list),
            desc="target metric"
        ):
            candidate_scores = []
            for candidate in candidates:
                scores = scorer.score(y_text, candidate)
                candidate_scores.append(scores[metric].fmeasure)
            
            target_scores.append(candidate_scores)
        
        return target_scores
    
    def _optimize_weights(
        self, 
        precomputed_metrics, 
        target_scores_cache, 
        metric_names,
        max_iter, 
        popsize, 
        seed, 
        n_workers
    ):
        """оптимизация весов"""
        n_metrics = len(metric_names)
        
        # сохраняем для objective function
        self._opt_precomputed_metrics = precomputed_metrics
        self._opt_target_scores_cache = target_scores_cache
        self._opt_metric_names = metric_names
        self._opt_n_metrics = n_metrics
        
        print(f"\nоптимизация весов (max_iter={max_iter}, popsize={popsize}, workers={n_workers})...")
        
        result = differential_evolution(
            self._objective_function,
            bounds=[(-1, 1)] * n_metrics,
            strategy='best1bin',
            maxiter=max_iter,
            popsize=popsize,
            tol=0.001,
            mutation=(0.5, 1.5),
            recombination=0.7,
            seed=seed,
            workers=n_workers,
            updating='deferred' if n_workers > 1 else 'immediate',
            polish=True,
            disp=True
        )
        
        # очистка
        del self._opt_precomputed_metrics
        del self._opt_target_scores_cache
        del self._opt_metric_names
        del self._opt_n_metrics
        
        weights_raw = result.x
        weights_sum = np.sum(np.abs(weights_raw))
        
        if weights_sum < 1e-10:
            print(f"warning: все веса близки к нулю")
            weights_normalized = np.ones(n_metrics) / n_metrics
        else:
            weights_normalized = weights_raw / weights_sum
        
        best_score = -result.fun
        return weights_normalized, best_score
    
    def _objective_function(self, weights):
        """функция для оптимизации"""
        total_score = 0.0
        
        for metrics_list, target_scores in zip(
            self._opt_precomputed_metrics,
            self._opt_target_scores_cache
        ):
            # вычисляем взвешенный скор
            weighted_scores = [
                sum(
                    weights[i] * metrics_dict.get(self._opt_metric_names[i], 0)
                    for i in range(self._opt_n_metrics)
                )
                for metrics_dict in metrics_list
            ]
            
            # выбираем лучшего
            best_idx = np.argmax(weighted_scores)
            total_score += target_scores[best_idx]
        
        avg_score = total_score / len(self._opt_precomputed_metrics)
        return -avg_score
    
    def save(self, path):
        """сохранить re-ranker"""
        with open(path, 'wb') as f:
            pickle.dump(self, f)
        print(f"re-ranker сохранен: {path}")
    
    @staticmethod
    def load(path):
        """загрузить re-ranker"""
        with open(path, 'rb') as f:
            ranker = pickle.load(f)
        print(f"re-ranker загружен: {path}")
        return ranker

In [None]:
# инициализация: вручную выбираем метрики
ranker = UniversalTextReRanker(
    use_coverage=True,
    use_length_simple=True,
    use_repetition_penalty=True,
    use_grammar=True,
    length_params={'target_min': 8, 'target_max': 20},
    metric_selection='manual',
    device='cuda',
    cache_dir='./cache_commongen'
)

# обучение
optimal_weights, best_score = ranker.fit(
    candidates_list=dev_candidates,     # list of list of str
    contexts=dev_contexts,              # list of str
    y_texts=dev_y_texts,                # list of str
    metric='rouge2',
    cache_name='commongen_dev',
    use_cache=True,
    max_iter=50,
    popsize=30,
    print_correlations=True
)

# использование
best_candidate = ranker.get_best_candidate(candidates, context)

In [None]:
# инициализация: указываем потенциально полезные метрики
# auto отберёт только сильно коррелирующие
ranker = UniversalTextReRanker(
    use_coverage=True,
    use_fluency=True,
    use_grammar=True,
    use_length_simple=True,
    use_repetition_penalty=True,
    use_lexical_diversity=True,
    length_params={'target_min': 1, 'target_max': 15},
    metric_selection='auto',            # автоотбор из указанных
    min_correlation=0.15,               # порог отбора
    device='cuda',
    cache_dir='./cache_qa'
)

# обучение (автоматически отберёт метрики с |корреляцией| >= 0.15)
optimal_weights, best_score = ranker.fit(
    candidates_list=dev_candidates,
    contexts=dev_contexts,
    y_texts=dev_y_texts,
    metric='rougeL',
    cache_name='qa_dev',
    max_iter=50,
    popsize=30,
    print_correlations=True
)

# использование
best_candidate = ranker.get_best_candidate(candidates, context)

In [None]:
# инициализация: тестирует ВСЕ 10 метрик автоматически
ranker = UniversalTextReRanker(
    length_params={'target_min': 10, 'target_max': 30},
    compression_params={'optimal_ratio': 0.15, 'sigma': 0.05},
    metric_selection='auto_all',       # тестирует все метрики
    min_correlation=0.1,               # низкий порог → больше метрик
    device='cuda',
    cache_dir='./cache_xsum'
)

# обучение (вычислит все 10 метрик, отберёт лучшие)
optimal_weights, best_score = ranker.fit(
    candidates_list=dev_candidates,
    contexts=dev_contexts,
    y_texts=dev_y_texts,
    metric='rougeL',
    cache_name='xsum_dev',
    max_iter=50,
    popsize=30,
    print_correlations=True
)

# использование
best_candidate = ranker.get_best_candidate(candidates, context)

# Метрики

In [6]:
!pip -q install bert_score rouge_score

from collections import Counter
import numpy as np
from rouge_score import rouge_scorer
from bert_score import score as bert_score
import torch
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
import warnings

warnings.filterwarnings('ignore')

def exact_match(prediction, reference):
    return prediction.strip().lower() == reference.strip().lower()
# Higher is better (0-1)

def token_f1(prediction, reference):
    pred_tokens = prediction.lower().split()
    ref_tokens = reference.lower().split()
    
    common = Counter(pred_tokens) & Counter(ref_tokens)
    num_same = sum(common.values())
    
    if num_same == 0:
        return 0.0
    
    precision = num_same / len(pred_tokens)
    recall = num_same / len(ref_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    
    return f1
# Higher is better (0-1)

def compute_bleu(prediction, reference):
    pred_tokens = prediction.lower().split()
    ref_tokens = reference.lower().split()
    
    smoothing = SmoothingFunction().method1
    return sentence_bleu([ref_tokens], pred_tokens, smoothing_function=smoothing)
# Higher is better (0-1)

def compute_rouge(prediction, reference):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference, prediction)
    
    return {
        'rouge1': scores['rouge1'].fmeasure,
        'rouge2': scores['rouge2'].fmeasure,
        'rougeL': scores['rougeL'].fmeasure
    }
# Higher is better (0-1)

def compute_bertscore(predictions, references):
    P, R, F1 = bert_score(predictions, references, lang='en', verbose=False)
    return {
        'precision': P.mean().item(),
        'recall': R.mean().item(),
        'f1': F1.mean().item()
    }
# Higher is better (0-1)

def compute_meteor(prediction, reference):
    pred_tokens = prediction.lower().split()
    ref_tokens = reference.lower().split()
    
    return meteor_score([ref_tokens], pred_tokens)
# Higher is better (0-1)

def perplexity(model, tokenizer, texts):
    total_loss = 0
    total_tokens = 0
    
    for text in texts:
        inputs = tokenizer(text, return_tensors='pt').to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs.input_ids)
            loss = outputs.loss
        
        total_loss += loss.item() * inputs.input_ids.size(1)
        total_tokens += inputs.input_ids.size(1)
    
    return np.exp(total_loss / total_tokens)
# Lower is better

def distinct_n(texts, n=2):
    all_ngrams = []
    
    for text in texts:
        tokens = text.lower().split()
        ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
        all_ngrams.extend(ngrams)
    
    if not all_ngrams:
        return 0.0
    
    return len(set(all_ngrams)) / len(all_ngrams)
# Higher is better (0-1, measures diversity)

def self_bleu(texts):
    if len(texts) < 2:
        return 0.0
    
    scores = []
    smoothing = SmoothingFunction().method1
    
    for i, text in enumerate(texts):
        others = texts[:i] + texts[i+1:]
        if not others:
            continue
        
        text_tokens = text.lower().split()
        if not text_tokens:
            continue
        
        refs_tokens = [other.lower().split() for other in others if other.strip()]
        refs_tokens = [ref for ref in refs_tokens if ref]
        
        if not refs_tokens:
            continue
        
        try:
            score = sentence_bleu(refs_tokens, text_tokens, smoothing_function=smoothing)
            scores.append(score)
        except:
            continue
    
    return np.mean(scores) if scores else 0.0
# Lower is better (0-1, measures diversity - lower means more diverse)

predictions = [
    "The cat sat on the mat",
    "A dog runs in the park",
    "She loves reading books"
]

references = [
    "A cat was sitting on the mat",
    "The dog is running",
    "She loves reading books"
]

em_scores = [exact_match(p, r) for p, r in zip(predictions, references)]
f1_scores = [token_f1(p, r) for p, r in zip(predictions, references)]
bleu_scores = [compute_bleu(p, r) for p, r in zip(predictions, references)]

rouge_scores = [compute_rouge(p, r) for p, r in zip(predictions, references)]
rouge_avg = {
    'rouge1': np.mean([s['rouge1'] for s in rouge_scores]),
    'rouge2': np.mean([s['rouge2'] for s in rouge_scores]),
    'rougeL': np.mean([s['rougeL'] for s in rouge_scores])
}

bertscore = compute_bertscore(predictions, references)
meteor_scores = [compute_meteor(p, r) for p, r in zip(predictions, references)]

distinct = distinct_n(predictions, n=2)
sbleu = self_bleu(predictions)

print(f"EM: {np.mean(em_scores):.3f}")
print(f"F1: {np.mean(f1_scores):.3f}")
print(f"BLEU: {np.mean(bleu_scores):.3f}")
print(f"ROUGE-1: {rouge_avg['rouge1']:.3f}")
print(f"ROUGE-2: {rouge_avg['rouge2']:.3f}")
print(f"ROUGE-L: {rouge_avg['rougeL']:.3f}")
print(f"BERTScore F1: {bertscore['f1']:.3f}")
print(f"METEOR: {np.mean(meteor_scores):.3f}")
print(f"Distinct-2: {distinct:.3f}")
print(f"Self-BLEU: {sbleu:.3f}")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EM: 0.333
F1: 0.672
BLEU: 0.411
ROUGE-1: 0.738
ROUGE-2: 0.455
ROUGE-L: 0.672
BERTScore F1: 0.966
METEOR: 0.684
Distinct-2: 1.000
Self-BLEU: 0.027


# работа с A100

In [None]:
import torch
from transformers import AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,  # BF16
    bnb_4bit_use_double_quant=True
)

model = AutoModelForCausalLM.from_pretrained(
    'Qwen/Qwen2.5-7B-Instruct',
    quantization_config=bnb_config,
    device_map="auto",  # одна GPU
    attn_implementation="flash_attention_2"  # Flash Attention, ускорение модели, качество не ухудшается
)

args = TrainingArguments(
    output_dir='./output',
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    fp16=False,
    bf16=True,  # BF16
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}
)

# RAG

In [None]:
class RAG:
    def __init__(self, checkpoint='BAAI/bge-base-en-v1.5', device='cuda'):
        self.model = AutoModel.from_pretrained(checkpoint).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        self.device = device
        
        self.x_texts = None
        self.y_texts = None
        self.embeddings = None

    def fit(self, x_texts, y_texts, batch_size=32):
        self.x_texts = x_texts
        self.y_texts = y_texts
        all_embeddings = []

        for i in tqdm(range(0, len(x_texts), batch_size), desc='RAG fitting'):
            batch = x_texts[i:i + batch_size]
            
            inputs = self.tokenizer(
                batch,
                max_length=512,
                truncation=True,
                padding='longest',
                return_tensors='pt'
            ).to(self.device)
            
            with torch.no_grad():
                embeddings = self.model(**inputs).last_hidden_state[:, 0]
            
            all_embeddings.append(embeddings.cpu())

        self.embeddings = torch.cat(all_embeddings, dim=0).numpy()
        self.embeddings = self.embeddings / np.linalg.norm(
            self.embeddings, axis=1, keepdims=True
        )

    def predict(self, x_texts, k=3, batch_size=32):
        if isinstance(x_texts, str):
            x_texts = [x_texts]
            single = True
        else:
            single = False

        all_results = []
        
        for i in range(0, len(x_texts), batch_size):
            batch = x_texts[i:i + batch_size]
            
            inputs = self.tokenizer(
                batch,
                max_length=512,
                truncation=True,
                padding='longest',
                return_tensors='pt'
            ).to(self.device)
            
            with torch.no_grad():
                query_embs = self.model(**inputs).last_hidden_state[:, 0]
            
            query_embs = query_embs.cpu().numpy()
            query_embs = query_embs / np.linalg.norm(query_embs, axis=1, keepdims=True)
            
            similarities = np.dot(query_embs, self.embeddings.T)
            
            for j, sims in enumerate(similarities):
                top_k = np.argsort(sims)[-k - len(x_texts):][::-1]
                
                results = []
                for idx in top_k:
                    if self.x_texts[idx] == batch[j]:
                        continue

                    results.append({
                        'x': self.x_texts[idx],
                        'y': self.y_texts[idx],
                        'similarity': float(sims[idx]),
                        'index': int(idx)
                    })
                
                all_results.append(results[:k])
        
        return all_results[0] if single else all_results

# target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'] для LLaMA / Llama-2 / Llama-3 / Mistral / Qwen / Yi

# Форматы данных для LLM

1. Instruction-following (самый популярный)

In [None]:
# Формат Alpaca
template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{response}"""

# Формат ChatML (для chat-моделей)
template = """<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{user_message}<|im_end|>
<|im_start|>assistant
{assistant_response}<|im_end|>"""

# Формат Llama/Mistral Instruct
template = """<s>[INST] {instruction} [/INST] {response}</s>"""

2. Структура датасета

In [None]:
# Вариант 1: Простой (для SFTTrainer)
dataset = [
    {"text": "### Вопрос: Столица России?\n### Ответ: Москва"},
    {"text": "### Вопрос: 2+2=?\n### Ответ: 4"},
]

# Вариант 2: Разделенный (лучше для контроля)
dataset = [
    {
        "instruction": "Столица России?",
        "response": "Москва"
    },
    {
        "instruction": "2+2=?",
        "response": "4"
    },
]

# Вариант 3: С контекстом
dataset = [
    {
        "instruction": "Суммаризируй текст",
        "input": "Длинный текст...",
        "output": "Краткое содержание"
    }
]

# Маскирование промпта

Способ 1: Автоматический (SFTTrainer)

In [None]:
from trl import SFTTrainer
from datasets import Dataset

# 1. Подготовка данных
data = [
    {"instruction": "Столица России?", "response": "Москва"},
    {"instruction": "Автор 'Война и мир'?", "response": "Лев Толстой"},
]

dataset = Dataset.from_list(data)

# 2. Функция форматирования
def formatting_func(example):
    return f"### Инструкция:\n{example['instruction']}\n\n### Ответ:\n{example['response']}"

# 3. Обучение
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,  # эффективный batch = 4*4 = 16
    learning_rate=2e-4,
    fp16=True,  # или bf16=True для новых GPU
    logging_steps=10,
    save_strategy="epoch",
    optim="paged_adamw_8bit",  # экономия памяти
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=formatting_func,
    max_seq_length=512,  # максимальная длина последовательности
    peft_config=lora_config,  # из прошлого блока
    args=training_args,
)

trainer.train()

Способ 2: Продвинутый (кастомный Data Collator)

In [None]:
from dataclasses import dataclass
from typing import Dict, List
from transformers import DataCollatorForLanguageModeling

@dataclass
class DataCollatorForCompletionOnlyLM:
    tokenizer: any
    response_template: str = "### Ответ:\n"
    mlm: bool = False
    
    def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, any]:
        batch = {
            "input_ids": [],
            "attention_mask": [],
            "labels": []
        }
        
        for example in examples:
            # Токенизируем полный текст
            full_text = example["text"]
            tokenized = self.tokenizer(
                full_text,
                truncation=True,
                max_length=512,
                padding=False,
            )
            
            input_ids = tokenized["input_ids"]
            
            # Находим где начинается ответ
            response_token_ids = self.tokenizer.encode(
                self.response_template, 
                add_special_tokens=False
            )
            
            # Ищем шаблон в input_ids
            labels = [-100] * len(input_ids)
            
            for i in range(len(input_ids) - len(response_token_ids)):
                if input_ids[i:i+len(response_token_ids)] == response_token_ids:
                    # Нашли начало ответа
                    response_start = i + len(response_token_ids)
                    labels[response_start:] = input_ids[response_start:]
                    break
            
            batch["input_ids"].append(input_ids)
            batch["attention_mask"].append(tokenized["attention_mask"])
            batch["labels"].append(labels)
        
        # Padding
        from torch.nn.utils.rnn import pad_sequence
        import torch
        
        batch["input_ids"] = pad_sequence(
            [torch.tensor(x) for x in batch["input_ids"]], 
            batch_first=True, 
            padding_value=self.tokenizer.pad_token_id
        )
        batch["attention_mask"] = pad_sequence(
            [torch.tensor(x) for x in batch["attention_mask"]], 
            batch_first=True, 
            padding_value=0
        )
        batch["labels"] = pad_sequence(
            [torch.tensor(x) for x in batch["labels"]], 
            batch_first=True, 
            padding_value=-100
        )
        
        return batch

Способ 3: Готовый инструмент из trl

In [None]:
from trl import DataCollatorForCompletionOnlyLM

# Указываем шаблон ответа
response_template = "### Ответ:\n"

collator = DataCollatorForCompletionOnlyLM(
    response_template=response_template,
    tokenizer=tokenizer,
    mlm=False
)

# Используем в Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    data_collator=collator,
    # ...
)

# TrainingArguments: Что важно

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    # === Основное ===
    output_dir="./results",
    num_train_epochs=3,
    
    # === Batch size ===
    per_device_train_batch_size=4,  # на 1 GPU
    gradient_accumulation_steps=4,   # накапливаем градиенты
    # Реальный batch = 4 * 4 = 16
    
    # === Learning rate ===
    learning_rate=2e-4,  # для LoRA обычно выше: 1e-4 до 5e-4
    lr_scheduler_type="cosine",  # или "linear"
    warmup_steps=100,  # или warmup_ratio=0.1
    
    # === Оптимизация памяти ===
    fp16=True,  # для старых GPU (V100, RTX 2080)
    # bf16=True,  # для новых GPU (A100, RTX 3090+) - лучше чем fp16
    gradient_checkpointing=True,  # экономия памяти за счет скорости
    optim="paged_adamw_8bit",  # 8-bit optimizer от bitsandbytes
    
    # === Логирование ===
    logging_steps=10,
    logging_dir="./logs",
    report_to="tensorboard",  # или "wandb"
    
    # === Сохранение ===
    save_strategy="epoch",  # "steps", "epoch", "no"
    save_total_limit=2,  # храним только 2 последних чекпоинта
    
    # === Eval (если есть val set) ===
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    
    # === Прочее ===
    remove_unused_columns=False,  # важно для SFTTrainer
    dataloader_num_workers=4,  # параллельная загрузка данных
)

# Полный пример: От данных до обученной модели

In [1]:
!pip -q install transformers>=4.38.0 trl>=0.8.0 peft>=0.9.0 bitsandbytes

import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer

checkpoint = "mistralai/Mistral-7B-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_type=torch.float16,
    bnb_4bit_use_double_quant=True
)

model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    quantization_config=bnb_config,
    device_map='auto'
)

prepare_model_for_kbit_training(model)

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    lora_dropout=0.05,
    bias='none',
    task_type='CAUSAL_LM'
)

train_data = [
    {
        "question": "Переведи на английский: Привет, как дела?",
        "answer": "Hello, how are you?"
    },
    {
        "question": "Реши: 15 * 8",
        "answer": "120"
    }
]
dataset = Dataset.from_list([
    {"text": f"<s>[INST] {item['question']} [/INST] {item['answer']}</s>"}
    for item in train_data
])

args = TrainingArguments(
    optim='paged_adamw_8bit',
    report_to='none',
    output_dir='./result',
    fp16=torch.cuda.is_available()
)
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=args,
    peft_config=lora_config
)
trainer.train()
trainer.save_model("./qa_model")

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
pylibcudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
cudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
bigframes 2.12.0 requires rich<14,>=12.4.4, but you have rich 14.2.0 which is incompatible.
libcugraph-cu12 25.6.0 requires libraft-cu12==25.6.*, but you have libraft-cu12 25.2.0 which is incompatible.
cudf-polars-cu12 25.6.0 requires pylibcudf-cu12==25.6.*, but you have pylibcudf-cu12 25.2.2 which is incompatible.
pylibcugraph-cu12 25.6.0 requires pylibraft-cu12==25.6.*, but you have pylibraft-cu12 25.2.0 which is incompatible.
pylibcugraph-cu12 

2025-11-09 07:06:42.452891: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762672002.680153      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762672002.744361      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'



config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Adding EOS to train dataset:   0%|          | 0/2 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/2 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/2 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.
  return fn(*args, **kwargs)


Step,Training Loss


In [2]:
from peft import PeftModel

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    quantization_config=bnb_config,
    device_map="auto"
)

model = PeftModel.from_pretrained(model, "./qa_model")
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_pro

In [None]:
# Подготовка промпта
prompt = "<s>[INST] Привет, как дела? [/INST]"

# Токенизация
inputs = tokenizer(prompt, return_tensors="pt").to(trainer.model.device)

# Генерация
outputs = model.generate(
    **inputs,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    repetition_penalty=1.2
)

# Декодирование
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

In [None]:
inputs = tokenizer('<s>[INST] Приветики-пистолетики! [/INST]</s>', return_tensors='pt').to(trainer.model.device)
outputs = model(
    **inputs,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    repetitiin_penalty=1.2
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
# Подготовка промпта
prompt = "<s>[INST] Привет, как дела? [/INST]"

# Токенизация
inputs = tokenizer(prompt, return_tensors="pt").to(trainer.model.device)

# Генерация
outputs = model.generate(
    **inputs,
    
    # === Длина ===
    max_new_tokens=256,      # максимум новых токенов
    min_new_tokens=10,       # минимум (опционально)
    max_length=512,          # альтернатива: общая длина (prompt + generation)
    
    # === Стратегия декодирования ===
    do_sample=True,          # False = greedy, True = sampling
    
    # === Temperature (креативность) ===
    temperature=0.7,         # 0.1 = консервативно, 1.0 = нормально, 2.0 = креативно
                             # <0.7: факты, код, переводы
                             # 0.7-1.0: обычная генерация
                             # >1.0: креативное письмо
    
    # === Top-p (nucleus sampling) ===
    top_p=0.9,              # рассматриваем топ токенов с суммарной вероятностью 90%
                            # 0.9-0.95: хороший баланс
                            # 0.5: более консервативно
                            # 0.99: почти все токены
    
    # === Top-k sampling ===
    top_k=50,               # рассматриваем только топ-50 токенов
                            # обычно 40-100
                            # 0 = выключено
    
    # === Repetition penalty ===
    repetition_penalty=1.2, # штраф за повторения
                            # 1.0 = нет штрафа
                            # 1.1-1.5: легкий штраф (обычно хорошо)
                            # >1.5: сильный штраф
    
    # # === Stopping criteria ===
    # eos_token_id=tokenizer.eos_token_id,
    # pad_token_id=tokenizer.pad_token_id,
    
    # === Другое ===
    num_return_sequences=1,  # сколько вариантов генерировать
    num_beams=1,            # beam search (1 = выключен)
)

# Декодирование
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

# Если не хватает памяти

In [None]:
# Уменьшите:
per_device_train_batch_size=2  # было 4
max_seq_length=256  # было 512
r=8  # было 16 в LoRA

# Добавьте:
gradient_checkpointing=True
optim="paged_adamw_8bit"

# Блок 3: Inference, генерация и валидация

Часть 1: Загрузка обученной модели

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    quantization_config=bnb_config,
    device_map="auto"
)

model = PeftModel.from_pretrained(model, "./qa_model")
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model.eval()

Часть 2: Генерация текста

In [6]:
# Подготовка промпта
prompt = "<s>[INST] Hello! [/INST]"

# Токенизация
inputs = tokenizer(prompt, return_tensors="pt").to(trainer.model.device)

# Генерация
outputs = trainer.model.generate(
    **inputs,
    max_new_tokens=100,  # сколько токенов сгенерировать
    do_sample=False,     # greedy decoding
)

# Декодирование
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[INST] Hello! [/INST]

Hi there, I'm new to this forum. My name is Kieran and I am from the United Kingdom (hence my accent). I have been playing TF2 for 5 years now, and have always loved it. It has never failed to give me fun times with friends or strangers alike.
I've just recently got back into trading after a long break, so if you want to trade with someone who knows what they are doing then feel free to send me an offer. Also, any tips would be appreciated as I'm not very good at trading myself xD .
I will also like to join your group, as I think that we could help each other out in trades.


Параметры генерации

In [4]:
outputs = model.generate(
    **inputs,
    
    # === Длина ===
    max_new_tokens=256,      # максимум новых токенов
    min_new_tokens=10,       # минимум (опционально)
    max_length=512,          # альтернатива: общая длина (prompt + generation)
    
    # === Стратегия декодирования ===
    do_sample=True,          # False = greedy, True = sampling
    
    # === Temperature (креативность) ===
    temperature=0.7,         # 0.1 = консервативно, 1.0 = нормально, 2.0 = креативно
                             # <0.7: факты, код, переводы
                             # 0.7-1.0: обычная генерация
                             # >1.0: креативное письмо
    
    # === Top-p (nucleus sampling) ===
    top_p=0.9,              # рассматриваем топ токенов с суммарной вероятностью 90%
                            # 0.9-0.95: хороший баланс
                            # 0.5: более консервативно
                            # 0.99: почти все токены
    
    # === Top-k sampling ===
    top_k=50,               # рассматриваем только топ-50 токенов
                            # обычно 40-100
                            # 0 = выключено
    
    # === Repetition penalty ===
    repetition_penalty=1.2, # штраф за повторения
                            # 1.0 = нет штрафа
                            # 1.1-1.5: легкий штраф (обычно хорошо)
                            # >1.5: сильный штраф
    
    # === Stopping criteria ===
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    
    # === Другое ===
    num_return_sequences=1,  # сколько вариантов генерировать
    num_beams=1,            # beam search (1 = выключен)
)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Комбинации параметров для разных задач

In [None]:
# 1. Факты, QA, перевод (нужна точность)
generation_config_precise = {
    "max_new_tokens": 100,
    "do_sample": True,
    "temperature": 0.3,
    "top_p": 0.85,
    "repetition_penalty": 1.1,
}

# 2. Обычная генерация (баланс)
generation_config_balanced = {
    "max_new_tokens": 200,
    "do_sample": True,
    "temperature": 0.7,
    "top_p": 0.9,
    "top_k": 50,
    "repetition_penalty": 1.2,
}

# 3. Креативное письмо
generation_config_creative = {
    "max_new_tokens": 300,
    "do_sample": True,
    "temperature": 1.0,
    "top_p": 0.95,
    "repetition_penalty": 1.15,
}

# 4. Детерминированная генерация (для debug)
generation_config_greedy = {
    "max_new_tokens": 100,
    "do_sample": False,  # greedy decoding
}

# Использование
outputs = model.generate(**inputs, **generation_config_balanced)

Batch генерация

In [None]:
prompts = [
    "<s>[INST] Столица России? [/INST]",
    "<s>[INST] 2+2=? [/INST]",
    "<s>[INST] Кто написал 'Евгений Онегин'? [/INST]"
]

# Токенизация с padding
inputs = tokenizer(
    prompts, 
    return_tensors="pt", 
    padding=True,  # важно!
    truncation=True,
    max_length=512
).to(model.device)

# Генерация
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    temperature=0.7,
    top_p=0.9,
    pad_token_id=tokenizer.pad_token_id
)

# Декодирование
for i, output in enumerate(outputs):
    text = tokenizer.decode(output, skip_special_tokens=True)
    print(f"Prompt {i}: {text}\n")

In [None]:
def generate_response(model, tokenizer, promts):
    inputs = tokenizer(
        promts,
        return_tensors='pt',
        padding='longest',
        truncation=True,
        max_length=512
    ).to(model.device)
    outputs = model.generate(
        **inputs,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.2,
        max_new_tokens=256
    )

    texts = []
    for output in outputs:
        output = output[inputs.input_ids.shape[1]:]
        text = tokenizer.decode(output, skip_special_tokens=True)
        texts.append(text)

    return texts

In [None]:
def template_processing(question, answer=None):
    if answer is None:
        return tokenizer.apply_chat_template(
            [{'role': 'user', 'content': f'здесь задача модели:\n\n{question}'}],
            tokenize=False,
            add_generation_promt=True
        )
    else:
        return tokenizer.apply_chat_template(
            [{'role': 'user', 'content': f'здесь задача модели:\n\n{question}'},
             {'role': 'assistant', 'content': answer}],
            tokenize=False
        )

# Часть 3: Валидация в процессе обучения

Вариант 1: Простой callback для генерации примеров

In [None]:
from transformers import TrainerCallback

class GenerationCallback(TrainerCallback):
    def __init__(self, tokenizer, test_prompts, every_n_steps=100):
        self.tokenizer = tokenizer
        self.test_prompts = test_prompts
        self.every_n_steps = every_n_steps
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        if state.global_step % self.every_n_steps == 0:
            print(f"\n{'='*50}")
            print(f"Generation at step {state.global_step}")
            print(f"{'='*50}")
            
            model.eval()
            for prompt in self.test_prompts:
                inputs = self.tokenizer(prompt, return_tensors="pt").to(model.device)
                
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=50,
                        temperature=0.7,
                        top_p=0.9,
                        do_sample=True
                    )
                
                generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                print(f"\nPrompt: {prompt}")
                print(f"Generated: {generated}")
            
            model.train()
            print(f"{'='*50}\n")

# Использование
test_prompts = [
    "<s>[INST] Столица России? [/INST]",
    "<s>[INST] Что такое Python? [/INST]",
]

trainer = SFTTrainer(
    model=model,
    # ... остальные параметры
    callbacks=[GenerationCallback(tokenizer, test_prompts, every_n_steps=50)]
)

Вариант 2: Validation set с метриками

In [None]:
from datasets import Dataset

# Разделяем данные
train_data = data[:800]
val_data = data[800:]

train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)

# В TrainingArguments добавляем
training_args = TrainingArguments(
    # ...
    evaluation_strategy="steps",  # или "epoch"
    eval_steps=100,               # оценивать каждые 100 шагов
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

# Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,  # добавили validation set
    # ...
)

Вариант 3: Кастомные метрики (ROUGE, BLEU)

In [None]:
import numpy as np
from evaluate import load

# Загружаем метрики
rouge = load('rouge')
bleu = load('bleu')

# Сам генерируй и считай метрики
def evaluate_model(model, val_data):
    rouge = load('rouge')
    predictions = []
    references = []
    
    for item in val_data[:50]:
        prompt = f"<s>[INST] {item['question']} [/INST]"
        
        # Генерируем
        generated = generate_response(model, tokenizer, prompt)
        
        predictions.append(generated)
        references.append(item['answer'])
    
    scores = rouge.compute(predictions=predictions, references=references)
    return scores  # чем больше, тем лучше

########################################## ИЛИ

rouge = evaluate.load('rouge')
def evaluate_model(model, val_data):
    predictions = generate_response(model, tokenizer, [f"<s>[INST] {item['question']} [/INST]" for item in val_data])
    references = [item['answer'] for item in val_data]

    return rouge.compute(predictions=predictions, references=references)

########################################## ИЛИ

rouge = evaluate.load('rouge')
def evaluate_model(model, tokenizer, val_data, batch_size=8):
    predictions = []
    for i in tqdm(range(0, len(val_data), batch_size)):
        batch = [val_data[j] for j in range(i, min(i+batch_size, len(val_data)))]
        batch_predictions = generate_response(
            model, tokenizer,
            [tokenizer.apply_chat_template(
                [{'role': 'user', 'content': item['text']}],
                add_generation_prompt=True, tokenize=False
            ) for item in batch]
        )
        predictions.extend(batch_predictions)
    references = [item['summary'] for item in val_data]
    return rouge.compute(predictions=predictions, references=references)

Вариант 4: Полноценная валидация с генерацией (для олимпиады)

In [None]:
def evaluate_model(model, tokenizer, val_data, num_samples=50):
    """
    Оценка модели на validation set с реальной генерацией
    """
    model.eval()
    results = {
        "rouge1": [],
        "rouge2": [],
        "rougeL": [],
        "exact_match": 0,
    }
    
    rouge_metric = load('rouge')
    
    for i, example in enumerate(val_data[:num_samples]):
        # Формируем промпт
        prompt = f"<s>[INST] {example['instruction']} [/INST]"
        true_response = example['response']
        
        # Генерируем
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=100,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )
        
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Убираем промпт из сгенерированного текста
        generated = generated.replace(prompt, "").strip()
        
        # Считаем метрики
        rouge_scores = rouge_metric.compute(
            predictions=[generated],
            references=[true_response]
        )
        
        results["rouge1"].append(rouge_scores["rouge1"])
        results["rouge2"].append(rouge_scores["rouge2"])
        results["rougeL"].append(rouge_scores["rougeL"])
        
        # Exact match (для простых задач типа QA)
        if generated.strip().lower() == true_response.strip().lower():
            results["exact_match"] += 1
    
    # Усредняем
    final_results = {
        "rouge1": np.mean(results["rouge1"]),
        "rouge2": np.mean(results["rouge2"]),
        "rougeL": np.mean(results["rougeL"]),
        "exact_match": results["exact_match"] / num_samples,
    }
    
    model.train()
    return final_results

# Использование
val_results = evaluate_model(model, tokenizer, val_data)
print(val_results)

# Часть 4: Debugging и типичные проблемы

Проблема 1: Модель повторяет промпт

In [None]:
# Проблема:
prompt = "Вопрос: Столица России?"
# Генерация: "Вопрос: Столица России? Вопрос: Столица России? Вопрос..."

# Решение 1: Увеличить repetition_penalty
outputs = model.generate(
    **inputs,
    repetition_penalty=1.5,  # было 1.2
)

# Решение 2: Правильно форматировать промпт (использовать тот же формат что при обучении)
prompt = "<s>[INST] Столица России? [/INST]"  # как в обучении!

# Решение 3: Убрать промпт из вывода
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = generated.replace(prompt, "").strip()

Проблема 2: Модель генерирует бессмыслицу

In [None]:
# Причины:
# 1. Слишком высокий temperature
temperature=0.5  # вместо 1.5

# 2. Модель недообучена
# Проверьте loss, обучите дольше

# 3. Слишком мало данных
# Нужно минимум 100-500 качественных примеров

# 4. Неправильный формат промпта
# Используйте ТОТ ЖЕ формат что и при обучении!

Проблема 3: Модель обрывается на середине

In [None]:
# Проблема: генерация заканчивается слишком рано

# Решение 1: Увеличить max_new_tokens
max_new_tokens=256  # было 50

# Решение 2: Проверить eos_token
print(f"EOS token: {tokenizer.eos_token}")
print(f"EOS token ID: {tokenizer.eos_token_id}")

# Решение 3: Добавить min_new_tokens
min_new_tokens=20

Проблема 4: Медленная генерация

In [None]:
# Решение 1: Использовать квантизацию
# (уже покрыто выше)

# Решение 2: Уменьшить max_new_tokens
max_new_tokens=100  # было 512

# Решение 3: Использовать greedy вместо sampling
do_sample=False  # быстрее, но менее разнообразно

# Решение 4: Batch inference
# (покрыто выше)

Часть 6: Быстрый inference для олимпиады

In [None]:
class LLMInference:
    def __init__(self, base_model_name, adapter_path, use_4bit=True):
        """Класс для быстрого inference"""
        
        if use_4bit:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                quantization_config=bnb_config,
                device_map="auto"
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                torch_dtype=torch.float16,
                device_map="auto"
            )
        
        self.model = PeftModel.from_pretrained(self.model, adapter_path)
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model.eval()
    
    def generate(self, prompt, max_new_tokens=100, temperature=0.7, **kwargs):
        """Генерация одного ответа"""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=kwargs.get('top_p', 0.9),
                do_sample=True,
                repetition_penalty=kwargs.get('repetition_penalty', 1.2),
                pad_token_id=self.tokenizer.pad_token_id,
            )
        
        generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Убираем промпт
        response = generated.replace(prompt, "").strip()
        return response
    
    def batch_generate(self, prompts, **kwargs):
        """Batch генерация"""
        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=kwargs.get('max_new_tokens', 100),
                temperature=kwargs.get('temperature', 0.7),
                top_p=kwargs.get('top_p', 0.9),
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        
        results = []
        for i, output in enumerate(outputs):
            generated = self.tokenizer.decode(output, skip_special_tokens=True)
            response = generated.replace(prompts[i], "").strip()
            results.append(response)
        
        return results

# Использование
inferencer = LLMInference(
    base_model_name="mistralai/Mistral-7B-v0.1",
    adapter_path="./final_model"
)

response = inferencer.generate(
    "<s>[INST] Столица России? [/INST]",
    max_new_tokens=50,
    temperature=0.5
)
print(response)

Часть 7: Сохранение и загрузка для submission

In [None]:
# === После обучения ===

# Вариант 1: Сохранить только LoRA адаптер (маленький размер)
trainer.save_model("./lora_adapter")
# Размер: ~10-50 MB

# Вариант 2: Объединить и сохранить полную модель
model = model.merge_and_unload()
model.save_pretrained("./full_model")
tokenizer.save_pretrained("./full_model")
# Размер: ~13 GB для 7B модели

# === Для загрузки ===

# Вариант 1: Загрузить LoRA адаптер
base_model = AutoModelForCausalLM.from_pretrained(...)
model = PeftModel.from_pretrained(base_model, "./lora_adapter")

# Вариант 2: Загрузить полную модель
model = AutoModelForCausalLM.from_pretrained("./full_model")
tokenizer = AutoTokenizer.from_pretrained("./full_model")