In [None]:
!pip install -U sentence-transformers huggingface_hub

Collecting sentence-transformers
  Downloading sentence_transformers-4.1.0-py3-none-any.whl.metadata (13 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_

In [4]:
# 📦 最终版 Memmap Retriever (Colab/Drive 安全版)
# 完全处理好设备问题 + 支持断点续跑 + 支持直接检索，无需重新生成！

import os
import json
import numpy as np
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from huggingface_hub import hf_hub_download
from google.colab import drive

# ------------------------------
# 挂载 Google Drive
# ------------------------------
drive.mount('/content/drive')
MEMMAP_DIR = "/content/drive/MyDrive/your_project_folder/embeddings"  # 🔥 修改成你想放的位置

# ------------------------------
# 加载语料
# ------------------------------
def load_corpus(repo_id: str, filename: str, repo_type: str = "dataset"):
    corpus_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type)
    corpus = {}
    with open(corpus_path, 'r', encoding='utf-8') as f:
        for line in f:
            if not line.strip():
                continue
            doc = json.loads(line)
            doc_id = doc.get("_id")
            title = doc.get("title", "")
            text = doc.get("text", "")
            corpus[doc_id] = {"title": title, "text": text}
    doc_ids = list(corpus.keys())
    return corpus, doc_ids

# ------------------------------
# 构建 memmap（支持断点续跑）
# ------------------------------
def build_embeddings_memmap(corpus, doc_ids, model_name: str, batch_size: int = 32, memmap_dir: str = "embeddings"):
    model = SentenceTransformer(model_name, trust_remote_code=True)
    d = model.get_sentence_embedding_dimension()
    N = len(doc_ids)

    os.makedirs(memmap_dir, exist_ok=True)
    memmap_path = os.path.join(memmap_dir, "corpus_emb.dat")
    doc_ids_path = os.path.join(memmap_dir, "corpus_doc_ids.json")
    progress_path = os.path.join(memmap_dir, "progress.json")

    if not os.path.exists(memmap_path):
        np.memmap(memmap_path, dtype='float32', mode='w+', shape=(N, d))

    mmap = np.memmap(memmap_path, dtype='float32', mode='r+', shape=(N, d))

    start_idx = 0
    if os.path.exists(progress_path):
        with open(progress_path, 'r', encoding='utf-8') as pf:
            progress = json.load(pf)
            start_idx = progress.get('last_idx', 0)
        print(f"🔄 检测到中断，从 {start_idx} 继续")

    if not os.path.exists(doc_ids_path):
        with open(doc_ids_path, 'w', encoding='utf-8') as f:
            json.dump(doc_ids, f, ensure_ascii=False)

    total_batches = (N + batch_size - 1) // batch_size
    for i in tqdm(range(start_idx, N, batch_size), initial=start_idx//batch_size, total=total_batches):
        batch_ids = doc_ids[i:i+batch_size]
        batch_texts = [corpus[_id]["text"] for _id in batch_ids]
        embeddings = model.encode(
            batch_texts,
            convert_to_numpy=True,
            batch_size=batch_size,
            max_length=512,
            truncation=True
        )
        mmap[i:i+len(embeddings)] = embeddings
        mmap.flush()
        with open(progress_path, 'w', encoding='utf-8') as pf:
            json.dump({'last_idx': i + len(embeddings)}, pf)

    if os.path.exists(progress_path):
        os.remove(progress_path)
    print("✅ 全量写入完成")

    return memmap_path, doc_ids_path, d, N

# ------------------------------
# Memmap Retriever
# ------------------------------
class MemmapRetriever:
    def __init__(self, memmap_path, doc_ids_path, dimension, num_docs, model_name):
        self.dimension = dimension
        self.num_docs = num_docs
        self.mmap = np.memmap(memmap_path, dtype='float32', mode='r', shape=(num_docs, dimension))
        self.corpus_embeddings = torch.from_numpy(self.mmap)
        with open(doc_ids_path, 'r', encoding='utf-8') as f:
            self.doc_ids = json.load(f)
        self.model = SentenceTransformer(model_name, trust_remote_code=True)

    def search(self, queries: dict, top_k: int = 5, score_function: str = 'cos_sim'):
        query_ids = list(queries.keys())
        query_texts = [queries[q] for q in query_ids]
        query_emb = self.model.encode(
            query_texts,
            convert_to_tensor=True,
            batch_size=32,
            max_length=512,
            truncation=True
        ).to('cpu')  # 保证 query_emb 在 CPU

        if score_function == 'cos_sim':
            sim = util.cos_sim(query_emb, self.corpus_embeddings)
        elif score_function == 'dot':
            sim = torch.matmul(query_emb, self.corpus_embeddings.T)
        else:
            raise ValueError(f"Unsupported score_function {score_function}")

        results = {}
        for idx, qid in enumerate(query_ids):
            topk = torch.topk(sim[idx], k=top_k)
            ids, scores = topk.indices.tolist(), topk.values.tolist()
            results[qid] = {self.doc_ids[i]: s for i, s in zip(ids, scores)}
        return results

# ------------------------------
# 主程序
# ------------------------------
if __name__ == '__main__':
    repo_id = "COMP631GroupSYCZ/Corpus"
    filename = "corpus.jsonl"
    model_name = "Lajavaness/bilingual-embedding-small"

    if not (os.path.exists(os.path.join(MEMMAP_DIR, "corpus_emb.dat")) and os.path.exists(os.path.join(MEMMAP_DIR, "corpus_doc_ids.json"))):
        corpus, doc_ids = load_corpus(repo_id=repo_id, filename=filename)
        build_embeddings_memmap(corpus, doc_ids, model_name=model_name, batch_size=32, memmap_dir=MEMMAP_DIR)

    retriever = MemmapRetriever(
        memmap_path=os.path.join(MEMMAP_DIR, "corpus_emb.dat"),
        doc_ids_path=os.path.join(MEMMAP_DIR, "corpus_doc_ids.json"),
        dimension=384,
        num_docs=len(json.load(open(os.path.join(MEMMAP_DIR, "corpus_doc_ids.json")))),
        model_name=model_name
    )

    queries = {
        "q1": "我昨晚梦见飞翔的鱼和奇怪的建筑，想了解这两个梦境的意义。"
    }
    results = retriever.search(queries, top_k=5)
    print(json.dumps(results, ensure_ascii=False, indent=2))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
{
  "q1": {
    "65": 0.8136338591575623,
    "9318": 0.7953237295150757,
    "82": 0.7892376780509949,
    "8366": 0.7865391969680786,
    "8869": 0.7863246202468872
  }
}


In [3]:
# 检索完成后，拿到 results
results = retriever.search(queries, top_k=3)

# 加载原 corpus（原文档内容）
# corpus, _, _ = load_corpus(
#     repo_id="COMP631GroupSYCZ/Corpus",
#     filename="corpus.jsonl"
# )
corpus, _ = load_corpus(
    repo_id="COMP631GroupSYCZ/Corpus",
    filename="corpus.jsonl"
)

# 把结果根据 doc_id还原成 文本内容
for query_id, doc_scores in results.items():
    print(f"🔍 查询: {query_id}")
    for doc_id, score in doc_scores.items():
        text = corpus[str(doc_id)]['text'][:200]  # 只打印前200字符
        print(f"📄 文档ID: {doc_id}, 相似度: {score:.4f}")
        print(f"内容: {text}")
        print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━")


🔍 查询: q1
📄 文档ID: 65, 相似度: 0.7896
内容: 梦见飞鱼是什么意思？做梦梦见飞鱼好不好？梦见飞鱼有现实的影响和反应，也有梦者的主观想象，请看下面由周公解梦官网整理的梦见飞鱼的详细解说吧。
飞鱼不是真的能飞而是一种跳跃滑翔的过程，让什么看起来像是在飞。在梦中，飞鱼往往是一种超越的精神体现。
梦见大海上很多飞鱼成群结队地在跳跃，预示着自己最近做事会超越自己的预料。得到很好的评价。
梦见飞鱼出现在沙漠中，表示自己最近会解决掉一些烦恼自己很久的事情。

━━━━━━━━━━━━━━━━━━━━━━━━━━━━
📄 文档ID: 9318, 相似度: 0.7724
内容: 梦见鱼在空中飞是什么意思？做梦梦见鱼在空中飞好不好？梦见鱼在空中飞有现实的影响和反应，也有梦者的主观想象，请看下面由周公解梦官网整理的梦见鱼在空中飞的详细解说吧。
梦见鱼在天上飞，表示近期有很多致富的机会有你面前，要好好把握。
传说，鲤鱼只要能跃过龙门，就会化为天上的飞龙;后以“鲤鱼跳龙门”比喻中举、升官等飞黄腾达之事;如今又用作比喻逆流前进，奋发向上，步步高升，官运亨通。
梦见鲤鱼跳龙门，预示梦
━━━━━━━━━━━━━━━━━━━━━━━━━━━━
📄 文档ID: 54, 相似度: 0.7679
内容: 梦见怪鱼是什么意思？做梦梦见怪鱼好不好？梦见怪鱼有现实的影响和反应，也有梦者的主观想象，请看下面由周公解梦官网整理的梦见怪鱼的详细解说吧。
梦见怪鱼，鱼代表内心突然出现的想法，或者生活中新出现的人，长相凶恶，感觉极具攻击性，说明你最近的情绪并不稳定，受到一些想法或事物的影响。
梦见大怪鱼，预示着你最近又麻烦事，提示你，最近做事小心为是。
梦见抓到怪鱼，身体有毛病，肾虚，想省的钱没省成。http:/
━━━━━━━━━━━━━━━━━━━━━━━━━━━━


In [None]:
# 先挂载 Google Drive
from google.colab import drive
drive.mount('/content/drive')

# 创建目标文件夹
!mkdir -p /content/drive/MyDrive/your_project_folder/embeddings

# 复制文件到 Drive
!cp /content/embeddings/corpus_emb.dat /content/drive/MyDrive/your_project_folder/embeddings/
!cp /content/embeddings/corpus_doc_ids.json /content/drive/MyDrive/your_project_folder/embeddings/


Mounted at /content/drive
