In [4]:
from pathlib import Path
import faiss, pickle

processed_dir = Path("../data/processed")

# Reload FAISS index
index = faiss.read_index(str(processed_dir / "evidence_faiss.index"))

# Reload metadata
with open(processed_dir / "id_to_row.pkl", "rb") as f:
    id_to_row = pickle.load(f)

print("Index size:", index.ntotal)


Index size: 5000


In [5]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

def retrieve(query, k=5):
    q_emb = model.encode([query], convert_to_numpy=True).astype("float32")
    faiss.normalize_L2(q_emb)
    scores, idxs = index.search(q_emb, k)
    results = []
    for score, i in zip(scores[0], idxs[0]):
        row = id_to_row[i]
        results.append({
            "score": float(score),
            "ligand": row.get("ligand_name"),
            "smiles": row.get("smiles"),
            "target": row.get("target"),
            "activity_type": row.get("activity_type"),
            "value": row.get("value"),
            "pValue": row.get("pValue")
        })
    return results


In [6]:
queries = [
    "EGFR inhibitors",
    "HER2 compounds with IC50 under 10 nM",
    "CYP2D6 metabolism"
]

for q in queries:
    print(f"\n🔎 {q}")
    for r in retrieve(q, k=3):
        print(f"  • {r['ligand']} | Target: {r['target']} | "
              f"{r['activity_type']}={r['value']} (p={r['pValue']}, score={r['score']:.3f})")



🔎 EGFR inhibitors
  • 91663444.0 | Target: MLGNKRLGLSGLTLALSLLVCLGALAEAYPSKPDNPGEDAPAEDMARYYSALRHYINLITRQRYGKRSSPETLISDLLMRESTENVPRTRLEDPAMW | BindingDB_Ki=4.6 (p=nan, score=0.362)
  • 44186669.0 | Target: MPPSISAFQAAYIGIEVLIALVSVPGNVLVIWAVKVNQALRDATFCFIVSLAVADVAVGALVIPLAILINIGPQTYFHTCLMVACPVLILTQSSILALLAIAVDRYLRVKIPLRYKMVVTPRRAAVAIAGCWILSFVVGLTPMFGWNNLSAVERAWAANGSMGEPVIKCEFEKVISMEYMVYFNFFVWVLPPLLLMVLIYLEVFYLIRKQLNKKVSASSGDPQKYYGKELKIAKSLALILFLFALSWLPLHILNCITLFCPSCHKPSILTYIAIFLTHGNSAMNPIVYAFRIQKFRVTFLKIWNDHFRCQPAPPIDEDLPEERPDD | BindingDB_Ki=2610.0 (p=nan, score=0.348)
  • 137796736.0 | Target: MGFQKFSPFLALSILVLLQAGSLHAAPFRSALESSPADPATLSEDEARLLLAALVQNYVQMKASELEQEQEREGSRIIAQKRACDTATCVTHRLAGLLSRSGGVVKNNFVPTNVGSKAFGRRRRDLQA | BindingDB_Ki=0.05 (p=nan, score=0.344)

🔎 HER2 compounds with IC50 under 10 nM
  • None | Target: None | IC50= 1851 (p=5.732593581247096, score=0.427)
  • prop-2-en-1-yl heptanoate | Target: solubility | solubility_aqsoldb=-3.5976237668 (p=nan, score=0.425)
  • None

In [8]:
from transformers import pipeline

rag_llm = pipeline("text-generation", model="google/flan-t5-base")

query = "Summarize evidence for EGFR inhibitors."
evidence = retrieve(query, k=5)

context = "\n".join([
    f"{e['ligand']} ({e['activity_type']}={e['value']}, p={e['pValue']}) on {e['target']}"
    for e in evidence
])

prompt = f"Question: {query}\nEvidence:\n{context}\n\nAnswer with citations."
print(rag_llm(prompt, max_length=200)[0]["generated_text"])


model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Device set to use mps:0
The model 'T5ForConditionalGeneration' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForConditionalGeneration', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTN

Question: Summarize evidence for EGFR inhibitors.
Evidence:
91663444.0 (BindingDB_Ki=4.6, p=nan) on MLGNKRLGLSGLTLALSLLVCLGALAEAYPSKPDNPGEDAPAEDMARYYSALRHYINLITRQRYGKRSSPETLISDLLMRESTENVPRTRLEDPAMW
TOX1633 (SR-p53=0.0, p=nan) on SR-p53
TOX5312 (NR-AhR=0.0, p=nan) on NR-AhR
TOX1636 (SR-p53=0.0, p=nan) on SR-p53
TOX1431 (SR-p53=0.0, p=nan) on SR-p53

Answer with citations.


In [10]:
from transformers import pipeline

# Use text2text-generation for FLAN-T5
rag_llm = pipeline(
    "text2text-generation",
    model="google/flan-t5-base",
    device_map="auto"   # uses MPS on Mac if available
)

def cite_block(evidence):
    lines = []
    for e in evidence:
        ligand = e.get("ligand") or e.get("smiles")
        lines.append(
            f"{ligand} | Target={e['target']} | {e['activity_type']}={e['value']} | p={e['pValue']}"
        )
    return "\n".join(lines)

query = "Summarize evidence for EGFR inhibitors."
evidence = retrieve(query, k=5)   # <- your FAISS retriever
context = cite_block(evidence)

prompt = (
    "You are a scientific assistant. Summarize the evidence below in 3–5 bullets. "
    "Each bullet MUST include a citation with ligand/target/value in parentheses.\n\n"
    f"Evidence:\n{context}\n\nSummary:"
)

out = rag_llm(prompt, max_new_tokens=256, truncation=True)[0]["generated_text"]
print(out)


Device set to use mps


                                                                                                                               


In [11]:
import time
t0 = time.perf_counter()
evidence = retrieve(query, k=5)
t1 = time.perf_counter()
out = rag_llm(prompt, max_new_tokens=256, truncation=True)[0]["generated_text"]
t2 = time.perf_counter()
print("retrieval_ms:", (t1-t0)*1000, "gen_ms:", (t2-t1)*1000)


retrieval_ms: 4216.097666998394 gen_ms: 14899.229374990682
