In [3]:
import sys
sys.path.insert(0,'/scratch/aqi5157/pip_packages_and_cache/packages')

In [5]:
import time
import requests
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

In [14]:
# using SAPbert for now
SapBERT_tokenizer = AutoTokenizer.from_pretrained('cambridgeltl/SapBERT-from-PubMedBERT-fulltext')
SapBERT_model = AutoModel.from_pretrained('cambridgeltl/SapBERT-from-PubMedBERT-fulltext')

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def embed_sentence(sentence, tokenizer, model, device='cpu'):
    encoded_input = tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')
    # If using GPU, you could do: 
    # for k in encoded_input:
    #     encoded_input[k] = encoded_input[k].to(device)
    with torch.no_grad():
        model_output = model(**encoded_input)
    return mean_pooling(model_output, encoded_input['attention_mask']).cpu().numpy()


def get_gene_description_ncbi(
    gene_id_or_symbol,
    api_key=None,
    organism_taxid=9606,     # default: human
    sleep_sec=0.1           # delay between requests
):
    """
    Fetch the gene description from NCBI using E-utilities (ESearch + ESummary).
    
    Parameters
    ----------
    gene_id_or_symbol : str
        E.g. "BRCA1" or "672".
    api_key : str, optional
        NCBI API key to help avoid 429 errors and get higher rate limits.
    organism_taxid : int, optional
        By default uses 9606 for Homo sapiens. Adjust if needed.
    sleep_sec : float, optional
        Seconds to sleep after each request to avoid rate-limiting.
    """
    try:
        # 1) ESearch: symbol -> Gene ID
        base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
        params = {
            'db': 'gene',
            'term': f"{gene_id_or_symbol}[sym] AND txid{organism_taxid}[Organism]",
            'retmode': 'json',
        }
        if api_key:
            params['api_key'] = api_key

        r = requests.get(base_url, params=params)
        r.raise_for_status()
        data = r.json()
        time.sleep(sleep_sec)  # Throttle requests

        gene_ids = data["esearchresult"]["idlist"]
        # If no symbol found, fallback to checking if it's a numeric ID
        if len(gene_ids) == 0:
            if gene_id_or_symbol.isdigit():
                gene_ids = [gene_id_or_symbol]
            else:
                return None
        
        gene_id = gene_ids[0]

        # 2) ESummary: get summary info for that gene ID
        summary_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi"
        params_summary = {
            'db': 'gene',
            'id': gene_id,
            'retmode': 'json',
        }
        if api_key:
            params_summary['api_key'] = api_key

        r_sum = requests.get(summary_url, params=params_summary)
        r_sum.raise_for_status()
        sum_data = r_sum.json()
        time.sleep(sleep_sec)  # Throttle requests

        if 'result' not in sum_data:
            return None

        gene_record = sum_data['result'].get(gene_id, {})
        summary_text = gene_record.get('summary', '') or gene_record.get('description', '')
        return summary_text if summary_text.strip() != '' else None

    except requests.exceptions.HTTPError as he:
        print(f"HTTP Error fetching gene {gene_id_or_symbol}: {he}")
        return None
    except Exception as e:
        print(f"Error fetching gene {gene_id_or_symbol}: {e}")
        return None

def get_omics_gene_centroid_embedding(
    gene_list,
    tokenizer,
    model,
    api_key=None,
    organism_taxid=9606,
    device='cpu',
    sleep_sec=0.1
):
    """
    1) For each gene in gene_list, fetch description from NCBI Gene.
    2) Embed the description using embed_sentence.
    3) Average all gene embeddings => centroid.
    """
    valid_embeddings = []
    for gene_symbol in gene_list:
        desc = get_gene_description_ncbi(
            gene_symbol,
            api_key=api_key,
            organism_taxid=organism_taxid,
            sleep_sec=sleep_sec
        )
        if desc:
            emb = embed_sentence(desc, tokenizer, model, device=device)
            valid_embeddings.append(emb)

    if not valid_embeddings:
        return None

    stack = np.vstack(valid_embeddings)   # shape (N, D)
    centroid = np.mean(stack, axis=0, keepdims=True)  # shape (1, D)
    return centroid

# -----------------------------------------------------------------------
# Score an Omics Set's LLM-Proposed Name via Gene-Centroid
# -----------------------------------------------------------------------
def score_omics_set_with_centroid(
    llm_assigned_name,
    gene_list,
    tokenizer,
    model,
    api_key=None,
    device='cpu'
):
    """
    1) Compute the centroid embedding of gene_list by fetching 
       each gene's description from NCBI with an API key.
    2) Embed the llm_assigned_name with the same model.
    3) Return their cosine similarity.
    """
    centroid_emb = get_omics_gene_centroid_embedding(
        gene_list=gene_list,
        tokenizer=tokenizer,
        model=model,
        api_key=api_key,
        device=device
    )
    if centroid_emb is None:
        print("No valid descriptions or embeddings; returning similarity=0.0.")
        return 0.0

    name_emb = embed_sentence(llm_assigned_name, tokenizer, model, device=device)
    sim = float(cosine_similarity(centroid_emb, name_emb)[0][0])
    return sim



In [15]:
# testing

NCBI_API_KEY_TEST = "75c5d8bd2e751177f5d17258476cf736c508"

    # Example usage
llm_assigned_name = "Pancreatic Islet Development and Glucose Metabolism Regulation"
gene_list_example = ['HOTTIP', 'NEU1', 'ARFGAP3', 'MAPKAPK2', 'SLC35B4', 'MIS12', 'HINT3', 'TOR1AIP2', 'NME1', 'AIDA', 'MITD1', 'DPY19L1', 'POLR1F', 'SYCP2', 'BCL3', 'ASH1L-AS1', 'UBE2F', 'CARD19', 'ZBTB17', 'SPRYD4', 'SYNGR2', 'FOXK2', 'NLGN2', 'RAB43', 'CYBC1', 'CWC25', 'OFD1', 'PIGW', 'TENT4B', 'DIMT1', 'NOPCHAP1', 'POGLUT1', 'CCDC142', 'DERL1', 'KRBOX5', 'LYRM7', 'NIN', 'VTI1B', 'IFT20', 'SEPHS1', 'RAB6A', 'NDFIP2', 'H3-3A', 'MBOAT2', 'ZSWIM9', 'SLC35D1', 'BOD1', 'PTCD3', 'DNAJC27', 'CSRP2', 'PURB', 'PPFIBP1', 'ALG10', 'PIGK', 'SKP1', 'CHCHD1', 'SEPTIN6', 'GALNT1', 'RNF2', 'CAPRIN1', 'DHFR2', 'EPB41L4A-AS1', 'MRPS9', 'PRSS12', 'NAAA', 'MAP7D3', 'PTBP3', 'PIBF1', 'DNAAF2', 'ARL2BP', 'KCMF1', 'AMFR', 'CDS2', 'UACA', 'PARS2', 'ATXN7L3B', 'KRT10', 'GUF1', 'DNAJC3-DT', 'CBX1', 'GET4', 'ANAPC4', 'TRIM4', 'ATG16L2', 'NEAT1', 'TRMT44', 'CYP2D6', 'CLN3', 'GK5', 'ZNF276', 'TMEM80', 'RBSN', 'FDPS', 'TRERF1', 'EGF', 'HOMEZ', 'CDC42EP5', 'FEM1A', 'LONRF3']

sim_score = score_omics_set_with_centroid(
    llm_assigned_name=llm_assigned_name,
    gene_list=gene_list_example,
    tokenizer=SapBERT_tokenizer,
    model=SapBERT_model,
    api_key=YOUR_NCBI_API_KEY,  # <--- pass your key here
    device='cpu'
)

print(f"Assigned name: {llm_assigned_name}")
print(f"Gene set: {gene_list_example}")
print(f"Centroid-based similarity = {sim_score:.4f}")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Assigned name: Pancreatic Islet Development and Glucose Metabolism Regulation
Gene set: ['HOTTIP', 'NEU1', 'ARFGAP3', 'MAPKAPK2', 'SLC35B4', 'MIS12', 'HINT3', 'TOR1AIP2', 'NME1', 'AIDA', 'MITD1', 'DPY19L1', 'POLR1F', 'SYCP2', 'BCL3', 'ASH1L-AS1', 'UBE2F', 'CARD19', 'ZBTB17', 'SPRYD4', 'SYNGR2', 'FOXK2', 'NLGN2', 'RAB43', 'CYBC1', 'CWC25', 'OFD1', 'PIGW', 'TENT4B', 'DIMT1', 'NOPCHAP1', 'POGLUT1', 'CCDC142', 'DERL1', 'KRBOX5', 'LYRM7', 'NIN', 'VTI1B', 'IFT20', 'SEPHS1', 'RAB6A', 'NDFIP2', 'H3-3A', 'MBOAT2', 'ZSWIM9', 'SLC35D1', 'BOD1', 'PTCD3', 'DNAJC27', 'CSRP2', 'PURB', 'PPFIBP1', 'ALG10', 'PIGK', 'SKP1', 'CHCHD1', 'SEPTIN6', 'GALNT1', 'RNF2', 'CAPRIN1', 'DHFR2', 'EPB41L4A-AS1', 'MRPS9', 'PRSS12', 'NAAA', 'MAP7D3', 'PTBP3', 'PIBF1', 'DNAAF2', 'ARL2BP', 'KCMF1', 'AMFR', 'CDS2', 'UACA', 'PARS2', 'ATXN7L3B', 'KRT10', 'GUF1', 'DNAJC3-DT', 'CBX1', 'GET4', 'ANAPC4', 'TRIM4', 'ATG16L2', 'NEAT1', 'TRMT44', 'CYP2D6', 'CLN3', 'GK5', 'ZNF276', 'TMEM80', 'RBSN', 'FDPS', 'TRERF1', 'EGF', 'HOMEZ', '