# Code Documentation of CheckThat! Subtask 4b: Neural Representation learning

Run these setup cells first.

In [None]:
import os
import random
import pickle
import time
from datetime import datetime

import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

from sentence_transformers import SentenceTransformer, InputExample, losses
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity


In [None]:
# helpers
def normalize(txt: str) -> str:
    return " ".join(str(txt).lower().split())

def compute_metrics(q_emb, doc_emb, gt_ids, doc_ids):
    sims = cosine_similarity(q_emb, doc_emb)
    ranks = []
    rr1 = []
    rr5 = []
    rr10 = []
    for i, row in enumerate(sims):
        # sorted descending
        order = np.argsort(-row)
        gt_idx = doc_ids.index(gt_ids[i])
        rank = int(np.where(order == gt_idx)[0][0]) + 1
        ranks.append(rank)
        rr1.append(1.0/rank if rank <= 1 else 0.0)
        rr5.append(1.0/rank if rank <= 5 else 0.0)
        rr10.append(1.0/rank if rank <= 10 else 0.0)
    ranks = np.array(ranks)
    metrics = {
        "MRR@1": float(np.mean(rr1)),
        "MRR@5": float(np.mean(rr5)),
        "MRR@10": float(np.mean(rr10)),
        "Recall@5": float((ranks <= 5).mean()),
        "Recall@10": float((ranks <= 10).mean()),
        "MedianRank": float(np.median(ranks))
    }
    return metrics

## Base-setup trials

These trials employ dual encoders using an in-batch negatives training regime. Reproduce the base-setup results here by switching out models from the sentence-transformers library:

- sentence-transformers/all-MiniLM-L6-v2
- sentence-transformers/multi-qa-mpnet-base-dot-v1
- sentence-transformers/msmarco-bert-base-dot-v5
- intfloat/e5-large-v2

Additionally, you can reproduce the batch-size trials (e.g. using intfloat/e5-large-v2) and the experiment without fine-tuning.

Places to make changes are clearly mark with "###Change here"

The cell below serves to reproduce the exact results presented in the paper. To reproduce results exactly, only change the following in the "exp" dictionary under CONFIG

- encoder_model (pick possible options from the commented-out options in the cell and paste it after "encoder-model" in the exp variable)
- batch_size (pick values like 8, 16, 32, 64, depending on computational resources and paste it after "batch_size" in the exp variable)
- "fine_tune" = False (for infloat without fine-tuning - for all other experiments "fine_tune" = True)

In [None]:
# ---------------- REPRODUCIBILITY ------------------------------------------------
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

### Models
# "sentence-transformers/all-MiniLM-L6-v2"
# "sentence-transformers/multi-qa-mpnet-base-dot-v1"
# "sentence-transformers/msmarco-bert-base-dot-v5"
# "intfloat/e5-large-v2"
###

# ---------------- CONFIG --------------------------------------------------------
exp = {
    "experiment_name":   "intfloat-e5-large-v2_bs64_lr7e-6",
    ### Change here
    "encoder_model":     "intfloat/e5-large-v2", 
    ### ---
    "query_field":       "tweet_text",
    "normalize":         True,
    ### Change here
    "fine_tune":         True,
    ### ---
    "epochs":            2,
    ### Change here
    "batch_size":        64,
    ### ---
    "lr":                7e-6,
    "use_hard_negatives": False,
}

RUN_ID   = datetime.now().strftime("%Y%m%d_%H%M%S")
SAVE_DIR = f"../models/{exp['experiment_name']}_{RUN_ID}"
os.makedirs(SAVE_DIR, exist_ok=True)

DATA_DIR = "../data"
OUT_CSV  = "../experiment_results/clef_neural_rep_exp_results.csv"
PRED_DIR = "../predictions"
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------- LOAD DATA ----------------------------------------------------
df_coll  = pd.read_pickle(f"{DATA_DIR}/subtask4b_collection_data.pkl")
df_train = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_train.tsv", sep="\t")
df_dev   = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_dev.tsv",   sep="\t")
df_test  = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_test_gold.tsv", sep="\t")

# build document texts
def build_doc(row):
    txt = row.title + " [SEP] " + row.abstract
    return normalize(txt) if exp["normalize"] else txt

doc_ids   = df_coll.cord_uid.tolist()
doc_texts = df_coll.apply(build_doc, axis=1).tolist()

# ---------------- MODEL -------------------------------------------------------
model = SentenceTransformer(exp["encoder_model"], device=DEVICE)

# ---------------- TRAINING DATA ------------------------------------------------
if exp["fine_tune"]:
    examples = []
    for _, row in df_train.iterrows():
        uid = row.cord_uid
        if uid not in doc_ids:
            continue
        q = normalize(row[exp["query_field"]]) if exp["normalize"] else row[exp["query_field"]]
        pos = doc_texts[doc_ids.index(uid)]
        examples.append(InputExample(texts=[q, pos]))

    train_dl = DataLoader(examples, shuffle=True, batch_size=exp["batch_size"],
                          num_workers=4, pin_memory=True)
    loss_fn = losses.MultipleNegativesRankingLoss(model)
    warmup = int(len(train_dl) * exp["epochs"] * 0.1)

    model.fit(
        train_objectives = [(train_dl, loss_fn)],
        epochs = exp["epochs"],
        warmup_steps = warmup,
        optimizer_params = {"lr": exp["lr"]},
        weight_decay = 0.02,
        output_path = SAVE_DIR,
        use_amp = True
    )
    torch.cuda.empty_cache()

# ---------------- EMBEDDINGS --------------------------------------------------
print("Encoding corpus …")
doc_emb = model.encode(doc_texts, batch_size=256,
                       convert_to_numpy=True, show_progress_bar=True)

def encode_queries(df):
    qs = df[exp["query_field"]].tolist()
    if exp["normalize"]:
        qs = [normalize(q) for q in qs]
    return model.encode(qs, batch_size=256,
                        convert_to_numpy=True, show_progress_bar=True)

print("Encoding queries …")
q_emb_tr = encode_queries(df_train)
q_emb_de = encode_queries(df_dev)
q_emb_te = encode_queries(df_test)

# ---------------- METRICS -----------------------------------------------------
print("Computing metrics …")
train_metrics = compute_metrics(
    q_emb_tr, doc_emb,
    df_train.cord_uid.tolist(),
    doc_ids
)
dev_metrics   = compute_metrics(
    q_emb_de, doc_emb,
    df_dev.cord_uid.tolist(),
    doc_ids
)
test_metrics = compute_metrics(
     q_emb_te, doc_emb,
     df_test.cord_uid.tolist(),
      doc_ids
)

print("=== Train Metrics ===")
for k,v in train_metrics.items():
    print(f"{k:8s}: {v:.4f}")
print("=== Dev Metrics ===")
for k,v in dev_metrics.items():
    print(f"{k:8s}: {v:.4f}")
print("=== Test Metrics ===")
for k,v in test_metrics.items():
    print(f"{k:8s}: {v:.4f}")


## Chunked tokenization + additional collection columns

Use the cell below to reproduce the document metadata experiments.

The places to change are clearly marked with "### Change here"". For "use_fields" in exp, put in any combination of the document fields as a python list: ["title", "abstract", "authors", "journal", "source_x"]

In [None]:
# ---------------- REPRODUCIBILITY ------------------------------------------------
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

# ---------------- CONFIG --------------------------------------------------------
exp = {
    "experiment_name": "intfloat-e5-large-v2_lr7e-6_chunked",
    "encoder_model": "intfloat/e5-large-v2",
    "query_field": "tweet_text",
    "normalize": True,
    "fine_tune": True,
    "epochs": 2,
    "batch_size": 64,
    "lr": 7e-6,
    ### Change here
    # specify which document columns to include in encoding
    # e.g. ["title", "abstract", "authors", "journal", "source_x"]
    "use_fields":        ["title", "abstract"] 
    ### ---
}

RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
SAVE_DIR = f"../models/{exp['experiment_name']}_{RUN_ID}"
os.makedirs(SAVE_DIR, exist_ok=True)

DATA_DIR = "../data"
OUT_CSV = "../experiment_results/clef_neural_rep_exp_results_paper_metadata.csv"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------- HELPERS -------------------------------------------------------
def normalize(txt: str) -> str:
    return " ".join(str(txt).lower().split())

# ---------------- LOAD DATA ----------------------------------------------------
df_coll  = pd.read_pickle(f"{DATA_DIR}/subtask4b_collection_data.pkl")
df_train = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_train.tsv", sep="\t")
df_dev   = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_dev.tsv",   sep="\t")
df_test = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_test_gold.tsv", sep="\t")

# build document texts using chosen fields, with custom preprocessing per field

def build_doc(row):
    parts = []
    for field in exp["use_fields"]:
        if field == "title":
            parts.append(row.title)
        elif field == "abstract":
            parts.append(row.abstract)
        elif field == "authors":  # split long author lists, keep top 5
            auths = [a.strip() for a in str(row.authors).split(";")]
            parts.append("Authors: " + ", ".join(auths[:5]))
        elif field == "journal":  
            parts.append("Journal: " + str(row.journal))
        elif field == "source_x":  # normalize semicolon list
            srcs = [s.strip() for s in str(row.source_x).split(";")]
            parts.append("Source: " + ", ".join(srcs))
        else:
            # fallback
            parts.append(str(getattr(row, field, "")))
    txt = " [SEP] ".join(parts)
    return normalize(txt) if exp["normalize"] else txt

# apply to all documents
ndocs = len(df_coll)
doc_ids   = df_coll.cord_uid.tolist()
doc_texts = [build_doc(r) for _, r in df_coll.iterrows()]

# ---------------- MODEL & TOKENIZER -------------------------------------------
model     = SentenceTransformer(exp["encoder_model"], device=DEVICE)
tokenizer = AutoTokenizer.from_pretrained(exp["encoder_model"])

def chunk_text(text, max_length=510, stride=50):
    enc = tokenizer(
        text,
        truncation=False,
        max_length=max_length,
        stride=stride,
        return_overflowing_tokens=True
    )
    return [tokenizer.decode(ids, skip_special_tokens=True)
            for ids in enc["input_ids"]]

# ---------------- TRAINING DATA ------------------------------------------------
if exp["fine_tune"]:
    examples = []
    for _, row in df_train.iterrows():
        uid = row.cord_uid
        if uid not in doc_ids:
            continue
        q_text = normalize(row[exp["query_field"]]) if exp["normalize"] else row[exp["query_field"]]
        doc_text = doc_texts[doc_ids.index(uid)]
        for chunk in chunk_text(doc_text):
            examples.append(InputExample(texts=[q_text, chunk]))

    train_dl = DataLoader(
        examples,
        shuffle=True,
        batch_size=exp["batch_size"],
        num_workers=4,
        pin_memory=True
    )
    loss_fn = losses.MultipleNegativesRankingLoss(model)
    warmup = int(len(train_dl) * exp["epochs"] * 0.1)

    model.fit(
        train_objectives=[(train_dl, loss_fn)],
        epochs          = exp["epochs"],
        warmup_steps    = warmup,
        optimizer_params={"lr": exp["lr"]},
        weight_decay    = 0.02,
        output_path     = SAVE_DIR,
        use_amp         = True
    )
    torch.cuda.empty_cache()

# ---------------- EMBEDDINGS --------------------------------------------------
print("Encoding corpus with chunking …")
all_doc_embs = []
for txt in doc_texts:
    chunks    = chunk_text(txt)
    chunk_embs = model.encode(
        chunks,
        batch_size=exp["batch_size"],
        convert_to_numpy=True,
        show_progress_bar=False
    )
    mean_emb = np.mean(chunk_embs, axis=0)
    max_emb  = np.max(chunk_embs, axis=0)
    all_doc_embs.append((mean_emb + max_emb) / 2.0)

# stack into (n_docs × dim) array
doc_emb = np.vstack(all_doc_embs)

def encode_queries(df, batch_size=256):
    qs = df[exp["query_field"]].tolist()
    if exp["normalize"]:
        qs = [normalize(q) for q in qs]
    return model.encode(qs, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True)

print("Encoding queries …")
q_emb_tr = encode_queries(df_train)
q_emb_de = encode_queries(df_dev)
q_emb_te = encode_queries(df_test)

# ---------------- METRICS -----------------------------------------------------
print("Computing metrics …")
train_metrics = compute_metrics(q_emb_tr, doc_emb, df_train.cord_uid.tolist(), doc_ids)
dev_metrics   = compute_metrics(q_emb_de, doc_emb, df_dev.cord_uid.tolist(), doc_ids)
test_metrics = compute_metrics(q_emb_te, doc_emb, df_test.cord_uid.tolist(), doc_ids)

print("=== Train Metrics ===")
for k, v in train_metrics.items():
    print(f"{k:8s}: {v:.4f}")
print("=== Dev Metrics ===")
for k, v in dev_metrics.items():
    print(f"{k:8s}: {v:.4f}")
print("=== Test Metrics ===")
for k, v in test_metrics.items():
    print(f"{k:8s}: {v:.4f}")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss


Encoding corpus with chunking …
Encoding queries …


Batches:   0%|          | 0/51 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Computing metrics …
=== Train Metrics ===
MRR@1   : 0.6622
MRR@5   : 0.7366
MRR@10  : 0.7430
Recall@5: 0.8487
Recall@10: 0.8947
MedianRank: 1.0000
=== Dev Metrics ===
MRR@1   : 0.6443
MRR@5   : 0.7040
MRR@10  : 0.7112
Recall@5: 0.7936
Recall@10: 0.8457
MedianRank: 1.0000
Logged results to ../experiment_results/clef_neural_rep_exp_results_paper_metadata.csv


## Adding a single Hard Negative per query using MultipleNegativesRankingLoss

No changes needed - setup will exactly reproduce the hard negative trial.

In [None]:
# ---------------- REPRODUCIBILITY ------------------------------------------------
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

# ---------------- CONFIG --------------------------------------------------------
exp = {
    "experiment_name": "intfloat-e5-large-v2_lr7e-6_HN_MNRL",
    "encoder_model": "intfloat/e5-large-v2",
    "query_field": "tweet_text",
    "normalize": True,
    "fine_tune": True,
    "epochs":  2,
    "batch_size": 64,
    "lr": 7e-6,
    "use_hard_negatives": True,
}

RUN_ID   = datetime.now().strftime("%Y%m%d_%H%M%S")
SAVE_DIR = f"../models/{exp['experiment_name']}_{RUN_ID}"
os.makedirs(SAVE_DIR, exist_ok=True)

DATA_DIR = "../data"
OUT_CSV  = "../experiment_results/clef_neural_rep_exp_results.csv"
PRED_DIR = "../predictions"
DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------- LOAD DATA ----------------------------------------------------
df_coll  = pd.read_pickle(f"{DATA_DIR}/subtask4b_collection_data.pkl")
df_train = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_train.tsv", sep="\t")
df_dev   = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_dev.tsv",   sep="\t")
df_test = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_test_gold.tsv", sep="\t")

# build document texts
def build_doc(row):
    txt = row.title + " [SEP] " + row.abstract
    return normalize(txt) if exp["normalize"] else txt

doc_ids   = df_coll.cord_uid.tolist()
doc_texts = df_coll.apply(build_doc, axis=1).tolist()

# ---------------- MODEL -------------------------------------------------------
model = SentenceTransformer(exp["encoder_model"], device=DEVICE)

# MEMORY SAVING TWEAKS 
try:
    #  Turn on HF gradient checkpointing
    xformer = model._first_module().auto_model
    xformer.gradient_checkpointing_enable()

    #  Turn off KV-cache 
    xformer.config.use_cache = False

    print("Enabled HF gradient-checkpointing and disabled use_cache")
except Exception as e:
    print("Could not enable gradient checkpointing:", e)
# ---------------- TRAINING DATA ------------------------------------------------
if exp["fine_tune"]:
    # load the 1xBM25 negative per tweet
    hard_negs = pickle.load(open("../cache/hard_negs_v4.pkl", "rb"))

    examples = []
    for _, row in df_train.iterrows():
        uid = row.cord_uid
        if uid not in doc_ids:
            continue

        # normalize query & positive as before
        q   = normalize(row[exp["query_field"]]) if exp["normalize"] else row[exp["query_field"]]
        pos = doc_texts[doc_ids.index(uid)]

        # pull exactly one BM25 negative for this tweet
        neg_uid = hard_negs.get(int(row.post_id), [None])[0]
        if neg_uid in doc_ids:
            neg = doc_texts[doc_ids.index(neg_uid)]
            examples.append(InputExample(texts=[q, pos, neg]))
        else:
            # fallback to pair‐only if no hard negative found (does not happen)
            examples.append(InputExample(texts=[q, pos]))

    train_dl = DataLoader(examples, shuffle=True, batch_size=exp["batch_size"],
                          num_workers=4, pin_memory=True)
    loss_fn = losses.MultipleNegativesRankingLoss(model)
    warmup = int(len(train_dl) * exp["epochs"] * 0.1)

    model.fit(
        train_objectives = [(train_dl, loss_fn)],
        epochs = exp["epochs"],
        warmup_steps = warmup,
        optimizer_params = {"lr": exp["lr"]},
        weight_decay = 0.02,
        output_path = SAVE_DIR,
        use_amp = True
    )
    torch.cuda.empty_cache()

# ---------------- EMBEDDINGS --------------------------------------------------
print("Encoding corpus …")
doc_emb = model.encode(doc_texts, batch_size=256,
                       convert_to_numpy=True, show_progress_bar=True)

def encode_queries(df):
    qs = df[exp["query_field"]].tolist()
    if exp["normalize"]:
        qs = [normalize(q) for q in qs]
    return model.encode(qs, batch_size=256,
                        convert_to_numpy=True, show_progress_bar=True)

print("Encoding queries …")
q_emb_tr = encode_queries(df_train)
q_emb_de = encode_queries(df_dev)
q_emb_te = encode_queries(df_test)

# ---------------- METRICS -----------------------------------------------------
print("Computing metrics …")
train_metrics = compute_metrics(
    q_emb_tr, doc_emb,
    df_train.cord_uid.tolist(),
    doc_ids
)
dev_metrics   = compute_metrics(
    q_emb_de, doc_emb,
    df_dev.cord_uid.tolist(),
    doc_ids
)
test_metrics = compute_metrics(
    q_emb_te,
    doc_emb,
    df_test.cord_uid.tolist(),
    doc_ids
)


print("=== Train Metrics ===")
for k,v in train_metrics.items():
    print(f"{k:8s}: {v:.4f}")
print("=== Dev Metrics ===")
for k,v in dev_metrics.items():
    print(f"{k:8s}: {v:.4f}")
print("=== Test Metrics ===")
for k,v in test_metrics.items():
    print(f"{k:8s}: {v:.4f}")

## Combine chunked tokenization, additional metadata fields and hard negatives

No changes needed - will exactly reproduce the trial using hard negatives as well as chunked tokenization with additional metadata fields.

In [None]:
# ---------------- REPRODUCIBILITY ------------------------------------------------
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

# ---------------- CONFIG --------------------------------------------------------
exp = {
    "experiment_name":  "intfloat-e5-large-v2_lr7e-6_HN_chunked_all_fields",
    "encoder_model": "intfloat/e5-large-v2",
    "query_field": "tweet_text",
    "normalize": True,
    "fine_tune": True,
    "epochs": 2,
    "batch_size": 64,
    "lr": 7e-6,
    "use_hard_negatives": True,
    "use_fields": ["title", "abstract", "authors", "journal", "source_x"],
    "chunk_max_length": 510,
    "chunk_stride": 50
}

RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
SAVE_DIR = f"../models/{exp['experiment_name']}_{RUN_ID}"
DATA_DIR = "../data"
CACHE_NEG = "../cache/hard_negs_v4.pkl"
OUT_CSV = "../experiment_results/clef_neural_rep_exp_results.csv"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(os.path.dirname(OUT_CSV), exist_ok=True)

def build_doc(row):
    parts = []
    for field in exp["use_fields"]:
        if field == "title":
            parts.append(row.title)
        elif field == "abstract":
            parts.append(row.abstract)
        elif field == "authors":
            auths = [a.strip() for a in str(row.authors).split(";")]
            parts.append("Authors: " + ", ".join(auths[:5]))
        elif field == "journal":
            parts.append("Journal: " + str(row.journal))
        elif field == "source_x":
            srcs = [s.strip() for s in str(row.source_x).split(";")]
            parts.append("Source: " + ", ".join(srcs))
        else:
            parts.append(str(getattr(row, field, "")))
    txt = " [SEP] ".join(parts)
    return normalize(txt) if exp["normalize"] else txt

# ---------------- LOAD & PREPARE DATA ------------------------------------------
df_coll  = pd.read_pickle(f"{DATA_DIR}/subtask4b_collection_data.pkl")
df_train = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_train.tsv", sep="\t")
df_dev   = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_dev.tsv",   sep="\t")
df_test = pd.read_csv(f"{DATA_DIR}/subtask4b_query_tweets_test_gold.tsv", sep="\t")

doc_ids   = df_coll.cord_uid.tolist()
doc_texts = [build_doc(r) for _, r in df_coll.iterrows()]

# ---------------- TOKENIZER & CHUNKING -----------------------------------------
tokenizer = AutoTokenizer.from_pretrained(exp["encoder_model"])
def chunk_text(text):
    enc = tokenizer(
        text,
        truncation=False,
        max_length=exp["chunk_max_length"],
        stride=exp["chunk_stride"],
        return_overflowing_tokens=True
    )
    return [
        tokenizer.decode(ids, skip_special_tokens=True)
        for ids in enc["input_ids"]
    ]

# ---------------- MODEL & MEMORY TWEAKS ----------------------------------------
model = SentenceTransformer(exp["encoder_model"], device=DEVICE)
try:
    tm = model._first_module().auto_model
    tm.gradient_checkpointing_enable()
    tm.config.use_cache = False
    print("Enabled gradient checkpointing, disabled cache")
except:
    pass

# ---------------- TRAINING DATA ------------------------------------------------
if exp["fine_tune"]:
    hard_negs = pickle.load(open(CACHE_NEG, "rb")) if exp["use_hard_negatives"] else {}
    examples = []
    for _, row in df_train.iterrows():
        uid = row.cord_uid
        if uid not in doc_ids:
            continue

        q_text   = normalize(row[exp["query_field"]]) if exp["normalize"] else row[exp["query_field"]]
        pos_text = doc_texts[doc_ids.index(uid)]
        pos_chunks = chunk_text(pos_text)

        neg_chunks = None
        if exp["use_hard_negatives"]:
            neg_uid = hard_negs.get(int(row.post_id), [None])[0]
            if neg_uid in doc_ids:
                neg_text   = doc_texts[doc_ids.index(neg_uid)]
                neg_chunks = chunk_text(neg_text)

        for p in pos_chunks:
            if neg_chunks:
                for n in neg_chunks:
                    examples.append(InputExample(texts=[q_text, p, n]))
            else:
                examples.append(InputExample(texts=[q_text, p]))

    train_dl = DataLoader(
        examples,
        shuffle=True,
        batch_size=exp["batch_size"],
        num_workers=4,
        pin_memory=True
    )
    loss_fn = losses.MultipleNegativesRankingLoss(model)
    warmup  = int(len(train_dl) * exp["epochs"] * 0.1)

    model.fit(
        train_objectives=[(train_dl, loss_fn)],
        epochs = exp["epochs"],
        warmup_steps = warmup,
        optimizer_params={"lr": exp["lr"]},
        weight_decay = 0.02,
        output_path = SAVE_DIR,
        use_amp = True
    )
    torch.cuda.empty_cache()

# ---------------- EMBEDDINGS & EVALUATION -------------------------------------
print("Encoding corpus with chunking …")
all_doc_embs = []
for txt in doc_texts:
    ch = chunk_text(txt)
    embs = model.encode(ch,
                        batch_size=exp["batch_size"],
                        convert_to_numpy=True,
                        show_progress_bar=False)
    all_doc_embs.append((embs.mean(axis=0) + embs.max(axis=0)) / 2.0)
doc_emb = np.vstack(all_doc_embs)

def encode_queries(df):
    qs = df[exp["query_field"]].tolist()
    if exp["normalize"]:
        qs = [normalize(q) for q in qs]
    return model.encode(qs,
                        batch_size=256,
                        convert_to_numpy=True,
                        show_progress_bar=True)

print("Encoding queries …")
q_emb_tr = encode_queries(df_train)
q_emb_de = encode_queries(df_dev)
q_emb_te = encode_queries(df_test)

print("Computing metrics …")
train_metrics = compute_metrics(q_emb_tr, doc_emb, df_train.cord_uid.tolist(), doc_ids)
dev_metrics   = compute_metrics(q_emb_de, doc_emb, df_dev.cord_uid.tolist(), doc_ids)
test_metrics = compute_metrics(q_emb_te, doc_emb, df_test.cord_uid.tolist(), doc_ids)

print("=== Train Metrics ===")
for k, v in train_metrics.items():
    print(f"{k:8s}: {v:.4f}")
print("=== Dev Metrics ===")
for k, v in dev_metrics.items():
    print(f"{k:8s}: {v:.4f}")
print("=== Test Metrics ===")
for k, v in test_metrics.items():
    print(f"{k:8s}: {v:.4f}")
