In [None]:
import polars as pl
import faiss
import torch
import numpy as np
from tqdm import tqdm 
from transformers import AutoTokenizer, AutoModel
from multiprocessing import Pool, cpu_count

In [None]:
csv_path = "../data/pubmed_baseline/csv/pubmed25n1274.csv" 
df = pl.read_csv(csv_path)
print(f'Number of rows: {len(df)}')

columns_to_check = ["PMID", "Title", "Abstract", "Authors", "Year", "Journal"]
df = df.drop_nulls(subset=columns_to_check)
print(f'Number of rows after dropping nulls: {len(df)}')

df = df.with_columns(df["Year"].cast(int))

In [None]:
from torch.nn.parallel import DistributedDataParallel as DDP

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 載入 MedEmbed 模型
model_name = "abhinand/MedEmbed-base-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device).eval()
model = model.half()

model = DDP(model, device_ids=[0,1,3,4], output_device=0, find_unused_parameters=True)
# if torch.cuda.device_count() > 1:
#     model = torch.nn.DataParallel(model)

In [None]:
texts = df["Title"] + ". " + df["Abstract"]

def embed_text_in_batches(texts, batch_size=32):
    texts = list(texts)  # **確保 texts 是 List**
    all_embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing Batches"):
        batch_texts = texts[i:i+batch_size]  # 直接取 list 的 slice，不需要 `.tolist()`
        inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)

        with torch.no_grad():
            outputs = model(**inputs)

        batch_embeddings = outputs.last_hidden_state[:, 0, :].to(torch.float32).cpu().numpy()  # 取 CLS token
        all_embeddings.append(batch_embeddings)

        del inputs, outputs  # 釋放記憶體
        torch.cuda.empty_cache()  # 清理 GPU 記憶體

    return np.vstack(all_embeddings)  # 合併所有 batch 結果

# 產生所有文本的向量
embeddings = embed_text_in_batches(texts, batch_size=9000)  # **如果還是 OutOfMemory，改 batch_size=8**

In [None]:
# 設置 FAISS 向量庫
d = embeddings.shape[1]  # 向量維度
N = embeddings.shape[0] # number of embeddings
nlist = nlist = min(int(4 * np.sqrt(N)), N)
quantizer = faiss.IndexFlatL2(d)  # L2 距離的量化器
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)

# 訓練 FAISS 索引（IVF 需要訓練）
index.train(embeddings)
index.add(embeddings)

In [None]:
# 儲存 FAISS 索引
faiss.write_index(index, "faiss_medical_index_IndexIVFFlat.ivf")

df.write_csv("faiss_metadata_IndexIVFFlat.csv")

print("向量索引建立完成，已儲存 FAISS 索引和 Metadata！")

# Serach

In [1]:
import faiss
import pandas as pd
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 重新載入模型 & Tokenizer
model_name = "abhinand/MedEmbed-base-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device).eval()

# 重新載入 FAISS 索引
index = faiss.read_index("../output/2020/faiss.index")

# 讀取 Metadata
df = pd.read_csv("../output/2020/metadata.csv")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# **搜尋函數**
def search_papers(query, top_k=5):
    # **Step 1: Query 向量化**
    inputs = tokenizer(query, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    query_embedding = outputs.last_hidden_state[:, 0, :].to(torch.float32).cpu().numpy()

    # **Step 2: FAISS 檢索**
    D, I = index.search(query_embedding, top_k)  # D 是距離，I 是索引

    # **Step 3: 根據索引取得 Metadata**
    results = df.iloc[I[0]].copy()  # 選擇查詢到的論文

    # **Step 4: 按年份排序（從新到舊）**
    results = results.sort_values(by="Year", ascending=False)

    # **Step 5: 格式化輸出**
    search_results = []
    for _, row in results.iterrows():
        search_results.append({
            "PMID": row["PMID"],
            "Title": row["Title"],
            "Abstract": row["Abstract"],
            "Authors": row["Authors"],
            "Year": row["Year"],
            "Journal": row["Journal"],
            "Keyword": row["Keyword"]
        })

    return search_results

# **測試 Search**
query = "Contrastive learning for medical image analysis"
results = search_papers(query, top_k=5)

# **顯示結果**
for i, res in enumerate(results):
    print(f"🔹 {i+1}. {res['Title']} ({res['Year']})")
    print(f"    📝 Abstract: {res['Abstract']}")
    print(f"    👩‍⚕️ Authors: {res['Authors']}")
    print(f"    🏥 Journal: {res['Journal']}")
    print(f"    🔑 Keywords: {res['Keyword']}")
    print(f"    🔗 PMID: {res['PMID']}\n")

🔹 1. Self-supervised learning for medical image analysis: Discriminative, restorative, or adversarial? (2024)
    📝 Abstract: Discriminative, restorative, and adversarial learning have proven beneficial for self-supervised learning schemes in computer vision and medical imaging. Existing efforts, however, fail to capitalize on the potentially synergistic effects these methods may offer in a ternary setup, which, we envision can significantly benefit deep semantic representation learning. Towards this end, we developed DiRA, the first framework that unites discriminative, restorative, and adversarial learning in a unified manner to collaboratively glean complementary visual information from unlabeled medical images for fine-grained semantic representation learning. Our extensive experiments demonstrate that DiRA: (1) encourages collaborative learning among three learning ingredients, resulting in more generalizable representation across organs, diseases, and modalities; (2) outperforms 