In [None]:
import json 
import numpy as np 
from pathlib import Path
from sentence_transformers import SentenceTransformer

In [None]:
records_path = Path("data/records/hetionet_cancer_mvp_records_v3.json")

with open(records_path, "r") as f:
    records = json.load(f)

print("Loaded records:", len(records))
print("Example keys:", records[0].keys())
print("Example text:", records[0]["search_text"][:300])

Loaded records: 5367
Example keys: dict_keys(['id', 'entity_type', 'identifier', 'name', 'search_text', 'metadata'])
Example text: Gene: SLC2A3. Description: solute carrier family 2 (facilitated glucose transporter), member 3. Associated diseases: germ cell cancer (associates); melanoma (upregulates).


In [None]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(model_name)

print("Model loaded:", model_name)

Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

[1mBertModel LOAD REPORT[0m from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Model loaded: sentence-transformers/all-MiniLM-L6-v2


In [38]:
texts = [r["search_text"] for r in records]
ids = [r["id"] for r in records]

print("Num texts:", len(texts))
print("First text sample:\n", texts[0][:500])

Num texts: 5367
First text sample:
 Gene: SLC2A3. Description: solute carrier family 2 (facilitated glucose transporter), member 3. Associated diseases: germ cell cancer (associates); melanoma (upregulates).


In [39]:
embeddings = model.encode(
    texts,
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

print("Embeddings shape:", embeddings.shape)
print("dtype:", embeddings.dtype)

Batches:   0%|          | 0/84 [00:00<?, ?it/s]

Embeddings shape: (5367, 384)
dtype: float32


In [None]:
out_dir = Path("data/embeddings/vector_index")
out_dir.mkdir(exist_ok=True)

np.save(out_dir / "hetionet_mvp_embeddings.npy", embeddings)

with open(out_dir / "hetionet_mvp_records.json", "w") as f:
    json.dump(records, f, indent=2)

print("Saved embeddings to:", out_dir / "hetionet_mvp_embeddings.npy")
print("Saved records to:", out_dir / "hetionet_mvp_records.json")

Saved embeddings to: /Users/shaunak/Documents/Hacklytics2026/vector_index/hetionet_mvp_embeddings.npy
Saved records to: /Users/shaunak/Documents/Hacklytics2026/vector_index/hetionet_mvp_records.json


# Local semantic search test

In [None]:
import numpy as np

def semantic_search(
    query,
    model,
    embeddings,
    records,
    top_k=5,
    kind_filter=None,
    must_contain=None, 
):
    """
    kind_filter: None or one of {"Gene", "Disease", "Compound"}
    must_contain: optional string that must appear in search_text (case-insensitive)
    """

    # 1) Embed query
    q_emb = model.encode(
        [query],
        convert_to_numpy=True,
        normalize_embeddings=True,
        show_progress_bar=False
    )[0]

    # 2) Cosine similarity (works because embeddings are normalized)
    scores = embeddings @ q_emb

    # 3) Start with all candidates
    candidate_idxs = np.arange(len(records))

    # 4) Applying entity-type filter
    if kind_filter is not None:
        candidate_idxs = np.array([
            i for i, r in enumerate(records)
            if r.get("entity_type") == kind_filter
        ])

    # 5) Applying keyword constraint (NEW)
    if must_contain is not None:
        token = must_contain.lower().strip()
        candidate_idxs = np.array([
            i for i in candidate_idxs
            if token in (records[i].get("search_text", "").lower())
        ])

    # If nothing survives filters
    if len(candidate_idxs) == 0:
        return []

    # 6) Top-k among candidates
    candidate_scores = scores[candidate_idxs]
    k = min(top_k, len(candidate_idxs))
    top_local = np.argsort(-candidate_scores)[:k]
    top_idxs = candidate_idxs[top_local]

    # 7) Formatting results
    results = []
    for i in top_idxs:
        r = records[i]
        results.append({
            "score": float(scores[i]),
            "entity_type": r.get("entity_type"),
            "id": r.get("id"),
            "identifier": r.get("identifier"),
            "name": r.get("name"),
            "search_text": r.get("search_text"),
            "metadata": r.get("metadata", {})
        })

    return results

In [None]:
# Using some sample queries

queries = [
    ("genes associated with lung cancer", "Gene"),
    ("compounds used for breast cancer", "Compound"),
    ("diseases associated with TP53", "Disease"),
    ("cancer drugs targeting EGFR", "Compound"),
    ("What are the top genes related to cancer?", "Gene"),
]

for q, filt in queries:
    print("\n" + "="*100)
    print("QUERY:", q)
    print("FILTER:", filt)

    hits = semantic_search(q, model, embeddings, records, top_k=5, kind_filter=filt)

    for h in hits:
        print(f"\n[{h['score']:.3f}] {h['entity_type']} | {h['id']} | {h['name']}")
        print(h["search_text"][:300])


QUERY: genes associated with lung cancer
FILTER: Gene

[0.710] Gene | Gene::1810 | DR1
Gene: DR1. Description: down-regulator of transcription 1, TBP-binding (negative cofactor 2). Associated diseases: lung cancer (upregulates).

[0.705] Gene | Gene::25771 | TBC1D22A
Gene: TBC1D22A. Description: TBC1 domain family, member 22A. Associated diseases: lung cancer (upregulates).

[0.704] Gene | Gene::128602 | C20orf85
Gene: C20orf85. Description: chromosome 20 open reading frame 85. Associated diseases: lung cancer (associates).

[0.698] Gene | Gene::92689 | FAM114A1
Gene: FAM114A1. Description: family with sequence similarity 114, member A1. Associated diseases: lung cancer (upregulates).

[0.696] Gene | Gene::118611 | C10orf90
Gene: C10orf90. Description: chromosome 10 open reading frame 90. Associated diseases: lung cancer (upregulates).

QUERY: compounds used for breast cancer
FILTER: Compound

[0.613] Compound | Compound::DB00481 | Raloxifene
Compound: Raloxifene. Treats or relates to

In [None]:

print(records[0].keys())
print(records[0])

dict_keys(['id', 'entity_type', 'identifier', 'name', 'search_text', 'metadata'])
{'id': 'Gene::9055', 'entity_type': 'Gene', 'identifier': 9055, 'name': 'PRC1', 'search_text': 'Gene: PRC1. Description: protein regulator of cytokinesis 1. Related diseases: breast cancer (upregulates); pancreatic cancer (upregulates); breast cancer (associates).', 'metadata': {'description': 'protein regulator of cytokinesis 1', 'source': 'Entrez Gene', 'license': 'CC0 1.0', 'url': 'http://identifiers.org/ncbigene/9055', 'chromosome': '15'}}


In [35]:

hits = semantic_search("compounds used for breast cancer", model, embeddings, records, top_k=20, kind_filter="Compound")
print([h["entity_type"] for h in hits])

['Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound', 'Compound']


In [47]:
# Basic checks for records
print(type(records))
print("num records:", len(records))

# Show keys in first record
print("first record keys:", records[0].keys())

# Show first record (truncated view)
for k, v in records[0].items():
    if isinstance(v, str):
        print(f"{k}: {v[:200]}")
    else:
        print(f"{k}: {v}")

<class 'list'>
num records: 5367
first record keys: dict_keys(['id', 'entity_type', 'identifier', 'name', 'search_text', 'metadata'])
id: Gene::6515
entity_type: Gene
identifier: 6515
name: SLC2A3
search_text: Gene: SLC2A3. Description: solute carrier family 2 (facilitated glucose transporter), member 3. Associated diseases: germ cell cancer (associates); melanoma (upregulates).
metadata: {'description': 'solute carrier family 2 (facilitated glucose transporter), member 3', 'source': 'Entrez Gene', 'license': 'CC0 1.0', 'url': 'http://identifiers.org/ncbigene/6515', 'chromosome': '12'}


In [None]:
required = {"id", "entity_type", "identifier", "name", "search_text", "metadata"}

missing_counts = {k: 0 for k in required}
bad_records = []

for i, r in enumerate(records):
    missing = required - set(r.keys())
    if missing:
        bad_records.append((i, missing))
        for m in missing:
            missing_counts[m] += 1

print("Missing key counts:", missing_counts)
print("Number of bad records:", len(bad_records))

# Checking first few problematic records if any
print("Sample bad records:", bad_records[:5])

Missing key counts: {'identifier': 0, 'search_text': 0, 'entity_type': 0, 'name': 0, 'metadata': 0, 'id': 0}
Number of bad records: 0
Sample bad records: []


In [49]:

print(type(embeddings))
print("embeddings shape:", embeddings.shape)
print("embeddings dtype:", embeddings.dtype)

print("len(records):", len(records))
print("embedding rows:", embeddings.shape[0])
print("rows match:", len(records) == embeddings.shape[0])

<class 'numpy.ndarray'>
embeddings shape: (5367, 384)
embeddings dtype: float32
len(records): 5367
embedding rows: 5367
rows match: True
