In [1]:
import json
import logging
import math
import os
from datetime import datetime
from random import shuffle
from tqdm import tqdm
import pickle

import datasets
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from huggingface_hub import HfFolder, Repository, create_repo
from pyserini.search.lucene import LuceneSearcher

def load_miracl_data(lang='fr'):
    """Load MIRACL dataset for given language"""
    query_judgements = datasets.load_dataset(
        "miracl/miracl", 
        lang, 
        split='train',
        cache_dir="hf_datasets_cache"
    )
    return query_judgements

def get_cache_path(lang, num_negatives):
    """Generate cache file path"""
    cache_dir = "miracl_cache"
    os.makedirs(cache_dir, exist_ok=True)
    return os.path.join(cache_dir, f"train_samples_{lang}_neg{num_negatives}.pkl")

def save_to_cache(train_samples, cache_path):
    """Save processed samples to cache"""
    logging.info(f"Saving processed samples to {cache_path}")
    with open(cache_path, 'wb') as f:
        pickle.dump(train_samples, f)

def load_from_cache(cache_path):
    """Load processed samples from cache"""
    logging.info(f"Loading processed samples from {cache_path}")
    with open(cache_path, 'rb') as f:
        return pickle.load(f)

def create_training_samples_with_hard_negatives(dataset, lang='fr', num_negatives=30, use_cache=True):
    """Create training samples with hard negative mining using BM25, with caching"""
    cache_path = get_cache_path(lang, num_negatives)
    
    # Try to load from cache first
    if use_cache and os.path.exists(cache_path):
        return load_from_cache(cache_path)
    
    logging.info("Cache not found or disabled. Creating new training samples...")
    
    # Save intermediate results in case of runtime failure
    intermediate_cache_path = cache_path + ".partial"
    intermediate_save_frequency = 1000  # Save every 1000 queries
    
    logging.info("Initializing BM25 searcher")
    searcher = LuceneSearcher.from_prebuilt_index(f'miracl-v1.0-{lang}')
    searcher.set_language(lang)
    
    train_samples = []
    
    logging.info("Creating training samples with hard negatives")
    for idx, data in enumerate(tqdm(dataset)):
        query = data['query']
        positives = data['positive_passages']
        negatives = data['negative_passages']
        
        # Get document IDs for filtering
        positive_ids = [p['docid'] for p in positives]
        negative_ids = [p['docid'] for p in negatives]
        
        # Search using BM25 to get hard negatives
        hits = searcher.search(query, k=100)
        bm25_negatives = []
        for hit in hits:
            info = json.loads(hit.lucene_document.get('raw'))
            if info['docid'] not in positive_ids and info['docid'] not in negative_ids:
                bm25_negatives.append(info)
        
        # Combine original negatives with BM25 negatives and shuffle
        all_negatives = negatives + bm25_negatives
        shuffle(all_negatives)
        selected_negatives = all_negatives[:num_negatives]
        
        # Create positive examples
        for pos in positives:
            train_samples.append(
                InputExample(
                    texts=[query, pos['text']],
                    label=1.0
                )
            )
        
        # Create negative examples
        for neg in selected_negatives:
            train_samples.append(
                InputExample(
                    texts=[query, neg['text']],
                    label=0.0
                )
            )
        
        # Save intermediate results
        if (idx + 1) % intermediate_save_frequency == 0:
            logging.info(f"Saving intermediate results after {idx + 1} queries...")
            save_to_cache(train_samples, intermediate_cache_path)
    
    # Save final results to cache
    save_to_cache(train_samples, cache_path)
    
    # Clean up intermediate cache if it exists
    if os.path.exists(intermediate_cache_path):
        os.remove(intermediate_cache_path)
    
    return train_samples

def push_to_hub(model_path, repo_name, username, lang):
    """Push the fine-tuned model to HuggingFace Hub"""
    full_repo_name = f"{username}/{repo_name}"
    
    # Create repository
    logging.info(f"Creating repository: {full_repo_name}")
    create_repo(full_repo_name, exist_ok=True)
    
    # Create model card
    model_card = f"""
# MIRACL Cross-Encoder ({lang})

This model is a fine-tuned version of [antoinelouis/crossencoder-camembert-base-mmarcoFR](https://huggingface.co/antoinelouis/crossencoder-camembert-base-mmarcoFR) on the MIRACL dataset for {lang} language. It uses hard negative mining with BM25 for better training data.

## Training
- The model was trained on MIRACL {lang} dataset
- Hard negative mining was performed using BM25
- For each query, we used all positive passages and up to 30 negative passages (combination of original negatives and BM25 hard negatives)

## Usage
```python
from sentence_transformers.cross_encoder import CrossEncoder

model = CrossEncoder("{full_repo_name}")
scores = model.predict([["query", "document_text"]])
```
    """
    
    with open(os.path.join(model_path, "README.md"), "w") as f:
        f.write(model_card)
    
    # Clone the repo
    repo = Repository(
        local_dir=model_path,
        clone_from=full_repo_name,
        use_auth_token=True
    )
    
    # Push to hub
    logging.info("Pushing model to HuggingFace Hub...")
    repo.push_to_hub()
    
    logging.info(f"Model successfully pushed to: https://huggingface.co/{full_repo_name}")

def main():
    # Setup logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(f'fine_tuning_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
            logging.StreamHandler()
        ]
    )

    # Parameters
    lang = 'fr'  # language code
    train_batch_size = 16
    num_epochs = 2
    num_negatives = 30  # number of negative examples per query
    use_cache = True  # whether to use cached dataset
    model_save_path = f"output/training_miracl_{lang}-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    
    # HuggingFace parameters
    username = "azat-serikbayev"  # Replace with your HF username
    repo_name = f"crossencoder-camembert-miracl-{lang}"

    # Load MIRACL dataset
    logging.info("Loading MIRACL dataset")
    dataset = load_miracl_data(lang=lang)
    
    # Create or load training samples with hard negatives
    train_samples = create_training_samples_with_hard_negatives(
        dataset,
        lang=lang,
        num_negatives=num_negatives,
        use_cache=use_cache
    )
    
    logging.info(f"Working with {len(train_samples)} training samples")
    
    # Create data loader
    train_dataloader = DataLoader(
        train_samples,
        shuffle=True,
        batch_size=train_batch_size
    )
    
    # Calculate warmup steps
    warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)
    logging.info(f"Warmup-steps: {warmup_steps}")

    # Load base model
    model = CrossEncoder(
        "antoinelouis/crossencoder-camembert-base-mmarcoFR",
        num_labels=1
    )

    # Train the model
    logging.info("Starting training")
    model.fit(
        train_dataloader=train_dataloader,
        evaluator=None,
        epochs=num_epochs,
        warmup_steps=warmup_steps,
        output_path=model_save_path
    )

    model.save(model_save_path)

    return model

    # # Push model to HuggingFace Hub
    # push_to_hub(model_save_path, repo_name, username, lang)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
finetuned_model = main()

Iteration: 100%|██████████| 2289/2289 [38:41<00:00,  1.01s/it]
Epoch: 100%|██████████| 1/1 [38:41<00:00, 2321.10s/it]
