# SPLADE v2 PT-BR - Training Notebook

This notebook provides a complete workflow for training the SPLADE model for Portuguese (PT-BR) text retrieval.

## Overview

This notebook trains a SPLADE (Sparse Lexical AnD Expansion) model based on BERTimbau for Portuguese information retrieval. The model learns to expand queries and documents with contextually relevant terms while maintaining sparsity (~99%).

## Features

- **Base Model**: `neuralmind/bert-base-portuguese-cased` (BERTimbau)
- **Training Dataset**: mMARCO Portuguese (`unicamp-dl/mmarco`)
- **Validation Dataset**: mRobust (`unicamp-dl/mrobust`)
- **Training Iterations**: 150,000
- **Output**: Sparse vectors with ~100-150 active dimensions per vector

## Requirements

- **GPU**: Recommended (T4, A100, or similar) for faster training. CPU is supported but will be significantly slower.
- **Memory**: At least 16GB RAM recommended
- **Disk Space**: ~5GB for datasets and checkpoints

## Notebook Structure

1. **Setup**: Install dependencies and apply compatibility patches
2. **Data Download**: Download and prepare training/validation datasets
3. **Configuration**: Generate Hydra configuration files
4. **Training**: Execute model training with progress monitoring

**Note**: This notebook includes fixes for library incompatibilities (AdamW/Hydra) and file dependencies that were resolved during development.

## Step 0: Install Dependencies

Install required Python packages. Uncomment the lines below if running in a fresh environment.

**Note**: If using a pre-configured environment, these may already be installed.

In [None]:
# Standard library imports
import os
import sys
import json
import shutil
import subprocess
from collections import defaultdict
from pathlib import Path

# Third-party imports
from huggingface_hub import hf_hub_download

print("‚úÖ All required libraries imported successfully")
print("\nüì¶ Libraries used in this notebook:")
print("  - os: File and directory operations")
print("  - sys: System-specific parameters and functions")
print("  - json: JSON file handling")
print("  - shutil: File operations (copy, move)")
print("  - subprocess: Execute external processes")
print("  - collections.defaultdict: Dictionary with default values")
print("  - huggingface_hub: Download datasets from HuggingFace Hub")


In [None]:
# !pip install pytrec_eval
# !pip install git+https://github.com/leobavila/splade.git -q
# !pip install hydra-core --upgrade

## Step 1: Set Working Directory

This cell ensures the notebook runs from the project root directory, regardless of where the notebook file is located.

In [None]:
# Project root marker (file that should exist in the project root)
PROJECT_MARKER = "pyproject.toml"

def find_project_root(start_path=None):
    """Find the project root by looking for pyproject.toml."""
    if start_path is None:
        start_path = Path.cwd()
    else:
        start_path = Path(start_path)
    
    # Check current directory and parent directories
    current = start_path.resolve()
    
    while current != current.parent:  # Stop at filesystem root
        # Check if pyproject.toml exists in this directory
        if (current / PROJECT_MARKER).exists():
            return current
        current = current.parent
    
    # If not found, try going up from notebooks/ if we're in it
    if start_path.name == "notebooks":
        parent = start_path.parent
        if (parent / PROJECT_MARKER).exists():
            return parent
    
    # Fallback: assume current directory is project root
    return start_path.resolve()

# Find and change to project root
current_dir = Path.cwd()
project_root = find_project_root()

if project_root != current_dir:
    os.chdir(project_root)
    print(f"‚úÖ Changed to project root: {project_root}")
    print(f"   (was in: {current_dir})")
else:
    print(f"‚úÖ Already in project root: {project_root}")

print(f"üìÅ Current working directory: {os.getcwd()}")

# Verify we're in the right place
if not (Path.cwd() / PROJECT_MARKER).exists():
    print(f"‚ö†Ô∏è  Warning: {PROJECT_MARKER} not found. Make sure you're in the project root!")

## Step 3: Clone Repository and Apply Compatibility Patches

This step:
1. Clones the SPLADE repository if it doesn't exist
2. Applies a compatibility patch for AdamW optimizer (fixes import error between old and new Transformers versions)

**Why this is needed**: The original code uses `transformers.optimization.AdamW`, which was moved to `torch.optim.AdamW` in newer versions of Transformers.

In [None]:
# Clone repository if it doesn't exist
if not os.path.exists("splade"):
    print("üì¶ Cloning SPLADE repository...")
    os.system("git clone https://github.com/leobavila/splade.git")
    print("‚úÖ Repository cloned")
else:
    print("‚úÖ Repository already exists")

# Apply compatibility patch for AdamW optimizer
# Fixes import error between old and new Transformers versions
file_path = "splade/splade/optim/bert_optim.py"
if os.path.exists(file_path):
    with open(file_path, "r") as f:
        content = f.read()

    # Replace old import with new compatible imports
    new_content = content.replace(
        "from transformers.optimization import AdamW, get_linear_schedule_with_warmup",
        "from transformers import get_linear_schedule_with_warmup; from torch.optim import AdamW"
    )

    with open(file_path, "w") as f:
        f.write(new_content)
    print("‚úÖ Patch applied: bert_optim.py fixed")
else:
    print("‚ùå Error: bert_optim.py file not found")

‚úÖ Patch aplicado: bert_optim.py corrigido.


## Step 4: Download and Prepare Datasets

This step downloads and prepares the training and validation datasets:

- **mMARCO** (`unicamp-dl/mmarco`): Training dataset with Portuguese queries, documents, and triplets
- **mRobust** (`unicamp-dl/mrobust`): Validation dataset with Portuguese queries, documents, and relevance judgments

The datasets are downloaded from HuggingFace Hub and organized into the SPLADE directory structure.

**Note**: Downloads are skipped if files already exist (checks file size > 100 bytes).

In [None]:
print("‚è≥ Downloading public datasets from HuggingFace Hub... (This may take several minutes)")

# Create base directories (using relative paths)
data_dir = "./data"
os.makedirs(f"{data_dir}/m_marco", exist_ok=True)
os.makedirs(f"{data_dir}/m_robust", exist_ok=True)

# Create SPLADE destination directories
os.makedirs("splade/data/pt/triplets", exist_ok=True)
os.makedirs("splade/data/pt/val_retrieval/collection", exist_ok=True)
os.makedirs("splade/data/pt/val_retrieval/queries", exist_ok=True)

def download_from_hf(repo_id, filename, output_path, description):
    """Download file from HuggingFace Hub"""
    # Skip if file already exists and has valid size (> 100 bytes)
    if os.path.exists(output_path) and os.path.getsize(output_path) > 100:
        size_mb = os.path.getsize(output_path) / (1024 * 1024)
        print(f"‚úÖ {description} already exists ({size_mb:.1f} MB), skipping download.")
        return True
    
    print(f"üì• Downloading {description}...")
    try:
        downloaded_path = hf_hub_download(
            repo_id=repo_id,
            filename=filename,
            repo_type="dataset",
            local_dir=None
        )
        # Copy to desired destination
        shutil.copy(downloaded_path, output_path)
        if os.path.exists(output_path) and os.path.getsize(output_path) > 100:
            size_mb = os.path.getsize(output_path) / (1024 * 1024)
            print(f"‚úÖ {description} downloaded successfully ({size_mb:.1f} MB)")
            return True
        else:
            print(f"‚ùå {description} failed: file too small or empty")
            return False
    except Exception as e:
        print(f"‚ùå Error downloading {description}: {e}")
        return False

# --- mMARCO (Training Dataset) ---
print("\nüì¶ Downloading mMARCO datasets...")
download_from_hf(
    "unicamp-dl/mmarco",
    "data/google/queries/train/portuguese_queries.train.tsv",
    f"{data_dir}/m_marco/queries_train.tsv",
    "queries_train.tsv"
)

download_from_hf(
    "unicamp-dl/mmarco",
    "data/google/collections/portuguese_collection.tsv",
    f"{data_dir}/m_marco/corpus.tsv",
    "corpus.tsv"
)

download_from_hf(
    "unicamp-dl/mmarco",
    "data/triples.train.ids.small.tsv",
    f"{data_dir}/m_marco/triples.train.ids.small.tsv",
    "triples.train.ids.small.tsv"
)

# Verify mMARCO files were downloaded correctly
mmarco_files = ["queries_train.tsv", "corpus.tsv", "triples.train.ids.small.tsv"]
if all(os.path.exists(f"{data_dir}/m_marco/{f}") and os.path.getsize(f"{data_dir}/m_marco/{f}") > 100 
       for f in mmarco_files):
    # Copy to SPLADE structure
    shutil.copy(f"{data_dir}/m_marco/corpus.tsv", "splade/data/pt/triplets/corpus.tsv")
    shutil.copy(f"{data_dir}/m_marco/queries_train.tsv", "splade/data/pt/triplets/queries_train.tsv")
    shutil.copy(f"{data_dir}/m_marco/triples.train.ids.small.tsv", "splade/data/pt/triplets/raw.tsv")
    print("‚úÖ mMARCO files copied to SPLADE structure")
else:
    print("‚ùå Error: Some mMARCO files were not downloaded correctly")
    print("Checking files:")
    for f in mmarco_files:
        path = f"{data_dir}/m_marco/{f}"
        if os.path.exists(path):
            size_mb = os.path.getsize(path) / (1024 * 1024)
            print(f"  - {f}: {size_mb:.1f} MB")
        else:
            print(f"  - {f}: NOT FOUND")

# --- mRobust (Validation Dataset) ---
print("\nüì¶ Downloading mRobust datasets...")
download_from_hf(
    "unicamp-dl/mrobust",
    "data/queries/portuguese_queries.tsv",
    f"{data_dir}/m_robust/queries.tsv",
    "mrobust queries.tsv"
)

download_from_hf(
    "unicamp-dl/mrobust",
    "data/collections/portuguese_collection.tsv",
    f"{data_dir}/m_robust/corpus.tsv",
    "mrobust corpus.tsv"
)

download_from_hf(
    "unicamp-dl/mrobust",
    "qrels.robust04.txt",
    f"{data_dir}/m_robust/qrels.robust04.txt",
    "qrels.robust04.txt"
)

# Verify and copy mRobust files
mrobust_files = ["queries.tsv", "corpus.tsv", "qrels.robust04.txt"]
if all(os.path.exists(f"{data_dir}/m_robust/{f}") and os.path.getsize(f"{data_dir}/m_robust/{f}") > 100 
       for f in mrobust_files):
    shutil.copy(f"{data_dir}/m_robust/corpus.tsv", "splade/data/pt/val_retrieval/collection/raw.tsv")
    shutil.copy(f"{data_dir}/m_robust/queries.tsv", "splade/data/pt/val_retrieval/queries/raw.tsv")
    print("‚úÖ mRobust files copied to SPLADE structure")
else:
    print("‚ùå Error: Some mRobust files were not downloaded correctly")
    print("Checking files:")
    for f in mrobust_files:
        path = f"{data_dir}/m_robust/{f}"
        if os.path.exists(path):
            size_mb = os.path.getsize(path) / (1024 * 1024)
            print(f"  - {f}: {size_mb:.1f} MB")
        else:
            print(f"  - {f}: NOT FOUND")

print("\n‚úÖ Download process completed.")

‚è≥ Baixando datasets p√∫blicos usando HuggingFace Hub... (Pode levar alguns minutos)

üì¶ Baixando datasets mMARCO...
‚úÖ queries_train.tsv j√° existe (39281063 bytes), pulando download.
‚úÖ corpus.tsv j√° existe (3431011785 bytes), pulando download.
‚úÖ triples.train.ids.small.tsv j√° existe (905211990 bytes), pulando download.
‚úÖ Arquivos mMARCO copiados para estrutura SPLADE

üì¶ Baixando datasets mRobust...
‚úÖ mrobust queries.tsv j√° existe (28418 bytes), pulando download.
‚úÖ mrobust corpus.tsv j√° existe (1914138316 bytes), pulando download.
‚úÖ qrels.robust04.txt j√° existe (6543541 bytes), pulando download.
‚úÖ Arquivos mRobust copiados para estrutura SPLADE

‚úÖ Processo de download conclu√≠do.


In [None]:
# Convert TREC-format QRELS to JSON
# Format: query_id 0 doc_id relevance_score
qrel = defaultdict(dict)
data_dir = "./data"
qrel_path = f"{data_dir}/m_robust/qrels.robust04.txt"

if os.path.exists(qrel_path):
    with open(qrel_path, 'r') as file:
        for line in file:
            fields = line.split()
            if len(fields) >= 4:
                q_id = fields[0]
                doc_id = fields[2]
                rel = fields[3]
                qrel[q_id][doc_id] = int(rel)

    # Save as JSON for SPLADE evaluation
    output_path = 'splade/data/pt/val_retrieval/qrel.json'
    with open(output_path, 'w') as file:
        json.dump(qrel, file)
    print(f"‚úÖ QREL converted to JSON: {output_path}")
    print(f"   Total queries: {len(qrel)}")
else:
    print(f"‚ùå Error: qrels.robust04.txt not found at {qrel_path}")

‚úÖ QREL convertido para JSON.


In [None]:
# Create directory structure for configuration files
os.makedirs("splade/conf/train/config", exist_ok=True)
os.makedirs("splade/conf/train/data", exist_ok=True)
os.makedirs("splade/conf/train/model", exist_ok=True)
os.makedirs("splade/conf/index", exist_ok=True)
os.makedirs("splade/conf/retrieve_evaluate", exist_ok=True)
os.makedirs("splade/conf/flops", exist_ok=True)

# 5.1 Model Configuration
with open("splade/conf/train/model/splade_bertimbau_base.yaml", "w") as f:
    f.write("""
_target_: splade.models.transformer_rep.Splade
# Note: The actual parameter will be read from init_dict below
model_type_or_dir: neuralmind/bert-base-portuguese-cased
    """)

# 5.2 Data Configuration
with open("splade/conf/train/data/pt.yaml", "w") as f:
    f.write(f"""
# @package _global_
data:
    type: triplets
    TRAIN_DATA_DIR: {os.getcwd()}/splade/data/pt/triplets
    VALIDATION_DATA_DIR: {os.getcwd()}/splade/data/pt/val_retrieval
    QREL_PATH: {os.getcwd()}/splade/data/pt/val_retrieval/qrel.json
    """)

# 5.3 Training Configuration
# IMPORTANT: loss parameter is required to avoid ConfigKeyError
with open("splade/conf/train/config/splade_pt.yaml", "w") as f:
    f.write("""
# @package _global_
config:
    lr: 2e-5
    seed: 123
    gradient_accumulation_steps: 1
    weight_decay: 0.01
    validation_metrics: [MRR@10]
    pretrained_no_yaml_config: false
    nb_iterations: 150000
    train_batch_size: 32
    eval_batch_size: 32
    index_retrieval_batch_size: 32
    record_frequency: 1000
    train_monitoring_freq: 500
    warmup_steps: 6000
    max_length: 256
    fp16: true
    matching_type: splade
    monitoring_ckpt: true
    tokenizer_type: neuralmind/bert-base-portuguese-cased

    # Required loss parameter (fixes ConfigKeyError)
    loss: InBatchPairwiseNLL

    # Required keys for Hydra (will be overridden at runtime)
    checkpoint_dir: ""
    index_dir: ""
    out_dir: ""

    regularization:
        FLOPS:
            lambda_q: 0.0003
            lambda_d: 0.0001
            T: 50000
    """)

# 5.4 Main Configuration
with open("splade/conf/config_splade_pt.yaml", "w") as f:
    f.write("""
defaults:
  - train/data: pt
  - train/model: splade_bertimbau_base
  - train/config: splade_pt
  - index: pt
  - retrieve_evaluate: pt
  - flops: pt
  - _self_

# init_dict with previous corrections
init_dict:
  model_type_or_dir: neuralmind/bert-base-portuguese-cased
  fp16: true

hydra:
  run:
    dir: experiments/pt/out
  job:
    chdir: true
    """)

# 5.5 Placeholder Configurations (for indexing and evaluation)
with open("splade/conf/index/pt.yaml", "w") as f: 
    f.write("# Placeholder")
with open("splade/conf/retrieve_evaluate/pt.yaml", "w") as f: 
    f.write("# Placeholder")
with open("splade/conf/flops/pt.yaml", "w") as f: 
    f.write("# Placeholder")

print("‚úÖ Configuration files created successfully (loss: InBatchPairwiseNLL included).")

‚úÖ Configura√ß√µes recriadas com sucesso (loss: InBatchPairwiseNLL adicionado).


In [None]:
# Configure environment
os.environ['PYTHONPATH'] = os.environ.get('PYTHONPATH', '') + ":" + os.path.join(os.getcwd(), 'splade')
os.environ['SPLADE_CONFIG_NAME'] = "config_splade_pt.yaml"
os.environ['PYTHONUNBUFFERED'] = '1'  # Disable buffering to see real-time logs

print("üöÄ Starting training... Follow the logs below.")
print("Note: Ignore warnings about 'Unable to register cuFFT/cuDNN' from TensorFlow/JAX.")
print("=" * 80)
print()

# Execute training with real-time output using subprocess
cmd = [
    sys.executable,  # Use Python from virtual environment
    '-m', 'splade.train_from_triplets_ids',
    'config.checkpoint_dir=experiments/pt/checkpoint',
    'config.index_dir=experiments/pt/index',
    'config.out_dir=experiments/pt/out'
]

# Change to splade directory
original_dir = os.getcwd()
os.chdir('splade')

try:
    # Execute with real-time output
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1,  # Line buffered
        env=os.environ.copy()
    )
    
    # Print output in real-time
    for line in process.stdout:
        print(line, end='', flush=True)
    
    process.wait()
    
    if process.returncode != 0:
        print(f"\n‚ùå Training finished with exit code: {process.returncode}")
    else:
        print("\n‚úÖ Training completed successfully!")
        print("\nüìÅ Checkpoints saved in: experiments/pt/checkpoint/")
        print("üìä Training logs in: experiments/pt/out/")
        
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted by user.")
    print("üí° Training can be resumed from the last checkpoint.")
    if 'process' in locals():
        process.terminate()
except Exception as e:
    print(f"\n‚ùå Error during training: {e}")
    import traceback
    traceback.print_exc()
finally:
    os.chdir(original_dir)

üöÄ Iniciando treinamento... Acompanhe os logs abaixo.
Nota: Ignorar avisos 'Unable to register cuFFT/cuDNN' do TensorFlow/JAX.

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
data:
  type: triplets
  TRAIN_DATA_DIR: /home/user/Projects/SPLADE-PT-BR/splade/data/pt/triplets
  VALIDATION_DATA_DIR: /home/user/Projects/SPLADE-PT-BR/splade/data/pt/val_retrieval
  QREL_PATH: /home/user/Projects/SPLADE-PT-BR/splade/data/pt/val_retrieval/qrel.json
train:
  model:
    _target_: splade.models.transformer_rep.Splade
    model_type_or_dir: neuralmind/bert-base-portuguese-cased
config:
  lr: 2.0e-05
  seed: 123
  gradient_accumulation_steps: 4
  weight_decay: 0.01
  validation_metrics:
  - MRR@10
  pretrained_no_yaml_config: false
  nb_iterations: 150000
  train_batch_size: 8
  eval_batch_size: 16
  index_retrieval_batch_size: 16
  record_frequen