In [45]:
!pip install transformers torch sentence-transformers datasets -q
!pip install mteb faiss-cpu sentence-transformers -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m45.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m45.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m304.8/304.8 kB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [10]:
# ==============================================================================
# Step 2: Imports and the Corrected, Definitive Custom Model
# ==============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import spearmanr
import numpy as np
from typing import Optional, List, Dict, Any
import logging
from dataclasses import dataclass
import json
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
print("Libraries installed and imported successfully.")

Libraries installed and imported successfully.


In [2]:
@dataclass
class ModelConfig:
    """Configuration for custom BGE model"""
    model_name: str = "BAAI/bge-small-en-v1.5"  # Base BGE model
    hidden_size: int = 384
    projection_dim: int = 768
    num_layers_to_average: int = 4
    pooling_strategy: str = "mean"  # "cls", "mean", "attention"
    use_projection_head: bool = True
    dropout_rate: float = 0.1
    activation: str = "gelu"

In [3]:
class ProjectionHead(nn.Module):
    """Projection head with L2 normalization"""
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
                 activation: str = "gelu", dropout: float = 0.1):
        super().__init__()

        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            self._get_activation(activation),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def _get_activation(self, activation: str):
        activations = {
            "relu": nn.ReLU(),
            "gelu": nn.GELU(),
            "tanh": nn.Tanh(),
            "silu": nn.SiLU()
        }
        return activations.get(activation, nn.GELU())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        projected = self.projection(x)
        # L2 normalization
        return F.normalize(projected, p=2, dim=-1)

In [4]:
class AttentionPooling(nn.Module):
    """Attention-based pooling mechanism"""
    def __init__(self, hidden_size: int):
        super().__init__()
        self.attention_weights = nn.Linear(hidden_size, 1)

    def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        # hidden_states: [batch_size, seq_len, hidden_size]
        # attention_mask: [batch_size, seq_len]

        # Compute attention scores
        attention_scores = self.attention_weights(hidden_states).squeeze(-1)  # [batch_size, seq_len]

        # Mask out padding tokens
        attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))

        # Apply softmax
        attention_weights = F.softmax(attention_scores, dim=-1)  # [batch_size, seq_len]

        # Weighted sum
        pooled = torch.sum(hidden_states * attention_weights.unsqueeze(-1), dim=1)
        return pooled


In [5]:
class LayerwiseWeightedSum(nn.Module):
    """Weighted sum of multiple layers"""
    def __init__(self, num_layers: int, learnable: bool = True):
        super().__init__()
        self.num_layers = num_layers
        self.learnable = learnable

        if learnable:
            # Initialize with equal weights
            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
        else:
            self.register_buffer('layer_weights', torch.ones(num_layers) / num_layers)

    def forward(self, layer_outputs: List[torch.Tensor]) -> torch.Tensor:
        # layer_outputs: List of tensors [batch_size, seq_len, hidden_size]
        if self.learnable:
            weights = F.softmax(self.layer_weights, dim=0)
        else:
            weights = self.layer_weights

        weighted_sum = torch.zeros_like(layer_outputs[0])
        for i, layer_output in enumerate(layer_outputs):
            weighted_sum += weights[i] * layer_output

        return weighted_sum

In [6]:
class CustomBGEModel(nn.Module):
    """Custom BGE model with enhanced features"""
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        # Load base BGE model
        self.backbone = AutoModel.from_pretrained(config.model_name)
        self.backbone.config.output_hidden_states = True

        # Get actual hidden size from model
        actual_hidden_size = self.backbone.config.hidden_size

        # Layerwise weighted sum
        self.layerwise_pooling = LayerwiseWeightedSum(
            num_layers=config.num_layers_to_average,
            learnable=True
        )

        # Pooling strategy
        if config.pooling_strategy == "attention":
            self.attention_pooling = AttentionPooling(actual_hidden_size)

        # Projection head
        if config.use_projection_head:
            self.projection_head = ProjectionHead(
                input_dim=actual_hidden_size,
                hidden_dim=config.projection_dim,
                output_dim=config.projection_dim,
                activation=config.activation,
                dropout=config.dropout_rate
            )

    def _pool_embeddings(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """Apply pooling strategy"""
        if self.config.pooling_strategy == "cls":
            # Use CLS token (first token)
            return hidden_states[:, 0, :]

        elif self.config.pooling_strategy == "mean":
            # Mean pooling with attention mask
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
            sum_mask = input_mask_expanded.sum(1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            return sum_embeddings / sum_mask

        elif self.config.pooling_strategy == "attention":
            # Attention pooling
            return self.attention_pooling(hidden_states, attention_mask)

        else:
            raise ValueError(f"Unknown pooling strategy: {self.config.pooling_strategy}")

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        # Get outputs from backbone
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)

        # Get last N layers
        all_hidden_states = outputs.hidden_states
        last_n_layers = all_hidden_states[-self.config.num_layers_to_average:]

        # Apply layerwise weighted sum
        weighted_hidden_states = self.layerwise_pooling(last_n_layers)

        # Apply pooling
        pooled_embeddings = self._pool_embeddings(weighted_hidden_states, attention_mask)

        # Apply projection head if enabled
        if self.config.use_projection_head:
            embeddings = self.projection_head(pooled_embeddings)
        else:
            # Apply L2 normalization directly
            embeddings = F.normalize(pooled_embeddings, p=2, dim=-1)

        return embeddings

In [7]:
class STSBenchmarkDataset(Dataset):
    """STS Benchmark dataset for evaluation"""
    def __init__(self, data_path: str, tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self._load_data(data_path)

    def _load_data(self, data_path: str):
        """Load STS benchmark data"""
        # For demo purposes, creating synthetic data
        # In practice, load from actual STS benchmark files
        data = []

        # Synthetic data for demonstration
        samples = [
            ("The cat sat on the mat", "A cat was sitting on a mat", 4.5),
            ("I love pizza", "Pizza is my favorite food", 3.8),
            ("The weather is nice", "It's raining heavily", 1.2),
            ("Machine learning is fascinating", "AI and ML are interesting topics", 4.2),
            ("The dog ran quickly", "A fast dog was running", 4.0),
            ("Books are educational", "Reading helps you learn", 3.5),
            ("The sun is shining", "It's a cloudy day", 1.8),
            ("Programming is fun", "Coding can be enjoyable", 4.1)
        ] * 100  # Repeat to create larger dataset

        for sent1, sent2, score in samples:
            data.append({
                'sentence1': sent1,
                'sentence2': sent2,
                'score': score / 5.0  # Normalize to [0, 1]
            })

        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Tokenize sentences
        encoding1 = self.tokenizer(
            item['sentence1'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        encoding2 = self.tokenizer(
            item['sentence2'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids_1': encoding1['input_ids'].squeeze(0),
            'attention_mask_1': encoding1['attention_mask'].squeeze(0),
            'input_ids_2': encoding2['input_ids'].squeeze(0),
            'attention_mask_2': encoding2['attention_mask'].squeeze(0),
            'score': torch.tensor(item['score'], dtype=torch.float)
        }

In [8]:
class CustomBGETrainer:
    """Trainer for custom BGE model"""
    def __init__(self, model: CustomBGEModel, tokenizer, device: str = "cuda"):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device

    def compute_similarity_loss(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor,
                               scores: torch.Tensor) -> torch.Tensor:
        """Compute cosine similarity loss"""
        # Compute cosine similarity
        cos_sim = F.cosine_similarity(embeddings1, embeddings2, dim=-1)

        # MSE loss between cosine similarity and ground truth scores
        loss = F.mse_loss(cos_sim, scores)
        return loss

    def train_epoch(self, dataloader: DataLoader, optimizer, scheduler=None):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0

        for batch_idx, batch in enumerate(dataloader):
            # Move to device
            for key in batch:
                batch[key] = batch[key].to(self.device)

            # Forward pass
            embeddings1 = self.model(batch['input_ids_1'], batch['attention_mask_1'])
            embeddings2 = self.model(batch['input_ids_2'], batch['attention_mask_2'])

            # Compute loss
            loss = self.compute_similarity_loss(embeddings1, embeddings2, batch['score'])

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if scheduler:
                scheduler.step()

            total_loss += loss.item()

            if batch_idx % 50 == 0:
                logger.info(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

        return total_loss / len(dataloader)

    def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
        """Evaluate model performance"""
        self.model.eval()
        predictions = []
        ground_truth = []

        with torch.no_grad():
            for batch in dataloader:
                # Move to device
                for key in batch:
                    batch[key] = batch[key].to(self.device)

                # Forward pass
                embeddings1 = self.model(batch['input_ids_1'], batch['attention_mask_1'])
                embeddings2 = self.model(batch['input_ids_2'], batch['attention_mask_2'])

                # Compute cosine similarity
                cos_sim = F.cosine_similarity(embeddings1, embeddings2, dim=-1)

                predictions.extend(cos_sim.cpu().numpy())
                ground_truth.extend(batch['score'].cpu().numpy())

        # Compute Spearman correlation
        spearman_corr, _ = spearmanr(predictions, ground_truth)

        # Compute MSE
        mse = np.mean((np.array(predictions) - np.array(ground_truth)) ** 2)

        return {
            'spearman_correlation': spearman_corr,
            'mse': mse,
            'predictions': predictions[:10],  # First 10 predictions for inspection
            'ground_truth': ground_truth[:10]  # First 10 ground truth for inspection
        }

In [12]:
"""Main training and evaluation pipeline"""
# Configuration
config = ModelConfig(
    model_name="BAAI/bge-small-en-v1.5",
    projection_dim=768,
    num_layers_to_average=4,
    pooling_strategy="mean",  # Try "cls", "mean", or "attention"
    use_projection_head=True,
    dropout_rate=0.1,
    activation="gelu"
)

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = CustomBGEModel(config)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model parameters: 34,246,276
Trainable parameters: 34,246,276


In [13]:
# Create datasets
train_dataset = STSBenchmarkDataset("train.tsv", tokenizer)
val_dataset = STSBenchmarkDataset("dev.tsv", tokenizer)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize trainer
device = "cuda" if torch.cuda.is_available() else "cpu"
trainer = CustomBGETrainer(model, tokenizer, device)

In [14]:
# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
num_training_steps = len(train_dataloader) * 3  # 3 epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0.1 * num_training_steps,
    num_training_steps=num_training_steps
)

In [15]:
# Training loop
best_correlation = -1
for epoch in range(3):
    logger.info(f"Epoch {epoch + 1}/3")

    # Train
    train_loss = trainer.train_epoch(train_dataloader, optimizer, scheduler)
    logger.info(f"Training loss: {train_loss:.4f}")

    # Evaluate
    eval_results = trainer.evaluate(val_dataloader)
    logger.info(f"Validation results: {eval_results}")

    # Save best model
    if eval_results['spearman_correlation'] > best_correlation:
        best_correlation = eval_results['spearman_correlation']
        torch.save(model.state_dict(), 'best_custom_bge_model.pt')
        logger.info(f"New best model saved! Spearman correlation: {best_correlation:.4f}")

In [16]:
# Test different configurations
print("\n" + "="*50)
print("TESTING DIFFERENT CONFIGURATIONS")
print("="*50)

configs_to_test = [
    {"pooling_strategy": "cls", "name": "CLS Pooling"},
    {"pooling_strategy": "mean", "name": "Mean Pooling"},
    {"pooling_strategy": "attention", "name": "Attention Pooling"},
    {"use_projection_head": False, "name": "Without Projection Head"},
    {"num_layers_to_average": 6, "name": "6 Layer Average"}
]

base_config = ModelConfig()
results = {}


TESTING DIFFERENT CONFIGURATIONS


In [17]:
for test_config in configs_to_test:
    print(f"\nTesting: {test_config['name']}")

    # Create new config
    new_config = ModelConfig(**{**base_config.__dict__, **{k: v for k, v in test_config.items() if k != 'name'}})

    # Create and evaluate model
    test_model = CustomBGEModel(new_config)
    test_trainer = CustomBGETrainer(test_model, tokenizer, device)

    # Quick evaluation (no training)
    eval_results = test_trainer.evaluate(val_dataloader)
    results[test_config['name']] = eval_results['spearman_correlation']
    print(f"Spearman Correlation: {eval_results['spearman_correlation']:.4f}")

# Print summary
print("\n" + "="*50)
print("CONFIGURATION COMPARISON")
print("="*50)
for name, correlation in results.items():
    print(f"{name}: {correlation:.4f}")



Testing: CLS Pooling
Spearman Correlation: 0.2619

Testing: Mean Pooling
Spearman Correlation: 0.4048

Testing: Attention Pooling
Spearman Correlation: 0.2857

Testing: Without Projection Head
Spearman Correlation: 0.3810

Testing: 6 Layer Average
Spearman Correlation: 0.2857

CONFIGURATION COMPARISON
CLS Pooling: 0.2619
Mean Pooling: 0.4048
Attention Pooling: 0.2857
Without Projection Head: 0.3810
6 Layer Average: 0.2857
