# SPLADE v2 PT-BR - Treinamento Corrigido

Este notebook cont√©m todas as corre√ß√µes necess√°rias para rodar o treinamento do SPLADE em 2025, resolvendo incompatibilidades de bibliotecas (AdamW/Hydra) e depend√™ncias de arquivos.

**Aten√ß√£o:** Certifique-se de usar um Runtime com GPU (T4 ou A100).

In [1]:
# 1. Instala√ß√£o de Bibliotecas
!pip install pytrec_eval
!pip install git+https://github.com/leobavila/splade.git -q
!pip install hydra-core --upgrade

Collecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytrec_eval
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
  Created wheel for pytrec_eval: filename=pytrec_eval-0.5-cp312-cp312-linux_x86_64.whl size=309354 sha256=f9393acd84cf0f3805d28ffb1547d96bd4282179fa9daf1ee4777ec1f62bfcf6
  Stored in directory: /root/.cache/pip/wheels/c6/4a/9e/e17f9ea004e1c221bd0ff384732285211c4917b790d598ea51
Successfully built pytrec_eval
Installing collected packages: pytrec_eval
Successfully installed pytrec_eval-0.5
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m70.3/70.3 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [2]:
# 2. Clonagem e Patch de Corre√ß√£o (AdamW)
import os

# Clona o reposit√≥rio se n√£o existir
if not os.path.exists("splade"):
    !git clone https://github.com/leobavila/splade.git

# Corrige o erro de importa√ß√£o do AdamW (Transformers antigo vs novo)
file_path = "splade/splade/optim/bert_optim.py"
if os.path.exists(file_path):
    with open(file_path, "r") as f:
        content = f.read()

    new_content = content.replace(
        "from transformers.optimization import AdamW, get_linear_schedule_with_warmup",
        "from transformers import get_linear_schedule_with_warmup; from torch.optim import AdamW"
    )

    with open(file_path, "w") as f:
        f.write(new_content)
    print("‚úÖ Patch aplicado: bert_optim.py corrigido.")
else:
    print("‚ùå Erro: Arquivo bert_optim.py n√£o encontrado.")

Cloning into 'splade'...
remote: Enumerating objects: 495, done.[K
remote: Counting objects: 100% (244/244), done.[K
remote: Compressing objects: 100% (108/108), done.[K
remote: Total 495 (delta 170), reused 136 (delta 136), pack-reused 251 (from 1)[K
Receiving objects: 100% (495/495), 3.08 MiB | 4.71 MiB/s, done.
Resolving deltas: 100% (238/238), done.
‚úÖ Patch aplicado: bert_optim.py corrigido.


In [3]:
# 3. Download e Prepara√ß√£o dos Datasets (mMARCO e mRobust)
import shutil

print("‚è≥ Baixando datasets p√∫blicos... (Pode levar 2-3 minutos)")

# Criar pastas base
os.makedirs("/content/data/m_marco", exist_ok=True)
os.makedirs("/content/data/m_robust", exist_ok=True)


# Criar pastas de destino do SPLADE
os.makedirs("splade/data/pt/triplets", exist_ok=True)
os.makedirs("splade/data/pt/val_retrieval/collection", exist_ok=True)
os.makedirs("splade/data/pt/val_retrieval/queries", exist_ok=True)

# --- mMARCO (Treino) ---
!wget -q https://huggingface.co/datasets/unicamp-dl/mmarco/resolve/main/data/google/queries/train/portuguese.tsv -O /content/data/m_marco/queries_train.tsv
!wget -q https://huggingface.co/datasets/unicamp-dl/mmarco/resolve/main/data/google/collections/portuguese.tsv -O /content/data/m_marco/corpus.tsv
!wget -q https://huggingface.co/datasets/unicamp-dl/mmarco/resolve/main/data/triples.train.ids.small.tsv -O /content/data/m_marco/triples.train.ids.small.tsv

# Copiar para estrutura SPLADE
shutil.copy("/content/data/m_marco/corpus.tsv", "splade/data/pt/triplets/corpus.tsv")
shutil.copy("/content/data/m_marco/queries_train.tsv", "splade/data/pt/triplets/queries_train.tsv")
shutil.copy("/content/data/m_marco/triples.train.ids.small.tsv", "splade/data/pt/triplets/raw.tsv")

# --- mRobust (Valida√ß√£o) ---
!wget -q https://huggingface.co/datasets/unicamp-dl/mrobust/resolve/main/data/mrobust/queries.tsv -O /content/data/m_robust/queries.tsv
!wget -q https://huggingface.co/datasets/unicamp-dl/mrobust/resolve/main/data/mrobust/corpus.tsv -O /content/data/m_robust/corpus.tsv
!wget -q https://huggingface.co/datasets/unicamp-dl/mrobust/resolve/main/data/mrobust/qrels.robust04.txt -O /content/data/m_robust/qrels.robust04.txt

# Copiar para estrutura SPLADE (Valida√ß√£o)
shutil.copy("/content/data/m_robust/corpus.tsv", "splade/data/pt/val_retrieval/collection/raw.tsv")
shutil.copy("/content/data/m_robust/queries.tsv", "splade/data/pt/val_retrieval/queries/raw.tsv")

print("‚úÖ Datasets baixados e organizados.")

‚è≥ Baixando datasets p√∫blicos... (Pode levar 2-3 minutos)
‚úÖ Datasets baixados e organizados.


In [4]:
# 4. Converter QRELS para JSON
import json
from collections import defaultdict

qrel = defaultdict(dict)
qrel_path = "/content/data/m_robust/qrels.robust04.txt"

if os.path.exists(qrel_path):
    with open(qrel_path, 'r') as file:
        for line in file:
            fields = line.split()
            if len(fields) >= 4:
                q_id = fields[0]
                doc_id = fields[2]
                rel = fields[3]
                qrel[q_id][doc_id] = int(rel)

    with open('splade/data/pt/val_retrieval/qrel.json', 'w') as file:
        json.dump(qrel, file)
    print("‚úÖ QREL convertido para JSON.")
else:
    print("‚ùå Erro: qrels.robust04.txt n√£o encontrado.")

‚úÖ QREL convertido para JSON.


In [5]:
# 5. Gerar Arquivos de Configura√ß√£o (CR√çTICO: Inclus√£o do par√¢metro 'loss')
# Corre√ß√£o: Adicionado 'loss: InBatchPairwiseNLL' para corrigir o ConfigKeyError.

import os

# Criar estrutura de pastas
os.makedirs("splade/conf/train/config", exist_ok=True)
os.makedirs("splade/conf/train/data", exist_ok=True)
os.makedirs("splade/conf/train/model", exist_ok=True)
os.makedirs("splade/conf/index", exist_ok=True)
os.makedirs("splade/conf/retrieve_evaluate", exist_ok=True)
os.makedirs("splade/conf/flops", exist_ok=True)

# 5.1 Modelo
with open("splade/conf/train/model/splade_bertimbau_base.yaml", "w") as f:
    f.write("""
_target_: splade.models.transformer_rep.Splade
# Nota: O par√¢metro real ser√° lido do init_dict abaixo
model_type_or_dir: neuralmind/bert-base-portuguese-cased
    """)

# 5.2 Dados
with open("splade/conf/train/data/pt.yaml", "w") as f:
    f.write(f"""
# @package _global_
data:
    type: triplets
    TRAIN_DATA_DIR: {os.getcwd()}/splade/data/pt/triplets
    VALIDATION_DATA_DIR: {os.getcwd()}/splade/data/pt/val_retrieval
    QREL_PATH: {os.getcwd()}/splade/data/pt/val_retrieval/qrel.json
    """)

# 5.3 Config de Treino (CORRE√á√ÉO AQUI: Adicionado 'loss')
with open("splade/conf/train/config/splade_pt.yaml", "w") as f:
    f.write("""
# @package _global_
config:
    lr: 2e-5
    seed: 123
    gradient_accumulation_steps: 1
    weight_decay: 0.01
    validation_metrics: [MRR@10]
    pretrained_no_yaml_config: false
    nb_iterations: 150000
    train_batch_size: 32
    eval_batch_size: 32
    index_retrieval_batch_size: 32
    record_frequency: 1000
    train_monitoring_freq: 500
    warmup_steps: 6000
    max_length: 256
    fp16: true
    matching_type: splade
    monitoring_ckpt: true
    tokenizer_type: neuralmind/bert-base-portuguese-cased

    # Par√¢metro de perda que faltava:
    loss: InBatchPairwiseNLL

    # Chaves obrigat√≥rias para o Hydra:
    checkpoint_dir: ""
    index_dir: ""
    out_dir: ""

    regularization:
        FLOPS:
            lambda_q: 0.0003
            lambda_d: 0.0001
            T: 50000
    """)

# 5.4 Config Geral
with open("splade/conf/config_splade_pt.yaml", "w") as f:
    f.write("""
defaults:
  - train/data: pt
  - train/model: splade_bertimbau_base
  - train/config: splade_pt
  - index: pt
  - retrieve_evaluate: pt
  - flops: pt
  - _self_

# init_dict com as corre√ß√µes anteriores
init_dict:
  model_type_or_dir: neuralmind/bert-base-portuguese-cased
  fp16: true

hydra:
  run:
    dir: experiments/pt/out
  job:
    chdir: true
    """)

# 5.5 Placeholders
with open("splade/conf/index/pt.yaml", "w") as f: f.write("# Placeholder")
with open("splade/conf/retrieve_evaluate/pt.yaml", "w") as f: f.write("# Placeholder")
with open("splade/conf/flops/pt.yaml", "w") as f: f.write("# Placeholder")

print("‚úÖ Configura√ß√µes recriadas com sucesso (loss: InBatchPairwiseNLL adicionado).")

‚úÖ Configura√ß√µes recriadas com sucesso (loss: InBatchPairwiseNLL adicionado).


In [6]:
# 6. Executar Treinamento
import os

# Configurar ambiente
os.environ['PYTHONPATH'] = os.environ.get('PYTHONPATH', '') + ":" + os.path.join(os.getcwd(), 'splade')
os.environ['SPLADE_CONFIG_NAME'] = "config_splade_pt.yaml"

print("üöÄ Iniciando treinamento... Acompanhe os logs abaixo.")
print("Nota: Ignorar avisos 'Unable to register cuFFT/cuDNN' do TensorFlow/JAX.")

!cd splade && python3 -m splade.train_from_triplets_ids \
  config.checkpoint_dir=experiments/pt/checkpoint \
  config.index_dir=experiments/pt/index \
  config.out_dir=experiments/pt/out

üöÄ Iniciando treinamento... Acompanhe os logs abaixo.
Nota: Ignorar avisos 'Unable to register cuFFT/cuDNN' do TensorFlow/JAX.
2025-11-28 11:44:36.047415: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764330276.078778    1203 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764330276.088425    1203 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1764330276.111848    1203 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1764330276.111881    1203 computation_placer.cc:177] computation placer already registered. Please check linkage and avoi