In [None]:
import json
import torch
from tqdm import tqdm
from bs4 import BeautifulSoup
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer, util
import numpy as np

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

HF_BACKBONE = "avsolatorio/GIST-Embedding-v0"

hf_tokenizer = AutoTokenizer.from_pretrained(HF_BACKBONE)
hf_model = AutoModel.from_pretrained(HF_BACKBONE).to(device)

# REMOVE SentenceTransformer -- use HF model for queries too
def encode_query(text):
    """Use the SAME HF model to encode queries."""
    toks = hf_tokenizer(
        text,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        out = hf_model(**toks)

    # use CLS token (index=0)
    return out.last_hidden_state[:, 0, :]


Device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/747 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [None]:
def encode_full_doc_tokens(full_text, max_len=512):
    # Tokenize once to get ALL global offsets for the entire document
    base = hf_tokenizer(
        full_text,
        return_offsets_mapping=True,
        add_special_tokens=False
    )

    ids = base["input_ids"]               # full sequence token ids
    offsets = base["offset_mapping"]      # full-text global char offsets

    all_emb = []
    all_offsets = offsets                 # offsets already full-text aligned

    # Encode in 512-token segments
    for i in range(0, len(ids), max_len):
        seg_ids = ids[i:i + max_len]

        seg_inputs = {
            "input_ids": torch.tensor([seg_ids], dtype=torch.long, device=device),
            "attention_mask": torch.ones((1, len(seg_ids)), dtype=torch.long, device=device)
        }

        with torch.no_grad():
            out = hf_model(**seg_inputs)

        hidden = out.last_hidden_state.squeeze(0)   # shape = (seg_len, hidden_dim)
        all_emb.append(hidden.cpu())

    # Concatenate all segment embeddings → full_doc_token_embeddings
    full_emb = torch.cat(all_emb, dim=0).to(device)

    return full_emb, all_offsets


In [None]:
def find_token_span_for_chunk(chunk_words, full_words, offsets):
    """
    chunk_words: chunk.split()
    full_words: full_text.split()
    offsets: list of (char_start, char_end) for each token
    """

    # 1) 在 full_words 里找到 chunk_words 的起点
    found_start = None
    for i in range(len(full_words)):
        if full_words[i:i + len(chunk_words)] == chunk_words:
            found_start = i
            break

    if found_start is None:
        return None, None

    found_end = found_start + len(chunk_words) - 1

    # 2) 将 word index 转成 token index：
    # full_words 是由 full_text.split() 得到，所以每个 word 对应一个 offset 区间
    if found_start >= len(offsets) or found_end >= len(offsets):
        return None, None

    return found_start, found_end


In [None]:
# ---------------------------------------------------
# ③ Pooling
# ---------------------------------------------------
def pool_chunk(token_emb, start_tok, end_tok):
    sub = token_emb[start_tok:end_tok+1]
    return sub.mean(dim=0)

In [None]:
def late_chunking_encode(full_text, chunks):
    token_emb, offsets = encode_full_doc_tokens(full_text)
    out = []

    for ch in chunks:
        st = ch["start_tok"]
        ed = ch["end_tok"]

        # boundary safety
        if st < 0 or ed >= token_emb.size(0):
            out.append(encode_query(ch["text"]).squeeze(0))
            continue

        pooled = pool_chunk(token_emb, st, ed)
        out.append(pooled)

    return torch.stack(out)


In [None]:
def late_chunker(full_text, max_tokens=512):
    """
    Return list of dicts:
    {
       "text": "...",
       "start_tok": int,
       "end_tok": int
    }
    """

    base = hf_tokenizer(
        full_text,
        return_offsets_mapping=True,
        add_special_tokens=False
    )

    ids = base["input_ids"]
    offsets = base["offset_mapping"]

    chunks = []
    cur_ids = []
    cur_offsets = []
    start_tok_idx = 0  # track start token index in full sequence

    for global_tok_idx, (tok_id, off) in enumerate(zip(ids, offsets)):
        cur_ids.append(tok_id)
        cur_offsets.append(off)

        # chunk full
        if len(cur_ids) >= max_tokens:
            start_char = cur_offsets[0][0]
            end_char = cur_offsets[-1][1]
            chunk_text = full_text[start_char:end_char]

            chunks.append({
                "text": chunk_text,
                "start_tok": start_tok_idx,
                "end_tok": start_tok_idx + len(cur_ids) - 1
            })

            # reset
            start_tok_idx += len(cur_ids)
            cur_ids = []
            cur_offsets = []

    # last chunk
    if cur_ids:
        start_char = cur_offsets[0][0]
        end_char = cur_offsets[-1][1]
        chunk_text = full_text[start_char:end_char]

        chunks.append({
            "text": chunk_text,
            "start_tok": start_tok_idx,
            "end_tok": start_tok_idx + len(cur_ids) - 1
        })

    return chunks


In [None]:
def find_gold_chunk(chunks, gold_start, gold_end):
    """
    chunks: list of dicts:
        { "text": ..., "start_tok": int, "end_tok": int }

    gold_start, gold_end: gold token indices from NQ
    """

    for idx, ch in enumerate(chunks):
        # condition: chunk covers the gold span
        if ch["start_tok"] <= gold_start and ch["end_tok"] >= gold_end:
            return idx

    return None

In [None]:
import time
from tqdm import tqdm
import json
import numpy as np

def evaluate(dataset_path, method_name="LateChunking"):
    t0 = time.time()

    ranks = []
    skipped = 0
    total_chunks = 0
    total_samples = 0

    tqdm_bar = tqdm(desc=f"Evaluating {method_name}")

    with open(dataset_path, "r", encoding="utf-8") as f:
        for item in map(json.loads, f):
            tqdm_bar.update(1)
            total_samples += 1

            html = item["document_html"]
            question = item["question_text"]
            doc_tokens = item["document_tokens"]

            # extract text
            soup = BeautifulSoup(html, "html.parser")
            full_text = soup.get_text(" ", strip=True)

            # chunking
            chunks = late_chunker(full_text, max_tokens=512)
            if not chunks:
                skipped += 1
                continue

            total_chunks += len(chunks)

            # embedding
            try:
                chunk_emb = late_chunking_encode(full_text, chunks)
                q_emb = encode_query(question)[0]
            except:
                skipped += 1
                continue

            # ranking
            scores = util.cos_sim(q_emb, chunk_emb)[0]
            ranking = scores.argsort(descending=True).cpu().numpy()

            # gold span
            ann = item["annotations"][0]
            if ann["short_answers"]:
                gs = ann["short_answers"][0]["start_token"]
                ge = ann["short_answers"][0]["end_token"]
            else:
                gs = ann["long_answer"]["start_token"]
                ge = ann["long_answer"]["end_token"]

            if gs < 0 or ge < 0:
                skipped += 1
                continue

            # find gold chunk
            gold_idx = find_gold_chunk(chunks, gs, ge)
            if gold_idx is None:
                skipped += 1
                continue

            rank = np.where(ranking == gold_idx)[0][0] + 1
            ranks.append(rank)

    tqdm_time_str = tqdm_bar.format_dict["elapsed"]
    tqdm_bar.close()

    # ===== metrics =====
    recall10 = np.mean([1 if r <= 10 else 0 for r in ranks]) if ranks else 0
    mrr = np.mean([1.0 / r for r in ranks]) if ranks else 0
    ndcg = np.mean([1 / np.log2(r + 1) for r in ranks]) if ranks else 0

    elapsed_seconds = time.time() - t0
    items_per_sec = total_samples / elapsed_seconds if elapsed_seconds > 0 else 0
    avg_num_chunks = (total_chunks / len(ranks)) if ranks else 0

    # ==== result dict ====
    result = {
        "method": method_name,
        "recall@10": float(recall10),
        "mrr": float(mrr),
        "ndcg": float(ndcg),
        "total_samples": total_samples,
        "skipped": skipped,
        "avg_num_chunks": float(avg_num_chunks),
        "elapsed_seconds": float(elapsed_seconds),
        "items_per_sec": float(items_per_sec),
        "tqdm_time_str": str(tqdm_time_str),
        "dataset": dataset_path
    }

    print(json.dumps(result, indent=2))
    return result


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/CIS-5200-final-Text2Vec

/content/drive/MyDrive/CIS-5200-final-Text2Vec


In [None]:
# ---------------------------------------------------
# ⑧ Run
# ---------------------------------------------------
if __name__ == "__main__":
    datasets = [
        "data/nq_dev_short_cleaned.jsonl",
        "data/nq_dev_medium_cleaned.jsonl",
        "data/nq_dev_long_cleaned.jsonl",
    ]

    for ds in datasets:
        print("\n==============================")
        print(f"Evaluating dataset: {ds}")
        print("==============================\n")
        evaluate(ds)



Evaluating dataset: data/nq_dev_medium_cleaned.jsonl



Evaluating Late Chunking: 1244it [05:41,  3.64it/s]


Recall@10: 0.8409506398537477 MRR: 0.3032399074645411

Evaluating dataset: data/nq_dev_long_cleaned.jsonl



Evaluating Late Chunking: 277it [03:11,  1.84it/s]