In [None]:
!pip install mteb

In [None]:
IDX_SLERP = 1

In [None]:
import torch
import os
from safetensors.torch import load_file, save_file
from sentence_transformers import SentenceTransformer, models
from transformers import AutoTokenizer
import mteb
import numpy as np
import math
import pandas as pd
import shutil

# Configuration
symm_model_dir = "/kaggle/input/s2-models-embed/epoch_1_model/epoch_1"      # symmetric model folder
asymm_model_dir = "/kaggle/input/s2-models-embed/epoch_1_asym/epoch_1_asym"     # asymmetric model folder
blend_weight = [0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95][IDX_SLERP]                     # fraction for symmetric model (0..1)
output_merged = "merged_model"

def slerp_weights(state_dict1, state_dict2, alpha):
    """
    Spherical linear interpolation for each weight tensor.
    """
    merged = {}
    
    # Проверка совместимости моделей
    if set(state_dict1.keys()) != set(state_dict2.keys()):
        raise ValueError("Models have different architectures - key mismatch")
    
    for k in state_dict1:
        w1 = state_dict1[k].float()
        w2 = state_dict2[k].float()
        
        # Проверка размерностей
        if w1.shape != w2.shape:
            print(f"Warning: Shape mismatch for {k}: {w1.shape} vs {w2.shape}")
            merged[k] = w1  # Используем первую модель при несовпадении
            continue
            
        # Flatten
        v1 = w1.view(-1)
        v2 = w2.view(-1)
        
        # Compute norms and dot
        dot = torch.dot(v1, v2) / (torch.norm(v1) * torch.norm(v2) + 1e-8)
        omega = torch.acos(torch.clamp(dot, -1.0, 1.0))
        
        if torch.abs(omega) < 1e-6:
            merged_tensor = w1
        else:
            so = torch.sin(omega)
            part1 = torch.sin((1 - alpha) * omega) / so
            part2 = torch.sin(alpha * omega) / so
            merged_flat = part1 * v1 + part2 * v2
            merged_tensor = merged_flat.view_as(w1)
            
        merged[k] = merged_tensor
    return merged

# Создание выходных директорий
os.makedirs(output_merged, exist_ok=True)
os.makedirs("results", exist_ok=True)

print("Loading model state dicts...")
# Load state dicts
sd_symm = load_file(f"{symm_model_dir}/model.safetensors")
sd_asymm = load_file(f"{asymm_model_dir}/model.safetensors")

print("Merging models with SLERP...")
# Merge
alpha = blend_weight
sd_merged = slerp_weights(sd_symm, sd_asymm, alpha)

print("Saving merged model...")
# Save merged
save_file(sd_merged, f"{output_merged}/model.safetensors")

# Copy config and tokenizer files
config_files = ['config.json', 'tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
for fname in config_files:
    src = f"{symm_model_dir}/{fname}"
    dst = f"{output_merged}/{fname}"
    if os.path.exists(src):
        shutil.copy(src, dst)
        print(f"Copied {fname}")
    else:
        print(f"Warning: {fname} not found in {symm_model_dir}")

print("Building SentenceTransformer with merged weights...")
# Build SBERT with merged weights
dir_merged = output_merged
transformer = models.Transformer(
    model_name_or_path=dir_merged,
    tokenizer_name_or_path=dir_merged
)
pooling = models.Pooling(
    transformer.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)
base_model = SentenceTransformer(modules=[transformer, pooling])

# Правильная обертка для префиксов
class PrefixSentenceTransformer:
    """
    Wrapper для добавления префиксов согласно обучению модели.
    Использует search_query: для запросов и search_document: для документов.
    """
    def __init__(self, base_model, query_prefix="search_query:", doc_prefix="search_document:"):
        self.base_model = base_model
        self.query_prefix = query_prefix
        self.doc_prefix = doc_prefix
        
    def encode(self, sentences, **kwargs):
        """
        Кодирует предложения. По умолчанию считает их документами,
        если не указан параметр is_query=True.
        """
        is_query = kwargs.pop('is_query', False)
        
        if isinstance(sentences, str):
            sentences = [sentences]
            
        if is_query:
            prefixed_sentences = [f"{self.query_prefix} {sent}" for sent in sentences]
        else:
            prefixed_sentences = [f"{self.doc_prefix} {sent}" for sent in sentences]
            
        return self.base_model.encode(prefixed_sentences, **kwargs)
    
    def encode_queries(self, queries, **kwargs):
        """Явно кодирует запросы."""
        return self.encode(queries, is_query=True, **kwargs)
    
    def encode_corpus(self, corpus, **kwargs):
        """Явно кодирует документы корпуса."""
        # Корпус может быть словарем с ключами 'title' и 'text'
        if isinstance(corpus, dict) and len(corpus) > 0:
            # Если это словарь корпуса (doc_id -> {'title': ..., 'text': ...})
            texts = []
            for doc_id, doc in corpus.items():
                if isinstance(doc, dict):
                    # Объединяем title и text если есть
                    title = doc.get('title', '')
                    text = doc.get('text', '')
                    combined = f"{title} {text}".strip() if title else text
                    texts.append(combined)
                else:
                    texts.append(str(doc))
            return self.encode(texts, is_query=False, **kwargs)
        else:
            return self.encode(corpus, is_query=False, **kwargs)
    
    def save(self, path, **kwargs):
        return self.base_model.save(path, **kwargs)
    
    def push_to_hub(self, *args, **kwargs):
        return self.base_model.push_to_hub(*args, **kwargs)

# Создание обернутой модели
wrapped = PrefixSentenceTransformer(base_model)

print("Starting MTEB evaluation...")
# Evaluation - русские задачи
tasks_list = [
    'CEDRClassification', 'GeoreviewClassification',
    'GeoreviewClusteringP2P', 'HeadlineClassification',
    'InappropriatenessClassification', 'KinopoiskClassification', 'RUParaPhraserSTS',
    'RuReviewsClassification','RuSTSBenchmarkSTS', 'RuSciBenchGRNTIClassification',
    'RuSciBenchGRNTIClusteringP2P', 'RuSciBenchOECDClassification',
    'RuSciBenchOECDClusteringP2P', 'SensitiveTopicsClassification',
]

try:
    tasks = mteb.get_tasks(tasks=tasks_list)
    evaluator = mteb.MTEB(tasks=tasks)
    results = evaluator.run(
        wrapped,
        output_folder="results/merged_model_mteb",
        eval_splits=["test"],
        verbosity=2,
    )
    
    print("Evaluation completed successfully!")
    
    # Collect scores
    records = []
    for r in results:
        score = r.get_score()
        if 'Clustering' in r.task_name:
            task_type = 'Clustering'
        elif 'STS' in r.task_name:
            task_type = 'STS'
        elif r.task_name in ['CEDRClassification', 'SensitiveTopicsClassification']:
            task_type = 'MultilabelClassification'
        else:
            task_type = 'Classification'

        
        records.append({
            "task": r.task_name,
            "score": score,
            "type": task_type
        })
        print(f"{r.task_name}: {score:.4f}")
    
    df = pd.DataFrame(records)
    
    # 1) Table: each task score
    print("\n=== Task Scores ===")
    print(df.to_string(index=False))
    
    # 2) Overall mean score
    overall_mean = df['score'].mean()
    print(f"\n=== Overall Mean Score ===")
    print(f"Overall Mean: {overall_mean:.4f}")
    
    # 3) Mean per type
    type_means = df.groupby('type')['score'].mean().reset_index()
    print(f"\n=== Scores by Task Type ===")
    print(type_means.to_string(index=False))
    
    # Save results
    df.to_csv('results/task_scores.csv', index=False)
    type_means.to_csv('results/type_means.csv', index=False)
    with open('results/overall_mean.txt', 'w') as f:
        f.write(f"{overall_mean:.6f}")
    
    print(f"\nResults saved to results/ folder")
    
except Exception as e:
    print(f"Error during evaluation: {e}")
    import traceback
    traceback.print_exc()