In [None]:
import polars as pl

csv_path = "../data/pubmed_baseline/csv/pubmed25n0001.csv" 
# csv_path = "../data/pubmed_baseline/merged_output.csv" 

df = pl.read_csv(csv_path)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel

# 選擇設備
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import os

model_name = "abhinand/MedEmbed-base-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().to(device)

if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model, device_ids=[0, 1])

model.eval()
BATCH_SIZE = 5000

In [None]:
def generate_embeddings_batch(texts):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs).last_hidden_state
    embeddings = outputs.mean(dim=1)
    return embeddings.cpu().numpy()

In [None]:
import faiss
import numpy as np

# 設定 embedding 維度（根據 jinaai/jina-embeddings-v3，應該是 1024 維）
embedding_dim = 768  
index = faiss.IndexFlatIP(embedding_dim)  # 使用 L2 距離（內積 IndexFlatIP 也可以）

paper_metadata = []
all_embeddings = []


In [None]:
from tqdm import tqdm

# 準備數據
texts = []
for row in df.iter_rows(named=True):
    text = f"{row['Title']} {row['Abstract']}"
    texts.append(text)
    paper_metadata.append({
        "PMID": row["PMID"], 
        "Title": row["Title"], 
        "Abstract": row["Abstract"], 
        "Year": row["Year"]
    })

# 批次處理
for i in tqdm(range(0, len(texts), BATCH_SIZE)):
    batch_texts = texts[i:i + BATCH_SIZE]
    batch_embeddings = generate_embeddings_batch(batch_texts)
    all_embeddings.extend(batch_embeddings)

In [None]:
# 轉換為 numpy array 並加入 FAISS
all_embeddings = np.array(all_embeddings).astype("float32")
index.add(all_embeddings)

# 儲存
faiss.write_index(index, "pubmed_index.faiss")
np.save("pubmed_metadata.npy", paper_metadata)

In [None]:
# 修正：將 BFloat16 轉換為 Float32
def generate_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs).last_hidden_state  # 所有 token 的輸出
    embedding = outputs.mean(dim=1)  # 均值池化
    return embedding.to(torch.float32).cpu().numpy().flatten()

In [None]:
import numpy as np

def search_papers(query, top_k=5, threshold=0.5):
    index = faiss.read_index("pubmed_index.faiss")
    metadata = np.load("pubmed_metadata.npy", allow_pickle=True)

    query_emb = np.array(generate_embedding(query)).reshape(1, -1).astype("float32")

    # 執行檢索
    D, I = index.search(query_emb, top_k)

    # 過濾掉相似度低於 threshold 的結果
    results = []
    for i, idx in enumerate(I[0]):
        if idx < len(metadata) and D[0][i] >= threshold:
            paper = metadata[idx]
            if paper and isinstance(paper.get("Title"), str) and isinstance(paper.get("Abstract"), str):
                results.append(paper)

    return results

# 測試查詢
query = "mdeical"
search_results = search_papers(query, top_k=3)
for paper in search_results:
    print(f"Title: {paper.get('Title', 'N/A')}\nAbstract: {paper.get('Abstract', 'N/A')}\nYear: {paper.get('Year', 'N/A')}\n")