In [None]:
!pip install datasets transformers sentence-transformers
!pip install beir mteb

In [1]:
import json
import os
import datasets
import torch
import logging
from dataclasses import dataclass
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from typing import Optional, List, Dict, Any, Tuple
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models
import mteb
from sentence_transformers import SentenceTransformer
BEIR_AVAILABLE = True
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [3]:
@dataclass
class ModelConfig:
    """Configuration for custom BGE model"""
    model_name: str = "BAAI/bge-small-en-v1.5"
    hidden_size: int = 384
    projection_dim: int = 768
    num_layers_to_average: int = 4
    pooling_strategy: str = "mean"  # "cls", "mean", "attention"
    use_projection_head: bool = True
    dropout_rate: float = 0.1
    activation: str = "gelu"
    max_seq_length: int = 512

In [4]:
class ProjectionHead(nn.Module):
    """Projection head with L2 normalization"""
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
                 activation: str = "gelu", dropout: float = 0.1):
        super().__init__()

        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            self._get_activation(activation),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def _get_activation(self, activation: str):
        activations = {
            "relu": nn.ReLU(),
            "gelu": nn.GELU(),
            "tanh": nn.Tanh(),
            "silu": nn.SiLU()
        }
        return activations.get(activation, nn.GELU())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        projected = self.projection(x)
        return F.normalize(projected, p=2, dim=-1)

In [5]:
class AttentionPooling(nn.Module):
    """Attention-based pooling mechanism"""
    def __init__(self, hidden_size: int):
        super().__init__()
        self.attention_weights = nn.Linear(hidden_size, 1)

    def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        attention_scores = self.attention_weights(hidden_states).squeeze(-1)
        attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
        attention_weights = F.softmax(attention_scores, dim=-1)
        pooled = torch.sum(hidden_states * attention_weights.unsqueeze(-1), dim=1)
        return pooled

In [6]:
class LayerwiseWeightedSum(nn.Module):
    """Weighted sum of multiple layers"""
    def __init__(self, num_layers: int, learnable: bool = True):
        super().__init__()
        self.num_layers = num_layers
        self.learnable = learnable

        if learnable:
            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
        else:
            self.register_buffer('layer_weights', torch.ones(num_layers) / num_layers)

    def forward(self, layer_outputs: List[torch.Tensor]) -> torch.Tensor:
        if self.learnable:
            weights = F.softmax(self.layer_weights, dim=0)
        else:
            weights = self.layer_weights

        weighted_sum = torch.zeros_like(layer_outputs[0])
        for i, layer_output in enumerate(layer_outputs):
            weighted_sum += weights[i] * layer_output

        return weighted_sum

In [7]:
class CustomBGEModel(nn.Module):
    """Custom BGE model with enhanced features"""
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.backbone = AutoModel.from_pretrained(config.model_name)
        self.backbone.config.output_hidden_states = True

        actual_hidden_size = self.backbone.config.hidden_size

        self.layerwise_pooling = LayerwiseWeightedSum(
            num_layers=config.num_layers_to_average,
            learnable=True
        )

        if config.pooling_strategy == "attention":
            self.attention_pooling = AttentionPooling(actual_hidden_size)

        if config.use_projection_head:
            self.projection_head = ProjectionHead(
                input_dim=actual_hidden_size,
                hidden_dim=config.projection_dim,
                output_dim=config.projection_dim,
                activation=config.activation,
                dropout=config.dropout_rate
            )

    def _pool_embeddings(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        if self.config.pooling_strategy == "cls":
            return hidden_states[:, 0, :]
        elif self.config.pooling_strategy == "mean":
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
            sum_mask = input_mask_expanded.sum(1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            return sum_embeddings / sum_mask
        elif self.config.pooling_strategy == "attention":
            return self.attention_pooling(hidden_states, attention_mask)
        else:
            raise ValueError(f"Unknown pooling strategy: {self.config.pooling_strategy}")

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        all_hidden_states = outputs.hidden_states
        last_n_layers = all_hidden_states[-self.config.num_layers_to_average:]
        weighted_hidden_states = self.layerwise_pooling(last_n_layers)
        pooled_embeddings = self._pool_embeddings(weighted_hidden_states, attention_mask)

        if self.config.use_projection_head:
            embeddings = self.projection_head(pooled_embeddings)
        else:
            embeddings = F.normalize(pooled_embeddings, p=2, dim=-1)

        return embeddings

In [8]:
class CustomBGEWrapper:
    """Wrapper to make CustomBGE compatible with BEIR/MTEB evaluation"""
    def __init__(self, model: CustomBGEModel, tokenizer, device: str = "cuda"):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()

    def encode(self, sentences: List[str], batch_size: int = 32, **kwargs) -> np.ndarray:
        """Encode sentences to embeddings"""
        embeddings = []

        for i in tqdm(range(0, len(sentences), batch_size), desc="Encoding"):
            batch = sentences[i:i+batch_size]

            # Tokenize
            encoded = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=self.model.config.max_seq_length,
                return_tensors="pt"
            )

            # Move to device
            input_ids = encoded['input_ids'].to(self.device)
            attention_mask = encoded['attention_mask'].to(self.device)

            # Get embeddings
            with torch.no_grad():
                batch_embeddings = self.model(input_ids, attention_mask)
                embeddings.append(batch_embeddings.cpu().numpy())

        return np.vstack(embeddings)

In [9]:
class STSDataset(Dataset):
    """STS Dataset for training"""
    def __init__(self, tokenizer, max_length: int = 512, dataset_name: str = "stsb_multi_mt"):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self._load_sts_data(dataset_name)

    def _load_sts_data(self, dataset_name: str):
        """Load STS data from HuggingFace datasets"""
        try:
            if dataset_name == "stsb_multi_mt":
                # Load multilingual STS benchmark
                dataset = datasets.load_dataset("stsb_multi_mt", "en", split="train")
                data = []
                for item in dataset:
                    data.append({
                        'sentence1': item['sentence1'],
                        'sentence2': item['sentence2'],
                        'score': item['similarity_score'] / 5.0  # Normalize to [0,1]
                    })
            else:
                # Fallback to original STS-B
                dataset = datasets.load_dataset("glue", "stsb", split="train")
                data = []
                for item in dataset:
                    if item['label'] >= 0:  # Filter out invalid labels
                        data.append({
                            'sentence1': item['sentence1'],
                            'sentence2': item['sentence2'],
                            'score': item['label'] / 5.0
                        })

            logger.info(f"Loaded {len(data)} samples from {dataset_name}")
            return data

        except Exception as e:
            logger.warning(f"Failed to load {dataset_name}: {e}")
            # Create minimal synthetic data as fallback
            return [
                {"sentence1": "The cat sat on the mat", "sentence2": "A cat was on the mat", "score": 0.9},
                {"sentence1": "I love pizza", "sentence2": "Pizza is great", "score": 0.8},
                {"sentence1": "The weather is nice", "sentence2": "It's raining", "score": 0.2},
            ] * 100

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        encoding1 = self.tokenizer(
            item['sentence1'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        encoding2 = self.tokenizer(
            item['sentence2'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids_1': encoding1['input_ids'].squeeze(0),
            'attention_mask_1': encoding1['attention_mask'].squeeze(0),
            'input_ids_2': encoding2['input_ids'].squeeze(0),
            'attention_mask_2': encoding2['attention_mask'].squeeze(0),
            'score': torch.tensor(item['score'], dtype=torch.float)
        }


In [10]:
class BEIRRetriever:
    """BEIR-compatible retriever"""
    def __init__(self, model_wrapper: CustomBGEWrapper):
        self.model = model_wrapper

    def retrieve(self, corpus, queries, top_k: int, score_function: str = "dot", **kwargs):
        """Retrieve top-k documents for queries"""
        # Encode queries and corpus
        query_embeddings = self.model.encode(list(queries.values()))
        corpus_embeddings = self.model.encode([corpus[doc_id]["title"] + " " + corpus[doc_id]["text"]
                                              for doc_id in corpus])

        # Compute similarities
        if score_function == "dot":
            scores = np.dot(query_embeddings, corpus_embeddings.T)
        else:  # cosine
            scores = cosine_similarity(query_embeddings, corpus_embeddings)

        # Get top-k for each query
        results = {}
        corpus_ids = list(corpus.keys())

        for i, qid in enumerate(queries.keys()):
            query_scores = scores[i]
            top_indices = np.argsort(query_scores)[::-1][:top_k]

            results[qid] = {}
            for j, idx in enumerate(top_indices):
                doc_id = corpus_ids[idx]
                results[qid][doc_id] = float(query_scores[idx])

        return results


In [11]:
class CustomBGETrainer:
    """Trainer for custom BGE model with BEIR/MTEB evaluation"""
    def __init__(self, model: CustomBGEModel, tokenizer, device: str = "cuda"):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device

    def compute_similarity_loss(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor,
                               scores: torch.Tensor) -> torch.Tensor:
        cos_sim = F.cosine_similarity(embeddings1, embeddings2, dim=-1)
        loss = F.mse_loss(cos_sim, scores)
        return loss

    def train_epoch(self, dataloader: DataLoader, optimizer, scheduler=None):
        self.model.train()
        total_loss = 0

        for batch_idx, batch in enumerate(dataloader):
            for key in batch:
                batch[key] = batch[key].to(self.device)

            embeddings1 = self.model(batch['input_ids_1'], batch['attention_mask_1'])
            embeddings2 = self.model(batch['input_ids_2'], batch['attention_mask_2'])

            loss = self.compute_similarity_loss(embeddings1, embeddings2, batch['score'])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if scheduler:
                scheduler.step()

            total_loss += loss.item()

            if batch_idx % 50 == 0:
                logger.info(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

        return total_loss / len(dataloader)

    def evaluate_sts(self, dataloader: DataLoader) -> Dict[str, float]:
        """Evaluate on STS task"""
        self.model.eval()
        predictions = []
        ground_truth = []

        with torch.no_grad():
            for batch in dataloader:
                for key in batch:
                    batch[key] = batch[key].to(self.device)

                embeddings1 = self.model(batch['input_ids_1'], batch['attention_mask_1'])
                embeddings2 = self.model(batch['input_ids_2'], batch['attention_mask_2'])

                cos_sim = F.cosine_similarity(embeddings1, embeddings2, dim=-1)

                predictions.extend(cos_sim.cpu().numpy())
                ground_truth.extend(batch['score'].cpu().numpy())

        spearman_corr, _ = spearmanr(predictions, ground_truth)
        pearson_corr, _ = pearsonr(predictions, ground_truth)

        return {
            'spearman_correlation': spearman_corr,
            'pearson_correlation': pearson_corr,
            'mse': np.mean((np.array(predictions) - np.array(ground_truth)) ** 2)
        }

    def evaluate_beir(self, model_wrapper: CustomBGEWrapper, datasets: List[str] = None) -> Dict[str, Dict]:
        """Evaluate on BEIR datasets"""
        if not BEIR_AVAILABLE:
            logger.warning("BEIR not available. Skipping BEIR evaluation.")
            return {}

        if datasets is None:
            # Use a subset of BEIR datasets for faster evaluation
            datasets = ["scifact", "nfcorpus", "arguana"]

        results = {}

        for dataset in datasets:
            try:
                logger.info(f"Evaluating on BEIR dataset: {dataset}")

                # Download and load dataset
                url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
                data_path = util.download_and_unzip(url, f"datasets/{dataset}")

                # Load data
                corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

                # Create retriever
                retriever = BEIRRetriever(model_wrapper)

                # Retrieve and evaluate
                retrieval_results = retriever.retrieve(corpus, queries, top_k=100)

                # Evaluate
                evaluator = EvaluateRetrieval()
                ndcg, map_score, recall, precision = evaluator.evaluate(qrels, retrieval_results, [1, 3, 5, 10])

                results[dataset] = {
                    'NDCG@10': ndcg['NDCG@10'],
                    'MAP@10': map_score['MAP@10'],
                    'Recall@10': recall['Recall@10']
                }

                logger.info(f"{dataset} - NDCG@10: {ndcg['NDCG@10']:.4f}, MAP@10: {map_score['MAP@10']:.4f}")

            except Exception as e:
                logger.error(f"Failed to evaluate {dataset}: {e}")
                results[dataset] = {"error": str(e)}

        return results

    def evaluate_mteb(self, model_wrapper: CustomBGEWrapper, tasks: List[str] = None) -> Dict[str, Dict]:
        """Evaluate on MTEB tasks"""
        if not BEIR_AVAILABLE:
            logger.warning("MTEB not available. Skipping MTEB evaluation.")
            return {}

        if tasks is None:
            # Use a subset of MTEB tasks for faster evaluation
            tasks = ["STSBenchmark", "SummEval", "EmotionClassification"]

        results = {}
        evaluation = mteb.MTEB(tasks=tasks)

        try:
            # Create a sentence transformer compatible wrapper
            class MTEBWrapper:
                def __init__(self, custom_model):
                    self.custom_model = custom_model

                def encode(self, sentences, **kwargs):
                    return self.custom_model.encode(sentences, **kwargs)

            mteb_model = MTEBWrapper(model_wrapper)
            results = evaluation.run(mteb_model, output_folder="mteb_results")

        except Exception as e:
            logger.error(f"MTEB evaluation failed: {e}")
            results = {"error": str(e)}

        return results

In [12]:
print("="*60)
print("CUSTOM BGE MODEL WITH BEIR/MTEB EVALUATION")
print("="*60)

# Configuration
config = ModelConfig(
    model_name="BAAI/bge-small-en-v1.5",
    projection_dim=768,
    num_layers_to_average=4,
    pooling_strategy="mean",
    use_projection_head=True,
    dropout_rate=0.1,
    activation="gelu"
)

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = CustomBGEModel(config)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

CUSTOM BGE MODEL WITH BEIR/MTEB EVALUATION


tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]

Model parameters: 34,246,276
Trainable parameters: 34,246,276


In [13]:
# Create datasets
train_dataset = STSDataset(tokenizer)
val_dataset = STSDataset(tokenizer)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize trainer
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
trainer = CustomBGETrainer(model, tokenizer, device)

# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
num_training_steps = len(train_dataloader) * 2  # 2 epochs for demo
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0.1 * num_training_steps,
    num_training_steps=num_training_steps
)

README.md: 0.00B [00:00, ?B/s]

en/train-00000-of-00001.parquet:   0%|          | 0.00/470k [00:00<?, ?B/s]

en/test-00000-of-00001.parquet:   0%|          | 0.00/108k [00:00<?, ?B/s]

en/dev-00000-of-00001.parquet:   0%|          | 0.00/142k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5749 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1379 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/1500 [00:00<?, ? examples/s]

Using device: cuda


In [None]:
# Training loop
print("\n" + "="*50)
print("TRAINING")
print("="*50)

best_correlation = -1
for epoch in range(2):  # Reduced for demo
    logger.info(f"Epoch {epoch + 1}/2")

    train_loss = trainer.train_epoch(train_dataloader, optimizer, scheduler)
    logger.info(f"Training loss: {train_loss:.4f}")

    eval_results = trainer.evaluate_sts(val_dataloader)
    logger.info(f"STS Validation: Spearman={eval_results['spearman_correlation']:.4f}, "
                f"Pearson={eval_results['pearson_correlation']:.4f}")

    if eval_results['spearman_correlation'] > best_correlation:
        best_correlation = eval_results['spearman_correlation']
        torch.save(model.state_dict(), 'best_custom_bge_model.pt')
        logger.info(f"New best model saved! Spearman: {best_correlation:.4f}")

# Create model wrapper for evaluation
model_wrapper = CustomBGEWrapper(model, tokenizer, device)


TRAINING


In [None]:
# BEIR Evaluation
print("\n" + "="*50)
print("BEIR EVALUATION")
print("="*50)

beir_results = trainer.evaluate_beir(model_wrapper)
for dataset, metrics in beir_results.items():
    print(f"\n{dataset.upper()}:")
    if "error" not in metrics:
        for metric, score in metrics.items():
            print(f"  {metric}: {score:.4f}")
    else:
        print(f"  Error: {metrics['error']}")

# MTEB Evaluation
print("\n" + "="*50)
print("MTEB EVALUATION")
print("="*50)

mteb_results = trainer.evaluate_mteb(model_wrapper)
if "error" not in mteb_results:
    print("MTEB Results saved to mteb_results/")
    # Print summary if available
    for task, result in mteb_results.items():
        if isinstance(result, dict) and 'main_score' in result:
            print(f"{task}: {result['main_score']:.4f}")
else:
    print(f"MTEB Error: {mteb_results['error']}")

In [None]:
# Test different configurations
print("\n" + "="*50)
print("CONFIGURATION COMPARISON")
print("="*50)

configs_to_test = [
    {"pooling_strategy": "cls", "name": "CLS Pooling"},
    {"pooling_strategy": "attention", "name": "Attention Pooling"},
    {"use_projection_head": False, "name": "No Projection Head"},
    {"num_layers_to_average": 6, "name": "6 Layer Average"}
]

base_config = ModelConfig()
comparison_results = {}

In [None]:
for test_config in configs_to_test:
    print(f"\nTesting: {test_config['name']}")

    # Create new config and model
    new_config = ModelConfig(**{**base_config.__dict__,
                                **{k: v for k, v in test_config.items() if k != 'name'}})
    test_model = CustomBGEModel(new_config)
    test_trainer = CustomBGETrainer(test_model, tokenizer, device)

    # Quick STS evaluation
    sts_results = test_trainer.evaluate_sts(val_dataloader)
    comparison_results[test_config['name']] = sts_results['spearman_correlation']
    print(f"  Spearman Correlation: {sts_results['spearman_correlation']:.4f}")

    # Quick BEIR evaluation on one dataset
    if BEIR_AVAILABLE:
        test_wrapper = CustomBGEWrapper(test_model, tokenizer, device)
        beir_quick = test_trainer.evaluate_beir(test_wrapper, ["scifact"])
        if "scifact" in beir_quick and "error" not in beir_quick["scifact"]:
            print(f"  BEIR SciFact NDCG@10: {beir_quick['scifact']['NDCG@10']:.4f}")

In [None]:
# Final summary
print("\n" + "="*60)
print("FINAL RESULTS SUMMARY")
print("="*60)

print("\nConfiguration Comparison (STS Spearman Correlation):")
for name, score in sorted(comparison_results.items(), key=lambda x: x[1], reverse=True):
    print(f"  {name}: {score:.4f}")

if beir_results:
    print(f"\nBest BEIR Performance:")
    for dataset, metrics in beir_results.items():
        if "error" not in metrics:
            print(f"  {dataset}: NDCG@10={metrics.get('NDCG@10', 'N/A')}")

print(f"\nModel saved as: best_custom_bge_model.pt")
print("Evaluation complete!")