In [1]:
!pip install sentence-transformers datasets



In [2]:
pip install numpy<2

zsh:1: no such file or directory: 2
Note: you may need to restart the kernel to use updated packages.


In [8]:
# Step 0: Auto Device Selection
import torch

def get_best_device():
    """
    Automatically detect the best available device: CUDA > MPS > CPU.
    """
    if torch.cuda.is_available():
        print("Using CUDA GPU.")
        return torch.device("cuda")
    elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        print("Using Apple MPS.")
        return torch.device("mps")
    else:
        print("Using CPU.")
        return torch.device("cpu")

DEVICE = get_best_device()

Using Apple MPS.


In [9]:
# Step 1: Load the first Demo 100 Chinese news documents from TREC NeuCLIR1 using streaming mode
from datasets import load_dataset

# Streaming mode: only download and iterate over needed samples
streamed = load_dataset('neuclir/neuclir1', split='zho', streaming=True)
docs = []
for i, item in enumerate(streamed):
    if i >= 100:
        break
    docs.append({
        "doc_id": item["id"],
        "text": f"{item['title']} {item['text']}"
    })
print("Step 1 complete: Loaded 100 Chinese news documents into variable 'docs'.")

Step 1 complete: Loaded 100 Chinese news documents into variable 'docs'.


In [10]:
# Step 2: Load official NeuCLIR 2024 queries for Chinese (or other language as needed)
import json

QUERYS_FILE = "/Users/lianqi/Desktop/TODOS/IR/Project_multiretrieval/neuclir24.topics.0614.jsonl.txt"
# Load all queries from NeuCLIR 2024 topics file
queries_zh = []  # Chinese
#queries_fa = []  # Persian
#queries_ru = []  # Russian

with open(QUERYS_FILE, "r", encoding="utf-8") as f:
    for line in f:
        if not line.strip():
            continue
        q_big = json.loads(line)
        # q_big["topics"] 是一个列表，遍历里面每个 topic
        for q in q_big.get("topics", []):
            # 中文
            if q["lang"] == "zho":
                queries_zh.append({
                    "query_id": q_big["topic_id"],  # 用顶层 topic_id 作为 query_id
                    "title": q["topic_title"],
                    "description": q["topic_description"],
                    "narrative": q["topic_narrative"]
                })
"""
            # 波斯语
            if q["lang"] == "fas":
                queries_fa.append({
                    "query_id": q_big["topic_id"],
                    "title": q["topic_title"],
                    "description": q["topic_description"],
                    "narrative": q["topic_narrative"]
                })
            # 俄语
            if q["lang"] == "rus":
                queries_ru.append({
                    "query_id": q_big["topic_id"],
                    "title": q["topic_title"],
                    "description": q["topic_description"],
                    "narrative": q["topic_narrative"]
                })

"""
# For demo, just use first 2 Chinese queries for a quick test
queries = queries_zh[:2]
# queries = queries_fa[:2] # Persian (uncomment to use)
# queries = queries_ru[:2] # Russian (uncomment to use)

print("Selected queries for demo:")
for q in queries:
    print(q)

# Example: use queries_zh for Chinese, queries_fa for Persian, queries_ru for Russian
# To use a different language, just replace 'queries_zh' with your target, e.g.:
#   queries = queries_zh      # For Chinese document retrieval
#   queries = queries_fa      # For Persian document retrieval
#   queries = queries_ru      # For Russian document retrieval


Selected queries for demo:
{'query_id': 300, 'title': '“日本自杀率 COVID-19', 'description': 'COVID-19 大流行对日本自杀率有何影响？', 'narrative': '查找有关冠状病毒大流行对日本自杀率影响的文章和报告。如果文章提及在 COVID-19 大流行期间其他次要原因，也可归为是相关的。只谈论日本高自杀率及其原因而不提及 COVID-19 大流行的文章是不相关的。与其他国家的 COVID-19 大流行期间自杀率相关的文章是无关紧要的。'}
{'query_id': 300, 'title': '日本新冠肺炎 (COVID-19) 自杀率\n', 'description': 'COVID-19 大流行对日本自杀率有何影响？', 'narrative': '查找有关冠状病毒大流行对日本自杀率影响的文章和报告。如果在 COVID-19 大流行期间提到其他次要原因，它们也被认为是相关的。只谈论日本高自杀率及其原因而不提及 COVID-19 大流行的文章是不相关的。与其他国家的 COVID-19 大流行中自杀流行率相关的文章是无关紧要的。'}


In [11]:
# Step 3: Prepare for mBERT, m3, MGTE
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
import torch

# Model info: (name, loading_type)
model_info = [
    ("bert-base-multilingual-cased", "hf"),        # mBERT
    ("bert-base-multilingual-uncased", "hf"),      # m3
    ("intfloat/multilingual-e5-large-instruct", "st")  # MGTE
]

def encode_with_hf(model_name, texts, return_token_embs=False):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(DEVICE)
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(texts, return_tensors='pt', truncation=True, padding=True, max_length=256)
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        outputs = model(**inputs)
        last_hidden = outputs.last_hidden_state
        attention_mask = inputs['attention_mask']
        # Mean pooling for sentence embedding
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
        sum_embeddings = torch.sum(last_hidden * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        embeddings = sum_embeddings / sum_mask
        if return_token_embs:
            # Get all valid token embeddings (remove padding)
            token_embeddings = []
            for i in range(last_hidden.size(0)):
                mask = attention_mask[i].bool()
                valid_tokens = last_hidden[i][mask]
                token_embeddings.append(valid_tokens)
            return embeddings, token_embeddings
        else:
            return embeddings

def encode_with_st(model_name, texts, return_token_embs=False):
    model = SentenceTransformer(model_name, device=str(DEVICE))
    if return_token_embs:
        tokenizer = model.tokenizer
        transformer = model._first_module().auto_model
        with torch.no_grad():
            inputs = tokenizer(texts, return_tensors='pt', truncation=True, padding=True, max_length=256)
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            outputs = transformer(**inputs)
            last_hidden = outputs.last_hidden_state
            attention_mask = inputs['attention_mask']
            token_embeddings = []
            for i in range(last_hidden.size(0)):
                mask = attention_mask[i].bool()
                valid_tokens = last_hidden[i][mask]
                token_embeddings.append(valid_tokens)
            # Mean pooling for sentence embedding
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
            sum_embeddings = torch.sum(last_hidden * input_mask_expanded, 1)
            sum_mask = input_mask_expanded.sum(1)
            embeddings = sum_embeddings / sum_mask
            return embeddings, token_embeddings
    else:
        return model.encode(texts, convert_to_tensor=True, device=DEVICE)


# Prepare doc texts and IDs
doc_texts = [d["text"] for d in docs]
doc_ids = [d["doc_id"] for d in docs]

# Prepare TF-IDF for BWE-Agg-IDF
tfidf = TfidfVectorizer()
tfidf.fit(doc_texts + [f"{q['title']} {q['description']} {q['narrative']}" for q in queries])
idf_dict = dict(zip(tfidf.get_feature_names_out(), tfidf.idf_))

print("Step 3 complete: Models and IDF dictionary ready.")

Step 3 complete: Models and IDF dictionary ready.


In [12]:
# Step 4: Retrieval loop with multiple encoders and strategies
import numpy as np

# --- Utility: Align tokenization for IDF and embeddings ---
def get_idf_aligned(tokenizer, text, idf_dict):
    """
    Given a tokenizer and text, return the list of IDF weights for the tokens,
    using the same tokenization as for embedding extraction. 
    This ensures IDF weights and token embeddings are strictly aligned 
    to avoid shape mismatch.
    """
    tokens = tokenizer.tokenize(text)
    # Remove special tokens like [CLS], [SEP] if present
    tokens = [t for t in tokens if not t.startswith('[')]
    idfs = [idf_dict.get(t.lower(), 1.0) for t in tokens]
    return idfs

# --- Main retrieval loop ---

for model_name, loader in model_info:
    print(f"\n\n=== Using encoder: {model_name} ===")
    if loader == "st":
        model = SentenceTransformer(model_name, device=str(DEVICE))
        tokenizer = model.tokenizer
        doc_sent_embs, doc_token_embs = encode_with_st(model_name, doc_texts, return_token_embs=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        doc_sent_embs, doc_token_embs = encode_with_hf(model_name, doc_texts, return_token_embs=True)
    
    for query in queries:
        q_text = f"{query['title']} {query['description']} {query['narrative']}"
        if loader == "st":
            q_sent_emb, q_token_embs = encode_with_st(model_name, [q_text], return_token_embs=True)
            q_token_embs = q_token_embs[0]
        else:
            q_sent_emb, q_token_embs = encode_with_hf(model_name, [q_text], return_token_embs=True)
            q_token_embs = q_token_embs[0]

        print(f"\n--- Query: {query['title']} ---")

        # --- maxSim ---
        maxsim_scores = []
        for doc_toks in doc_token_embs:
            sims = []
            for q_tok_vec in q_token_embs:
                sim = torch.nn.functional.cosine_similarity(q_tok_vec.unsqueeze(0), doc_toks)
                sims.append(sim.max().item())
            score = np.mean(sims)
            maxsim_scores.append(score)
        maxsim_scores = torch.tensor(maxsim_scores)
        topk = torch.topk(maxsim_scores, 3)
        print(f"\n   [maxSim] Top 3 results:")
        for rank, idx in enumerate(topk.indices, 1):
            print(f"      Rank {rank}: DocID {doc_ids[idx]}, Score: {maxsim_scores[idx]:.4f}")
            print(f"         Text snippet: {doc_texts[idx][:60]}...")

        # --- BWE-Agg-IDF ---
        # Use utility to ensure IDF and embedding alignment!
        q_idfs = get_idf_aligned(tokenizer, q_text, idf_dict)
        emb_num = q_token_embs.shape[0]
        idf_num = len(q_idfs)
        if emb_num > idf_num:
            q_token_embs = q_token_embs[:idf_num]
        elif idf_num > emb_num:
            q_idfs = q_idfs[:emb_num]
        if len(q_idfs) > 0:
            q_idfs_tensor = torch.tensor(q_idfs, device=DEVICE)
            q_vec = (q_token_embs.T @ q_idfs_tensor).T / sum(q_idfs)
        else:
            q_vec = q_token_embs.mean(0)

        agg_idf_scores = []
        for doc_idx, doc_toks in enumerate(doc_token_embs):
            # Use the same alignment for doc tokens
            doc_idfs = get_idf_aligned(tokenizer, doc_texts[doc_idx], idf_dict)
            emb_num = doc_toks.shape[0]
            idf_num = len(doc_idfs)
            if emb_num > idf_num:
                doc_toks = doc_toks[:idf_num]
            elif idf_num > emb_num:
                doc_idfs = doc_idfs[:emb_num]
            if len(doc_idfs) > 0:
                d_idfs_tensor = torch.tensor(doc_idfs, device=doc_toks.device)
                doc_vec = (doc_toks.T @ d_idfs_tensor).T / sum(doc_idfs)
            else:
                doc_vec = doc_toks.mean(0)
            score = torch.nn.functional.cosine_similarity(
                q_vec.unsqueeze(0), doc_vec.unsqueeze(0)
            ).item()
            agg_idf_scores.append(score)
        agg_idf_scores = torch.tensor(agg_idf_scores)
        topk = torch.topk(agg_idf_scores, 3)
        print(f"\n   [BWE-Agg-IDF] Top 3 results:")
        for rank, idx in enumerate(topk.indices, 1):
            print(f"      Rank {rank}: DocID {doc_ids[idx]}, Score: {agg_idf_scores[idx]:.4f}")
            print(f"         Text snippet: {doc_texts[idx][:60]}...")

print("\nDemo complete! You now have maxSim and BWE-Agg-IDF retrieval results for all three encoders.")

# model_info = [
#     ("bert-base-multilingual-cased", "hf"),        # mBERT
#     ("bert-base-multilingual-uncased", "hf"),      # m3
#     ("intfloat/multilingual-e5-large-instruct", "st")  # MGTE
# ]



=== Using encoder: bert-base-multilingual-cased ===

--- Query: “日本自杀率 COVID-19 ---


  q_vec = (q_token_embs.T @ q_idfs_tensor).T / sum(q_idfs)
Token indices sequence length is longer than the specified maximum sequence length for this model (2118 > 512). Running this sequence through the model will result in indexing errors



   [maxSim] Top 3 results:
      Rank 1: DocID ddd54cb0-fe35-4a8f-8a90-bc58b34590d7, Score: 0.5022
         Text snippet: 新加坡研究：疫苗對抗Delta變異株 保護力達69% 新加坡表示，當地一項研究顯示，接種疫苗對防止感染Delta變異株的...
      Rank 2: DocID a27b31d8-ad61-4280-9ede-ea5ab3479a26, Score: 0.4940
         Text snippet: 習近平與捷克總統通話 盼捷方更多人士「正確看待中國」 大陸國家主席習近平7月7日晚上與捷克總統澤曼（Milos Zema...
      Rank 3: DocID de56ae79-6512-4b7d-9754-f6dfde5c5ad8, Score: 0.4930
         Text snippet: 北市頻傳私打疫苗 柯文哲：求生是人的本能 台北市日前接連爆發好心肝診所、振興醫院私打疫苗事件，市長柯文哲今天表示，求生是...

   [BWE-Agg-IDF] Top 3 results:
      Rank 1: DocID 6a84f09d-88f5-4bf6-b6b7-e6a2da7a4af8, Score: 0.6981
         Text snippet: 食物價格一路漲 超市開始囤貨求自保 芝加哥一家超市內，顧客在選購水果。(Getty Images)

食物價格一路走高，...
      Rank 2: DocID 4654911d-e8ef-49b7-b12f-1ae2ce837a67, Score: 0.6967
         Text snippet: 吃素不能打疫苗？醫師們解答籲轉念 網讚：達賴喇嘛都打AZ了 隨著疫苗陸續抵台，政府逐漸開放多類族群接種疫苗，不過仍有民眾...
      Rank 3: DocID e95d764e-0d76-40aa-8119-154f64ab9db7, Score: 0.6935
         Text snippet: 柯文哲：很多商品都中國製 為何疫苗不能用中國代理的？ 新冠肺炎疫情在台灣近日逐漸趨緩，不過能夠有效

Token indices sequence length is longer than the specified maximum sequence length for this model (2102 > 512). Running this sequence through the model will result in indexing errors



   [maxSim] Top 3 results:
      Rank 1: DocID de56ae79-6512-4b7d-9754-f6dfde5c5ad8, Score: 0.5855
         Text snippet: 北市頻傳私打疫苗 柯文哲：求生是人的本能 台北市日前接連爆發好心肝診所、振興醫院私打疫苗事件，市長柯文哲今天表示，求生是...
      Rank 2: DocID 799f1e12-cf9c-468a-90f7-01ef7c827554, Score: 0.5748
         Text snippet: 台日友好是「善的循環」展現 謝長廷：會是世界和平的模範 日本贈我疫苗，表示是為回報台灣311地震的恩情。謝長廷感動表示，...
      Rank 3: DocID ddd54cb0-fe35-4a8f-8a90-bc58b34590d7, Score: 0.5703
         Text snippet: 新加坡研究：疫苗對抗Delta變異株 保護力達69% 新加坡表示，當地一項研究顯示，接種疫苗對防止感染Delta變異株的...

   [BWE-Agg-IDF] Top 3 results:
      Rank 1: DocID 95a18cdd-da7a-4815-a07d-7365d8c2e268, Score: 0.8669
         Text snippet: 欧科云链链上大师重磅上线，一起来用“链上Bloomberg”听听行业脉搏跳动 Wednesday, 7 July 202...
      Rank 2: DocID ddd54cb0-fe35-4a8f-8a90-bc58b34590d7, Score: 0.8584
         Text snippet: 新加坡研究：疫苗對抗Delta變異株 保護力達69% 新加坡表示，當地一項研究顯示，接種疫苗對防止感染Delta變異株的...
      Rank 3: DocID 21fa73c8-e8c5-461b-a567-c08f87cf978b, Score: 0.8568
         Text snippet: 张汉晖大使出席庆祝中国共产党成立100周年暨《中俄睦邻友好合作条约》签署20周年专场音乐会--国际

Token indices sequence length is longer than the specified maximum sequence length for this model (1502 > 512). Running this sequence through the model will result in indexing errors



   [maxSim] Top 3 results:
      Rank 1: DocID 5f486e3f-5e1d-4ba5-8a81-e2a31dccf2e1, Score: 0.8243
         Text snippet: 日本政府確定方針！ 東京奧運將在「緊急事態」中開幕 ▲日本政府確定東京第四次宣布緊急事態宣言的方針，勢必影響奧運開幕式。...
      Rank 2: DocID abc181f7-f049-441c-959c-5e70adb6ecb5, Score: 0.8228
         Text snippet: 新加坡不承認科興疫苗 港議員憂旅遊氣泡計劃受影響｜即時新聞｜港澳｜on.cc東網 內地自主研發的科興疫苗不被新加坡衞生部...
      Rank 3: DocID ddd54cb0-fe35-4a8f-8a90-bc58b34590d7, Score: 0.8223
         Text snippet: 新加坡研究：疫苗對抗Delta變異株 保護力達69% 新加坡表示，當地一項研究顯示，接種疫苗對防止感染Delta變異株的...

   [BWE-Agg-IDF] Top 3 results:
      Rank 1: DocID 5f486e3f-5e1d-4ba5-8a81-e2a31dccf2e1, Score: 0.8395
         Text snippet: 日本政府確定方針！ 東京奧運將在「緊急事態」中開幕 ▲日本政府確定東京第四次宣布緊急事態宣言的方針，勢必影響奧運開幕式。...
      Rank 2: DocID aaf37a05-2522-4af8-93d7-255e9c7a78e7, Score: 0.8371
         Text snippet: 疑與男友吵架想不開 女墜落6米深橋下送醫 一名50多歲女子疑似因為與男友發生爭執，今晚心情不好，從台中市東勢區永安橋墜下...
      Rank 3: DocID abc181f7-f049-441c-959c-5e70adb6ecb5, Score: 0.8364
         Text snippet: 新加坡不承認科興疫苗 港議員憂旅遊氣泡計劃受影響｜即時新聞｜港澳｜on.cc東網 內地自主研發的科