In [1]:
pip install transformers datasets peft trl bitsandbytes accelerate ipywidgets tensorboard optuna

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install unsloth

Note: you may need to restart the kernel to use updated packages.


In [3]:
from unsloth import FastModel
import torch

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-12b-it",
    max_seq_length = 16384,
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    load_in_16bit = True,
    full_finetuning = False, # [NEW!] We have full finetuning now!
)

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Your Flash Attention 2 installation seems to be broken?
A possible explanation is you have a new CUDA version which isn't
yet compatible with FA2? Please file a ticket to Unsloth or FA2.
We shall now use Xformers instead, which does not have any performance hits!
We found this negligible impact by benchmarking on 1x A100.
Switching to PyTorch attention since your Xformers is broken.

/home/cosmin/anaconda3/envs/rerank/lib/python3.12/site-packages/flash_attn_2_cuda.cpython-312-x86_64-linux-gnu.so: undefined symbol: _ZNK3c106SymInt6sym_neERKS0_
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.11.6: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    NVIDIA H200 NVL. Num GPUs = 1. Max memory: 139.801 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 9.0. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [4]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 128,           # Larger = higher accuracy, but might overfit
    lora_alpha = 256,  # Recommended alpha == r at least
    lora_dropout = 0.05,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    bias = "none",
    random_state = 3407,
)

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.


Unsloth: Making `base_model.model.model.vision_tower.vision_model` require gradients


In [5]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

In [6]:
from datasets import load_from_disk

dataset = load_from_disk("reranking_dataset_doris_mae")

In [7]:
# Filter out queries with no positive documents
dataset = dataset.filter(
    lambda example: any(score > 0 for score in example["candidate_scores"]),
    desc="Removing zero-positive queries"
)

print(f"Dataset size after filtering: {len(dataset)}")

Dataset size after filtering: 1498


In [8]:
dataset[100]

{'query_text': 'The document should discuss the improvement of interpretability and explainability of machine learning models.',
 'aspect_id': '171',
 'original_query_type': 'ML',
 'candidate_pool': [184834,
  141826,
  94225,
  91667,
  48153,
  84506,
  70171,
  316961,
  357922,
  193061,
  357415,
  48168,
  121899,
  64560,
  96305,
  236595,
  43574,
  152119,
  263743,
  45119,
  190015,
  242239,
  170071,
  167008,
  83044,
  173673,
  166518,
  36476,
  154247,
  261768,
  202383,
  80015,
  38037,
  80538,
  281249,
  138404,
  151208,
  208041,
  253622,
  216765,
  71871,
  44223,
  60102,
  212169,
  251597,
  271055,
  319696,
  156881,
  337109,
  160471,
  89316,
  302822,
  179957,
  207094,
  76023,
  76021,
  298758,
  162060,
  241933,
  160526,
  3852,
  99608,
  138523,
  40234,
  244011,
  69932,
  242478,
  100660,
  170805,
  272183,
  147768,
  123706,
  275272,
  67913,
  277856,
  213344,
  48992,
  166762,
  104811,
  124270,
  111989,
  91004,
  198022,
 

In [9]:
import pandas as pd

corpus_df = pd.read_json("splits/Corpus.json")

In [10]:
corpus_df.head()

Unnamed: 0,masked_abstract,original_abstract,title,url,primary_category,categories,incoming_citations,ss_id,outgoing_citations,abstract_id
0,Neural machine translation ( * ) systems aim t...,Neural machine translation (NMT) systems aim t...,How_do_lexical_semantics_affect_translation?_A...,http://arxiv.org/abs/2201.00075v1,cs.CL,"[cs.CL, cs.LG]",[],8cff7cb7a44672d1108d63ed611e1056acf2ace3,"[157157, 197007, 238926, 280486, 283908, 29077...",0
1,If popular online platforms systematically exp...,If popular online platforms systematically exp...,Engagement_Outweighs_Exposure_to_Partisan_and_...,http://arxiv.org/abs/2201.00074v3,cs.SI,[cs.SI],[],404f412bc6b76f58418604eb5caac077bb25cdac,"[14715, 57120, 84703, 111772, 124625, 263661]",1
2,While neural networks have shown remarkable su...,While neural networks have shown remarkable su...,BARACK:_Partially_Supervised_Group_Robustness_...,http://arxiv.org/abs/2201.00072v2,cs.LG,[cs.LG],[],49750bf1dd5e66025c18adfce5ce7fef445fb9d4,"[5883, 34547, 84663, 89288, 91962, 94275, 9482...",2
3,"In recent times , a large number of people hav...","In recent times, a large number of people have...",A_Deep_Learning_Approach_to_Integrate_Human-Le...,http://arxiv.org/abs/2201.02735v1,cs.CL,"[cs.CL, cs.LG]",[],0be842930627213049f5567fe37d65237e535960,"[137492, 159867, 299491]",3
4,The distributed consensus mechanism is the bac...,The distributed consensus mechanism is the bac...,Confronting_the_Carbon-footprint_Challenge_of_...,http://arxiv.org/abs/2201.06929v1,cs.CY,[cs.CY],[],1dee953d5211578d330a1328fb0033dbe644685b,"[179697, 186061, 232933, 252408]",4


In [11]:
len(corpus_df)

363133

In [12]:
print("Creating abstract_id -> text lookup dictionary...")
abstract_lookup = {}
for _, row in corpus_df.iterrows():
    title = row['title'].replace("_", " ")
    text = f"Title: {title}\nAbstract: {row['original_abstract']}"
    abstract_lookup[int(row['abstract_id'])] = text

print(f"Created lookup for {len(abstract_lookup)} abstracts.")

Creating abstract_id -> text lookup dictionary...
Created lookup for 363133 abstracts.


In [13]:
abstract_lookup[0]

'Title: How do lexical semantics affect translation? An empirical study\nAbstract: Neural machine translation (NMT) systems aim to map text from one language\ninto another. While there are a wide variety of applications of NMT, one of the\nmost important is translation of natural language. A distinguishing factor of\nnatural language is that words are typically ordered according to the rules of\nthe grammar of a given language. Although many advances have been made in\ndeveloping NMT systems for translating natural language, little research has\nbeen done on understanding how the word ordering of and lexical similarity\nbetween the source and target language affect translation performance. Here, we\ninvestigate these relationships on a variety of low-resource language pairs\nfrom the OpenSubtitles2016 database, where the source language is English, and\nfind that the more similar the target language is to English, the greater the\ntranslation performance. In addition, we study the impa

In [14]:
import random

def make_conversation(example):
    query_text = example['query_text']
    candidate_pool = example['candidate_pool']
    candidate_scores = example['candidate_scores']

    # Generate passages string as before
    passages_str_list = []
    for i, doc_id in enumerate(candidate_pool):
        doc_text = abstract_lookup.get(int(doc_id), "Error: Document text not found.")
        passages_str_list.append(f"[{i+1}] {doc_text}")
    passages_for_prompt = "\n\n---\n\n".join(passages_str_list)

    # Build the user prompt
    user_msg = (
        f"You are an expert academic paper reranker. "
        f"Your task is to re-order the given list of passages (from [1] to [{len(candidate_pool)}]) "
        f"based on their relevance to the query. Respond with only the ranking and nothing else.\n\n"
        f"Example output for 8 passages:\n"
        f"[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]\n\n"
        f"Query: {query_text}\n\n"
        f"Passages:\n{passages_for_prompt}\n\n"
        f"Your ranking (most to least relevant):"
    )

    # Build the assistant response (ground truth ranking)
    indexed_scores = [(i+1, score) for i, score in enumerate(candidate_scores)]
    sorted_indices = sorted(indexed_scores, key=lambda x: x[1], reverse=True)
    sorted_list_str = " > ".join(f"[{idx}]" for idx, _ in sorted_indices)

    conversation = [
        {"role": "user", "content": user_msg},
        {"role": "assistant", "content": sorted_list_str}
    ]

    return {"conversations": conversation}

def make_conversation_k(example):
    query_text = example['query_text']
    candidate_pool = example['candidate_pool']
    candidate_scores = example['candidate_scores']
    
    positives = [(doc_id, score) for doc_id, score in zip(candidate_pool, candidate_scores) if score > 0]
    negatives = [(doc_id, score) for doc_id, score in zip(candidate_pool, candidate_scores) if score == 0]
    
    TARGET_SIZE = 50
    MAX_POSITIVES = 5 
    
    random.shuffle(positives)
    random.shuffle(negatives)
    
    selected_positives = positives[:MAX_POSITIVES]
    num_negatives_needed = TARGET_SIZE - len(selected_positives)
    selected_negatives = negatives[:num_negatives_needed]
    selected_docs = selected_positives + selected_negatives
    
    if len(selected_docs) < TARGET_SIZE:
        remaining_positives = positives[MAX_POSITIVES:]
        needed = TARGET_SIZE - len(selected_docs)
        selected_docs += remaining_positives[:needed]

    random.shuffle(selected_docs)
    
    pool_k = [doc_id for doc_id, _ in selected_docs]
    scores_k = [score for _, score in selected_docs]
    
    passages_str_list = []
    for i, (doc_id, _) in enumerate(selected_docs):
        doc_text = abstract_lookup.get(int(doc_id), "Error: Document text not found.")
        doc_text = " ".join(doc_text.split()[:300]) 
        passages_str_list.append(f"[{i+1}] {doc_text}")

    passages_for_prompt = "\n\n---\n\n".join(passages_str_list)

    user_msg = (
        f"You are an expert academic paper reranker. "
        f"Your task is to re-order the given list of passages (from [1] to [{len(selected_docs)}]) "
        f"based on their relevance to the query. Respond with only the ranking and nothing else.\n\n"
        f"Example output for 8 passages:\n"
        f"[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]\n\n"
        f"Query: {query_text}\n\n"
        f"Passages:\n{passages_for_prompt}\n\n"
        f"Your ranking (most to least relevant):"
    )

    indexed_scores = [(i+1, score) for i, score in enumerate(scores_k)]
    sorted_indices = sorted(indexed_scores, key=lambda x: (x[1], -x[0]), reverse=True)
    
    sorted_list_str = " > ".join(f"[{idx}]" for idx, _ in sorted_indices)

    conversation = [
        {"role": "user", "content": user_msg},
        {"role": "assistant", "content": sorted_list_str}
    ]

    return {
        "conversations": conversation,
        "pool_k": pool_k,
        "scores_k": scores_k
    }

def make_conversation_k_v2(example, rng, max_positives=5, target_size=50, hard_neg_top=None, pos_mode="any"):
    query_text = example["query_text"]
    candidate_pool = example["candidate_pool"]
    candidate_scores = example["candidate_scores"]

    positives = [(d,s) for d,s in zip(candidate_pool, candidate_scores) if s > 0]
    negatives = [(d,s) for d,s in zip(candidate_pool, candidate_scores) if s == 0]

    # Optional: focus on "harder" negatives if candidate_pool is ranked
    if hard_neg_top is not None and len(negatives) > hard_neg_top:
        negatives = negatives[:hard_neg_top]

    # Optional: choose which positives to include
    if pos_mode == "only_2":
        positives = [p for p in positives if p[1] == 2]
    elif pos_mode == "only_1":
        positives = [p for p in positives if p[1] == 1]
    # pos_mode == "any": keep all

    rng.shuffle(positives)
    rng.shuffle(negatives)

    selected_pos = positives[:max_positives]
    num_negs = target_size - len(selected_pos)

    # sample without replacement if possible
    if len(negatives) >= num_negs:
        selected_neg = negatives[:num_negs]
    else:
        selected_neg = negatives[:]  # then edge-case fill with remaining positives like you already do

    selected_docs = selected_pos + selected_neg
    if len(selected_docs) < target_size:
        remaining_pos = positives[max_positives:]
        needed = target_size - len(selected_docs)
        selected_docs += remaining_pos[:needed]

    rng.shuffle(selected_docs)

    pool_k  = [d for d,s in selected_docs]
    scores_k = [s for d,s in selected_docs]

    passages_str_list = []
    for i, (doc_id, score) in enumerate(selected_docs):
        doc_text = abstract_lookup.get(int(doc_id), "Error: Document text not found.")
        doc_text = " ".join(doc_text.split()[:300]) 
        passages_str_list.append(f"[{i+1}] {doc_text}")

    passages_for_prompt = "\n\n---\n\n".join(passages_str_list)

    # Build User Prompt
    user_msg = (
        f"You are an expert academic paper reranker. "
        f"Your task is to re-order the given list of passages (from [1] to [{len(selected_docs)}]) "
        f"based on their relevance to the query. Respond with only the ranking and nothing else.\n\n"
        f"Example output for 8 passages:\n"
        f"[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]\n\n"
        f"Query: {query_text}\n\n"
        f"Passages:\n{passages_for_prompt}\n\n"
        f"Your ranking (most to least relevant):"
    )

    # Build Ground Truth
    # Sort by score (descending). 
    # Secondary sort by original index (ascending) ensures stable sorting for ties.
    indexed_scores = [(i+1, score) for i, score in enumerate(scores_k)]
    sorted_indices = sorted(indexed_scores, key=lambda x: (x[1], -x[0]), reverse=True)
    
    sorted_list_str = " > ".join(f"[{idx}]" for idx, _ in sorted_indices)

    conversation = [
        {"role": "user", "content": user_msg},
        {"role": "assistant", "content": sorted_list_str}
    ]

    return {"conversations": conversation, "pool_k": pool_k, "scores_k": scores_k}


def formatting_prompts_func_k(examples):
    texts = [] 
    new_pools = []
    new_scores = []
    
    num_examples = len(examples['query_text'])
    
    for i in range(num_examples):
        # 1. Reconstruct the single example
        single_example = {key: examples[key][i] for key in examples}
        
        # 2. Process
        processed = make_conversation_k(single_example)
        
        # 3. Apply Chat Template
        formatted_text = tokenizer.apply_chat_template(
            processed['conversations'],
            tokenize=False,
            add_generation_prompt=False
        )
        
        formatted_text = formatted_text.removeprefix('<bos>')
        
        # 4. Append to lists
        texts.append(formatted_text)
        new_pools.append(processed['pool_k'])
        new_scores.append(processed['scores_k'])
    
    # Return dictionary with all new columns
    return { 
        "text": texts, 
        "candidate_pool_k": new_pools, 
        "candidate_scores_k": new_scores 
    }

import random, hashlib

N_VARIANTS = 6
POS_MODES = ["any", "only_2", "only_1"]

def stable_seed(query_text: str, v: int) -> int:
    h = hashlib.md5(f"{query_text}||{v}".encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def formatting_prompts_func_k_v2(examples):
    texts, new_pools, new_scores = [], [], []
    n = len(examples["query_text"])

    for i in range(n):
        base = {k: examples[k][i] for k in examples}
        qtext = base["query_text"]

        for v in range(N_VARIANTS):
            rng = random.Random(stable_seed(qtext, v))
            pos_mode = POS_MODES[v % len(POS_MODES)]

            processed = make_conversation_k_v2(
                base,
                rng=rng,
                max_positives=5,
                target_size=50,
                hard_neg_top=50,
                pos_mode=pos_mode,
            )

            formatted = tokenizer.apply_chat_template(
                processed["conversations"],
                tokenize=False,
                add_generation_prompt=False
            ).removeprefix("<bos>")

            texts.append(formatted)
            new_pools.append(processed["pool_k"])
            new_scores.append(processed["scores_k"])

    return {"text": texts, "candidate_pool_k": new_pools, "candidate_scores_k": new_scores}


def formatting_prompts_func(examples):
    texts = [] # We need a list for the final formatted strings
    
    num_examples = len(examples['query_text'])
    
    for i in range(num_examples):
        # 1. Reconstruct the single example
        single_example = {key: examples[key][i] for key in examples}
        
        # 2. Build the structural conversation (List of Dicts)
        # This returns {'conversations': [{'role': 'user', ...}, ...]}
        processed = make_conversation_k(single_example)
        
        # 3. APPLY THE CHAT TEMPLATE (The missing step!)
        # This converts the List of Dicts into the specific string format for Gemma 3
        formatted_text = tokenizer.apply_chat_template(
            processed['conversations'],
            tokenize=False,
            add_generation_prompt=False
        )
        
        # 4. Remove the extra <bos> token if Unsloth adds its own
        formatted_text = formatted_text.removeprefix('<bos>')
        
        texts.append(formatted_text)
    
    # Return the field 'text', which is what the Trainer usually looks for
    return { "text": texts }

In [15]:
import random, hashlib
from typing import List, Tuple, Dict, Any

# -----------------------------
# CONFIG
# -----------------------------
N_VARIANTS       = 5     # AT MOST this many unique pools per query
TARGET_SIZE      = 50
MAX_POSITIVES    = 5
HARD_NEG_TOP     = 100    # set None to use ALL negatives
MAX_TOTAL_TRIES  = 3000  # per query; stop early if can't find more unique pools

# Uniqueness definition:
# - set-based (order-invariant): same 20 doc_ids in different order counts as DUPLICATE
def pool_signature(doc_ids: List[Any]) -> Tuple[int, ...]:
    return tuple(sorted(int(d) for d in doc_ids))

# If you ever want "unique-by-order" instead, use:
# def pool_signature(doc_ids): return tuple(int(d) for d in doc_ids)


def stable_seed(row_id: int, query_text: str) -> int:
    # row_id prevents collisions if query_text repeats
    h = hashlib.md5(f"{row_id}||{query_text}".encode("utf-8")).hexdigest()
    return int(h[:8], 16)


# -----------------------------
# PROMPT + TARGET CREATION
# -----------------------------
def create_single_conversation_entry(
    query_text: str,
    selected_docs: List[Tuple[Any, int]],
    abstract_lookup: Dict[int, str],
) -> Dict[str, Any]:
    """
    selected_docs: list of (doc_id, score), already shuffled in input order
    Returns dict with:
      - conversations
      - pool_k, scores_k
    """
    passages_str_list = []
    scores_k = []
    pool_k = []

    for i, (doc_id, score) in enumerate(selected_docs):
        pool_k.append(doc_id)
        scores_k.append(int(score))

        doc_text = abstract_lookup.get(int(doc_id), "Error: Document text not found.")
        doc_text = " ".join(doc_text.split()[:300])  # truncate to ~300 words
        passages_str_list.append(f"[{i+1}] {doc_text}")

    passages_for_prompt = "\n\n---\n\n".join(passages_str_list)

    user_msg = (
        f"You are an expert academic paper reranker. "
        f"Your task is to re-order the given list of passages (from [1] to [{len(selected_docs)}]) "
        f"based on their relevance to the query. Respond with only the ranking and nothing else.\n\n"
        f"Example output for 8 passages:\n"
        f"[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]\n\n"
        f"Query: {query_text}\n\n"
        f"Passages:\n{passages_for_prompt}\n\n"
        f"Your ranking (most to least relevant):"
    )

    # Ground truth ranking: score desc, tie broken by earlier index (stable)
    indexed_scores = [(i+1, s) for i, s in enumerate(scores_k)]
    sorted_indices = sorted(indexed_scores, key=lambda x: (x[1], -x[0]), reverse=True)
    target_str = " > ".join(f"[{idx}]" for idx, _ in sorted_indices)

    conversation = [
        {"role": "user", "content": user_msg},
        {"role": "assistant", "content": target_str},
    ]

    return {"conversations": conversation, "pool_k": pool_k, "scores_k": scores_k}


# -----------------------------
# UNIQUE POOL SAMPLING
# -----------------------------
def sample_unique_pools_for_query(
    example: Dict[str, Any],
    *,
    rng: random.Random,
    n_variants: int,
    target_size: int,
    max_positives: int,
    hard_neg_top: int | None,
    max_total_tries: int,
) -> List[List[Tuple[Any, int]]]:
    pool = example["candidate_pool"]
    scores = example["candidate_scores"]

    # Bucket
    neg  = [(d, int(s)) for d, s in zip(pool, scores) if int(s) == 0]
    pos1 = [(d, int(s)) for d, s in zip(pool, scores) if int(s) == 1]
    pos2 = [(d, int(s)) for d, s in zip(pool, scores) if int(s) >= 2]
    all_pos = pos2 + pos1
    all_docs = all_pos + neg

    if hard_neg_top is not None and len(neg) > hard_neg_top:
        neg = neg[:hard_neg_top]

    # If pool is tiny, we'll just return <=1 variant from what exists.
    if len(all_docs) == 0:
        return []

    # Try structured modes first, then keep sampling "any" for coverage
    modes = ["only_2", "only_1", "mixed", "any", "any", "any"]

    seen = set()
    unique_selected_docs: List[List[Tuple[Any, int]]] = []

    tries = 0
    while len(unique_selected_docs) < n_variants and tries < max_total_tries:
        mode = modes[tries % len(modes)]
        tries += 1

        # ---- choose positives
        positives: List[Tuple[Any, int]] = []
        if mode == "only_2":
            k = min(len(pos2), max_positives)
            positives = rng.sample(pos2, k) if k > 0 else []
        elif mode == "only_1":
            k = min(len(pos1), max_positives)
            positives = rng.sample(pos1, k) if k > 0 else []
        elif mode == "mixed":
            k2 = min(len(pos2), max_positives // 2)
            k1 = min(len(pos1), max_positives - k2)
            if k2 > 0: positives += rng.sample(pos2, k2)
            if k1 > 0: positives += rng.sample(pos1, k1)
        else:  # "any"
            k = min(len(all_pos), max_positives)
            positives = rng.sample(all_pos, k) if k > 0 else []

        # ---- choose negatives
        needed = target_size - len(positives)
        if needed < 0:
            positives = positives[:target_size]
            needed = 0

        negatives: List[Tuple[Any, int]] = []
        if needed > 0:
            kneg = min(len(neg), needed)
            negatives = rng.sample(neg, kneg) if kneg > 0 else []

        if not positives and all_pos and max_positives > 0:
            positives = [rng.choice(all_pos)]

        selected = positives + negatives

        # ---- fill up to target_size with remaining docs (pos or neg) not already included
        if len(selected) < target_size:
            used_ids = {d for d, _ in selected}
            remaining = [p for p in all_docs if p[0] not in used_ids]
            fill = remaining[: max(0, target_size - len(selected))]
            selected += fill

        # If still < target_size because the overall pool is too small, that's fine:
        # we'll produce a smaller list (but signature uniqueness still applies).
        doc_ids = [d for d, _ in selected]
        sig = pool_signature(doc_ids)

        if sig in seen:
            continue

        seen.add(sig)
        rng.shuffle(selected)  # IMPORTANT: avoid positional shortcut
        unique_selected_docs.append(selected)

    return unique_selected_docs


# -----------------------------
# DATASET MAP FUNCTION (EXPANDS ROWS)
# -----------------------------
def formatting_prompts_func_unique(examples, indices):
    """
    Requires dataset.map(..., batched=True, with_indices=True)
    Assumes these exist in scope:
      - tokenizer
      - abstract_lookup
    """
    out_text, out_pools, out_scores = [], [], []
    n = len(examples["query_text"])

    for i in range(n):
        ex = {k: examples[k][i] for k in examples}
        row_id = indices[i]
        qtext = ex["query_text"]

        rng = random.Random(stable_seed(row_id, qtext))

        unique_pools = sample_unique_pools_for_query(
            ex,
            rng=rng,
            n_variants=N_VARIANTS,
            target_size=TARGET_SIZE,
            max_positives=MAX_POSITIVES,
            hard_neg_top=HARD_NEG_TOP,
            max_total_tries=MAX_TOTAL_TRIES,
        )

        for selected_docs in unique_pools:
            processed = create_single_conversation_entry(
                query_text=qtext,
                selected_docs=selected_docs,
                abstract_lookup=abstract_lookup,
            )

            formatted = tokenizer.apply_chat_template(
                processed["conversations"],
                tokenize=False,
                add_generation_prompt=False,
            ).removeprefix("<bos>")

            out_text.append(formatted)
            out_pools.append(processed["pool_k"])
            out_scores.append(processed["scores_k"])

    return {"text": out_text, "candidate_pool_k": out_pools, "candidate_scores_k": out_scores}


# def has_positives(example):
#     return any(int(s) > 0 for s in example["candidate_scores"])

# dataset_filtered = dataset.filter(has_positives)

# print(f"Original: {len(dataset)}, After filtering: {len(dataset_filtered)}")
# print(f"Removed {len(dataset) - len(dataset_filtered)} queries with zero positives")

# Then run augmentation on filtered dataset
augmented_dataset = dataset.map(
    formatting_prompts_func_unique,
    batched=True,
    with_indices=True,
    remove_columns=dataset.column_names,
    load_from_cache_file=False,
)

print("Original size:", len(dataset))
print("Augmented size:", len(augmented_dataset))


Map:   0%|          | 0/1498 [00:00<?, ? examples/s]

Original size: 1498
Augmented size: 7490


In [16]:
dataset = dataset.map(formatting_prompts_func_k, batched=True)

In [17]:
augmented_dataset[1]

{'text': '<start_of_turn>user\nYou are an expert academic paper reranker. Your task is to re-order the given list of passages (from [1] to [50]) based on their relevance to the query. Respond with only the ranking and nothing else.\n\nExample output for 8 passages:\n[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]\n\nQuery: The dataset should be related to the natural language processing task.\n\nPassages:\n[1] Title: Tutorial: Safe and Reliable Machine Learning Abstract: This document serves as a brief overview of the "Safe and Reliable Machine Learning" tutorial given at the 2019 ACM Conference on Fairness, Accountability, and Transparency (FAT* 2019). The talk slides can be found here: https://bit.ly/2Gfsukp, while a video of the talk is available here: https://youtu.be/FGLOCkC4KmE, and a complete list of references for the tutorial here: https://bit.ly/2GdLPme.\n\n---\n\n[2] Title: Gradient Descent in RKHS with Importance Labeling Abstract: Labeling cost is often expensive and is a fu

In [18]:
augmented_dataset[1]['text']


'<start_of_turn>user\nYou are an expert academic paper reranker. Your task is to re-order the given list of passages (from [1] to [50]) based on their relevance to the query. Respond with only the ranking and nothing else.\n\nExample output for 8 passages:\n[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]\n\nQuery: The dataset should be related to the natural language processing task.\n\nPassages:\n[1] Title: Tutorial: Safe and Reliable Machine Learning Abstract: This document serves as a brief overview of the "Safe and Reliable Machine Learning" tutorial given at the 2019 ACM Conference on Fairness, Accountability, and Transparency (FAT* 2019). The talk slides can be found here: https://bit.ly/2Gfsukp, while a video of the talk is available here: https://youtu.be/FGLOCkC4KmE, and a complete list of references for the tutorial here: https://bit.ly/2GdLPme.\n\n---\n\n[2] Title: Gradient Descent in RKHS with Importance Labeling Abstract: Labeling cost is often expensive and is a fundamental

In [19]:
bad_items = [
    (i, item)
    for i, item in enumerate(augmented_dataset)
    if not any(score > 0 for score in item["candidate_scores_k"])
]

print(f"Number of queries with zero positives: {len(bad_items)}\n")

for i, item in bad_items:
    print(f"Index: {i}")
    print(f"Query: {item.get('query_text', '<no query_text>')}")
    print(f"Candidate scores: {item['candidate_scores_k']}")
    print("-" * 80)


Number of queries with zero positives: 0



In [20]:
from datasets import DatasetDict

print(f"Original dataset size: {len(augmented_dataset)}")

train_val_test_split = augmented_dataset.train_test_split(test_size=0.1, seed=42)

train_val_split = train_val_test_split['train'].train_test_split(test_size=(0.1/0.9), seed=42)

finetuning_splits = DatasetDict({
    'train': train_val_split['train'],
    'validation': train_val_split['test'],
    'test': train_val_test_split['test']
})

print(finetuning_splits)

Original dataset size: 7490
DatasetDict({
    train: Dataset({
        features: ['text', 'candidate_pool_k', 'candidate_scores_k'],
        num_rows: 5992
    })
    validation: Dataset({
        features: ['text', 'candidate_pool_k', 'candidate_scores_k'],
        num_rows: 749
    })
    test: Dataset({
        features: ['text', 'candidate_pool_k', 'candidate_scores_k'],
        num_rows: 749
    })
})


In [21]:
finetuning_splits['train'][100]

{'text': '<start_of_turn>user\nYou are an expert academic paper reranker. Your task is to re-order the given list of passages (from [1] to [50]) based on their relevance to the query. Respond with only the ranking and nothing else.\n\nExample output for 8 passages:\n[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]\n\nQuery: The paper should discuss a labeling process.\n\nPassages:\n[1] Title: Needmining: Designing Digital Support to Elicit Needs from Social Media Abstract: Today\'s businesses face a high pressure to innovate in order to succeed in highly competitive markets. Successful innovations, though, typically require the identification and analysis of customer needs. While traditional, established need elicitation methods are time-proven and have demonstrated their capabilities to deliver valuable insights, they lack automation and scalability and, thus, are expensive and time-consuming. In this article, we propose an approach to automatically identify and quantify customer needs b

In [22]:
print(f"Example of formatted train data:\n{finetuning_splits['train'][100]['text']}")


Example of formatted train data:
<start_of_turn>user
You are an expert academic paper reranker. Your task is to re-order the given list of passages (from [1] to [50]) based on their relevance to the query. Respond with only the ranking and nothing else.

Example output for 8 passages:
[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]

Query: The paper should discuss a labeling process.

Passages:
[1] Title: Needmining: Designing Digital Support to Elicit Needs from Social Media Abstract: Today's businesses face a high pressure to innovate in order to succeed in highly competitive markets. Successful innovations, though, typically require the identification and analysis of customer needs. While traditional, established need elicitation methods are time-proven and have demonstrated their capabilities to deliver valuable insights, they lack automation and scalability and, thus, are expensive and time-consuming. In this article, we propose an approach to automatically identify and quantify cus

In [23]:
def calculate_length(example):
    tokens = tokenizer(text=example['text'], truncation=False)
    
    return {"token_length": len(tokens['input_ids'][0])}

print("Calculating token lengths for the training set...")
train_val_with_lengths = train_val_split.map(calculate_length, num_proc=4)

all_lengths = []
for split in train_val_with_lengths:
    all_lengths.extend(train_val_with_lengths[split]["token_length"])

max_tokens_in_set = max(all_lengths)
print(f"The maximum number of tokens is: {max_tokens_in_set}")

Calculating token lengths for the training set...


Map (num_proc=4):   0%|          | 0/5992 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/749 [00:00<?, ? examples/s]

The maximum number of tokens is: 14478


In [24]:
# import optuna
# import gc
# import torch
# from unsloth import FastModel

# def model_init(trial):
#     # 1. Clean up memory from previous runs (Crucial for single GPU!)
#     gc.collect()
#     torch.cuda.empty_cache()
    
#     # 2. Define the search space for LoRA parameters
#     # We define them here so we can pass them to get_peft_model
#     r_value = trial.suggest_categorical("peft_r", [8, 16, 32])
#     alpha_value = trial.suggest_categorical("peft_alpha", [16, 32, 64])
#     dropout_value = trial.suggest_float("peft_dropout", 0.0, 0.1)
    
#     # 3. Load the base model (Unsloth handles caching efficiently)
#     model, _ = FastModel.from_pretrained(
#         model_name = "unsloth/gemma-3-4b-it",
#         max_seq_length = 8192,
#         load_in_4bit = True,
#         load_in_8bit = False,
#     )
    
#     # 4. Apply LoRA with the trial's parameters
#     model = FastModel.get_peft_model(
#         model,
#         r = r_value,
#         lora_alpha = alpha_value,
#         lora_dropout = dropout_value,
#         target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
#         finetune_vision_layers = False,
#         finetune_language_layers = True,
#         finetune_attention_modules = True,
#         finetune_mlp_modules = True,
#         bias = "none",
#         random_state = 3407,
#     )
    
#     return model

In [25]:
# def hp_space(trial):
#     return {
#         "learning_rate": trial.suggest_float("learning_rate", 1e-6, 2e-4, log=True),
#         "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16]),
#         "gradient_accumulation_steps": trial.suggest_categorical("gradient_accumulation_steps", [1, 2, 4]),
#         "warmup_ratio": trial.suggest_float("warmup_ratio", 0.0, 0.1),
#     }

In [26]:
# import optuna
# import gc
# import torch
# from unsloth import FastModel
# from unsloth.chat_templates import get_chat_template
# from trl import SFTTrainer, SFTConfig

# def objective(trial):
#     # 1. Clean up memory from previous trial
#     gc.collect()
#     torch.cuda.empty_cache()
    
#     # 2. Define Hyperparameters to Tune
#     # LoRA Params
#     r_value = trial.suggest_categorical("peft_r", [8, 16])
#     alpha_value = trial.suggest_categorical("peft_alpha", [16, 32]) # Keep alpha/r ratio reasonable (2x)
#     dropout_value = trial.suggest_float("peft_dropout", 0.0, 0.1)
    
#     # Training Params
#     lr_value = trial.suggest_float("learning_rate", 1e-6, 5e-5, log=True)
#     batch_size_value = trial.suggest_categorical("batch_size", [8, 16])
    
#     # 3. Load Model (Must be done fresh every trial)
#     model, tokenizer = FastModel.from_pretrained(
#         model_name = "unsloth/gemma-3-4b-it",
#         max_seq_length = 8192,
#         load_in_4bit = True,
#         load_in_8bit = False,
#     )
    
#     # Apply Chat Template (Crucial!)
#     tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
    
#     # Apply LoRA with trial parameters
#     model = FastModel.get_peft_model(
#         model,
#         r = r_value,
#         lora_alpha = alpha_value,
#         lora_dropout = dropout_value,
#         target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
#         finetune_vision_layers = False,
#         finetune_language_layers = True,
#         finetune_attention_modules = True,
#         finetune_mlp_modules = True,
#         bias = "none",
#         random_state = 3407,
#     )
    
#     # 4. Configure Trainer
#     # We use a smaller number of steps/epochs for tuning to save time
#     training_args = SFTConfig(
#         dataset_text_field="text",
#         max_seq_length=8192,
#         output_dir=f"hyperparam_results/trial_{trial.number}",
#         per_device_train_batch_size=batch_size_value,
#         gradient_accumulation_steps=1, 
#         num_train_epochs=1,      # 1 Epoch is enough to see convergence trends
#         learning_rate=lr_value,
#         logging_steps=10,
#         eval_strategy="steps",
#         eval_steps=20,           # Frequent eval to catch overfitting early
#         save_strategy="no",      # Don't save checkpoints to save disk space
#         optim="adamw_8bit",
#         weight_decay=0.01,
#         seed=3407,
#         report_to="none",
#     )
    
#     # Use the 'train_on_responses_only' wrapper manually if needed, 
#     # or rely on the fact that SFTTrainer handles the template if formatted correctly.
#     # Since we formatted the dataset with the template already, we pass it directly.
    
#     trainer = SFTTrainer(
#         model=model,
#         tokenizer=tokenizer,
#         train_dataset=finetuning_splits['train'],
#         eval_dataset=finetuning_splits['validation'],
#         args=training_args,
#     )
    
#     # Apply the response masking wrapper
#     from unsloth.chat_templates import train_on_responses_only
#     trainer = train_on_responses_only(
#         trainer,
#         instruction_part = "<start_of_turn>user\n",
#         response_part = "<start_of_turn>model\n",
#     )
    
#     # 5. Train
#     trainer.train()
    
#     # 6. Get Final Validation Loss
#     eval_stats = trainer.evaluate()
#     val_loss = eval_stats["eval_loss"]
    
#     # 7. Cleanup to prevent OOM
#     del model
#     del trainer
#     gc.collect()
#     torch.cuda.empty_cache()
    
#     return val_loss

# # Run the study
# study = optuna.create_study(direction="minimize")
# study.optimize(objective, n_trials=10) # Run 10 trials

# print("Best hyperparameters:", study.best_params)

In [27]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = augmented_dataset,
    eval_dataset = None,
    args = SFTConfig(
        max_seq_length = 16384,
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 1, # Use GA to mimic batch size!
        warmup_steps = 20,
        warmup_ratio=0.06,
        # gradient_checkpointing = True,
        num_train_epochs = 2, # Set this for 1 full training run.
        # max_steps = 30,
        learning_rate = 1e-5, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_torch",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        report_to = "none", # Use TrackIO/WandB etc
        bf16=True
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=36):   0%|          | 0/7490 [00:00<?, ? examples/s]

In [28]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

Map (num_proc=36):   0%|          | 0/7490 [00:00<?, ? examples/s]

In [29]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

'<bos><start_of_turn>user\nYou are an expert academic paper reranker. Your task is to re-order the given list of passages (from [1] to [50]) based on their relevance to the query. Respond with only the ranking and nothing else.\n\nExample output for 8 passages:\n[2] > [5] > [4] > [8] > [6] > [1] > [3] > [7]\n\nQuery: Since I have sufficient computing resources, my plan is to use distributed training methods to handle the huge amount of data.\n\nPassages:\n[1] Title: Towards Gaussian Bayesian Network Fusion Abstract: Data sets are growing in complexity thanks to the increasing facilities we have nowadays to both generate and store data. This poses many challenges to machine learning that are leading to the proposal of new methods and paradigms, in order to be able to deal with what is nowadays referred to as Big Data. In this paper we propose a method for the aggregation of different Bayesian network structures that have been learned from separate data sets, as a first step towards mini

In [30]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

'                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

In [None]:
trainer_stats = trainer.train()

The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7,490 | Num Epochs = 2 | Total steps = 7,490
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2
 "-____-"     Trainable parameters = 547,651,584 of 12,734,976,624 (4.30% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,0.4095
2,0.4264
3,0.2696
4,0.3633
5,0.3349
6,0.4766
7,0.4248
8,0.4093
9,0.2457
10,0.1758


In [None]:
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

model.save_pretrained_merged(f"models/{timestamp}/gemma-3-finetune", tokenizer)

In [None]:
model.save_pretrained(f"models/{timestamp}/gemma-3-finetune")  # Local saving
tokenizer.save_pretrained(f"models/{timestamp}/gemma-3-finetune")