In [1]:
!pip install torch faiss-cpu tqdm

!pip install pyarrow==14.0.1 datasets==2.14.6 transformers==4.35.2 accelerate==0.24.1

!pip -q install ipywidgets



In [2]:
import os
import json
from datetime import datetime
import logging
import warnings

import torch
from tqdm.std import tqdm
from datasets import load_dataset, Dataset

from transformers import (
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenForGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    default_data_collator
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

logging.getLogger("transformers").setLevel(logging.ERROR)

# Configurable inputs
DATASET_NAME = "stanfordnlp/web_questions"
SPLIT_TRAIN = "train"
SPLIT_TEST = "test"

USE_DUMMY = True
OUT_ROOT = "WQ_models"

# Training hyperparams
TRAIN_N_DOCS = 1
MAX_Q_LEN = 64
MAX_A_LEN = 32
LR = 1e-5
EPOCHS = 2
BSZ = 1
GRAD_ACC = 8
SAVE_STEPS = 500
LOG_STEPS = 50

os.makedirs(OUT_ROOT, exist_ok=True)
print("Saving all runs under:", OUT_ROOT)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
2026-01-17 19:11:41.731091: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-17 19:11:41.789610: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Running on: cuda
Saving all runs under: WQ_models


In [3]:
def build_wq_flat(split: str, dataset_name: str = DATASET_NAME) -> Dataset:
    ds = load_dataset(dataset_name, split=split)

    questions, answers = [], []
    for ex in tqdm(ds, desc=f"Flattening {dataset_name}:{split}"):
        q = ex["question"]
        for a in ex["answers"]:
            questions.append(q)
            answers.append(a)

    return Dataset.from_dict({"question": questions, "answer": answers})


wq_train_flat = build_wq_flat(split=SPLIT_TRAIN)
wq_test = load_dataset(DATASET_NAME, split=SPLIT_TEST)

print("Flat train:", wq_train_flat)
print("Test:", wq_test)

Flattening stanfordnlp/web_questions:train: 100%|██████████| 3778/3778 [00:00<00:00, 29447.88it/s]


Flat train: Dataset({
    features: ['question', 'answer'],
    num_rows: 8933
})
Test: Dataset({
    features: ['url', 'question', 'answers'],
    num_rows: 2032
})


In [4]:
from transformers import Seq2SeqTrainer

class RagFixedLossTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)

        # outputs can be dict-like or have .loss
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs.loss

        # Make sure it's a scalar
        loss = loss.mean()

        return (loss, outputs) if return_outputs else loss


In [5]:
def make_preprocess_fn(
    tokenizer: RagTokenizer,
    max_q_len: int,
    max_a_len: int,
    decoder_start_token_id: int,
    rag_type: str   # "token" or "sequence"
):
    assert rag_type in {"token", "sequence"}

    gen_tok = tokenizer.generator
    pad_id = gen_tok.pad_token_id

    def preprocess(batch):
        # Question encoder inputs
        q_enc = tokenizer(
            batch["question"],
            padding="max_length",
            truncation=True,
            max_length=max_q_len,
        )

        # Generator target ids (answer_ids)
        a_enc = gen_tok(
            batch["answer"],
            padding="max_length",
            truncation=True,
            max_length=max_a_len,
        )
        answer_ids = a_enc["input_ids"]

        # Decoder_input_ids (shift-right, keep pad ids)
        answer_ids = a_enc["input_ids"]

        decoder_input_ids = [[decoder_start_token_id] + seq[:-1] for seq in answer_ids]
        decoder_attention_mask = [[0 if t == pad_id else 1 for t in seq] for seq in decoder_input_ids]

        labels = answer_ids

        q_enc["decoder_input_ids"] = decoder_input_ids
        q_enc["decoder_attention_mask"] = decoder_attention_mask
        q_enc["labels"] = labels
        return q_enc

    return preprocess


In [6]:
def finetune_wq(
    base_model_name: str,
    rag_type: str,  # "token" or "sequence"
    train_dataset: Dataset,
    use_dummy: bool,
    out_root: str,
    train_n_docs: int,
):
    assert rag_type in {"token", "sequence"}

    # Timestamped run dir
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_name = f"{base_model_name.replace('/','_')}__{rag_type}__wq_ft__nDocs{train_n_docs}__{ts}"
    out_dir = os.path.join(out_root, run_name)
    os.makedirs(out_dir, exist_ok=True)

    print(f"\n=== Fine-tuning {rag_type.upper()} from {base_model_name} ===")
    print(f"Saving to: {out_dir}")

    print("Loading tokenizer.")
    tokenizer = RagTokenizer.from_pretrained(base_model_name)
    print("Tokenizer loaded.")

    print("Loading retriever (downloads/loads Wikipedia index if use_dummy=False).")
    retriever = RagRetriever.from_pretrained(
        base_model_name,
        index_name="exact",
        use_dummy_dataset=use_dummy,
    )
    print("Retriever loaded.")

    print("Loading model weights.")
    model_cls = RagTokenForGeneration if rag_type == "token" else RagSequenceForGeneration
    model = model_cls.from_pretrained(base_model_name, retriever=retriever).to(device)
    print("Model loaded and moved to device.")

    model.config.n_docs = train_n_docs
    model.config.use_cache = False
    if rag_type == "token":
        model.config.reduce_loss = True

    # Preprocess dataset
    decoder_start_id = model.generator.config.decoder_start_token_id
    preprocess_fn = make_preprocess_fn(
        tokenizer=tokenizer,
        max_q_len=MAX_Q_LEN,
        max_a_len=MAX_A_LEN,
        decoder_start_token_id=decoder_start_id,
        rag_type=rag_type
    )
    train_tok = train_dataset.map(
        preprocess_fn,
        batched=True,
        remove_columns=train_dataset.column_names
    )

    train_tok.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
    )


    # Training args
    args = Seq2SeqTrainingArguments(
        output_dir=out_dir,
        per_device_train_batch_size=BSZ,
        gradient_accumulation_steps=GRAD_ACC,
        learning_rate=LR,
        num_train_epochs=EPOCHS,
        fp16=torch.cuda.is_available(),
        logging_strategy="steps",
        logging_steps=LOG_STEPS,
        logging_first_step=True,
        save_strategy="steps",
        save_steps=SAVE_STEPS,
        save_total_limit=2,
        report_to="none",
        evaluation_strategy="no",
        predict_with_generate=False,
        remove_unused_columns=False,
        disable_tqdm=False,
    )

    # Trainer
    trainer = RagFixedLossTrainer(
        model=model,
        args=args,
        train_dataset=train_tok,
        data_collator=default_data_collator,
        tokenizer=None
    )

    print("Number train examples:", len(train_tok))
    approx_steps = (len(train_tok) // (BSZ * GRAD_ACC)) * EPOCHS
    print("Expected steps (approx):", approx_steps)
    print("Starting training loop now...")

    train_result = trainer.train()

    print("Training finished.")
    print(train_result)

    # Save model + tokenizer to SAME folder
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)

    # Save meta
    meta = {
        "dataset": DATASET_NAME,
        "split_train": SPLIT_TRAIN,
        "split_test": SPLIT_TEST,
        "base_model_name": base_model_name,
        "rag_type": rag_type,
        "use_dummy_dataset": use_dummy,
        "train_n_docs": train_n_docs,
        "max_q_len": MAX_Q_LEN,
        "max_a_len": MAX_A_LEN,
        "learning_rate": LR,
        "epochs": EPOCHS,
        "batch_size": BSZ,
        "grad_accumulation": GRAD_ACC,
    }
    with open(os.path.join(out_dir, "run_meta.json"), "w") as f:
        json.dump(meta, f, indent=2)

    print(f"Saved fine-tuned checkpoint to {out_dir}")
    return out_dir

In [7]:
token_ckpt_dir = finetune_wq(
    base_model_name="facebook/rag-token-nq",
    rag_type="token",
    train_dataset=wq_train_flat,
    use_dummy=USE_DUMMY,
    out_root=OUT_ROOT,
    train_n_docs=TRAIN_N_DOCS,
)
print("RAG Token checkpoint:", token_ckpt_dir)


=== Fine-tuning TOKEN from facebook/rag-token-nq ===
Saving to: WQ_models/facebook_rag-token-nq__token__wq_ft__nDocs1__20260117_191147
Loading tokenizer.




Tokenizer loaded.
Loading retriever (downloads/loads Wikipedia index if use_dummy=False).
Retriever loaded.
Loading model weights.


  _torch_pytree._register_pytree_node(


Model loaded and moved to device.


Map:   0%|          | 0/8933 [00:00<?, ? examples/s]

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Number train examples: 8933
Expected steps (approx): 2232
Starting training loop now...


Step,Training Loss
1,41.1917


KeyboardInterrupt: 

In [8]:
seq_ckpt_dir = finetune_wq(
    base_model_name="facebook/rag-sequence-nq",
    rag_type="sequence",
    train_dataset=wq_train_flat,
    use_dummy=USE_DUMMY,
    out_root=OUT_ROOT,
    train_n_docs=TRAIN_N_DOCS,
)
print("RAG Sequence checkpoint:", seq_ckpt_dir)


=== Fine-tuning SEQUENCE from facebook/rag-sequence-nq ===
Saving to: WQ_models/facebook_rag-sequence-nq__sequence__wq_ft__nDocs1__20260117_191308
Loading tokenizer.
Tokenizer loaded.
Loading retriever (downloads/loads Wikipedia index if use_dummy=False).
Retriever loaded.
Loading model weights.
Model loaded and moved to device.


Map:   0%|          | 0/8933 [00:00<?, ? examples/s]

Number train examples: 8933
Expected steps (approx): 2232
Starting training loop now...


Step,Training Loss
1,49.4964


OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB. GPU 0 has a total capacity of 14.57 GiB of which 4.75 MiB is free. Process 1135247 has 14.56 GiB memory in use. Of the allocated memory 14.07 GiB is allocated by PyTorch, and 360.95 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)