In [1]:
import llm2geneset
import openai
import pandas as pd
from scipy.stats import hypergeom
from statsmodels.stats.multitest import multipletests
import re
import numpy as np
client = openai.Client()

In [2]:
def load(gmt_file):
    """Load GMT files and generate embeddings for gene set descriptions."""
    gmt = llm2geneset.read_gmt(gmt_file)
    def clean_elements(array):
        """Use regular expression to remove (GO:xxx) substring,  
        R-HSA-xxx substrings, and WPxxx substrings"""
        cleaned_array = []
        for element in array:
            cleaned_element = re.sub(r'\s*\(GO:\d+\)\s*|\s*R-HSA-\d+\s*|\s*WP\d+\s*', '', element)
            cleaned_array.append(cleaned_element)
        return cleaned_array
    gmt["descr_cleaned"] = clean_elements(gmt["descr"]) 
    gmtemb = llm2geneset.get_embeddings(client, gmt["descr_cleaned"])
    return (gmt, gmtemb)

In [3]:
# Get GO reference.
go, goemb = load("libs_human/gmt/GO_Biological_Process_2023.txt")

In [4]:
# Load query gene sets and embeddings.
wiki, wikiemb = load("libs_human/gmt/Reactome_2022.txt")
kegg, keggemb = load("libs_human/gmt/KEGG_2021_Human.txt")
react, reactemb = load("libs_human/gmt/Reactome_2022.txt")

In [13]:

def eval_geneset(iou_output, descr_llm, descr_go, descr_go_genes, query, queryemb):
    # Convert embedding lists to matricies.
    matrix1 = np.vstack(queryemb)  # Shape (m, d), where m is the number of vectors
    matrix2 = np.vstack(goemb)  # Shape (n, d), where n is the number of vectors

    # Compute pairwise dot products using matrix multiplication
    A = matrix1 @ matrix2.T  # Shape (m, n)

    # Which column has highest dot product. 
    max_cols = np.argmax(A, axis=1)
    max_values = A[np.arange(A.shape[0]), max_cols]
    mask = max_values > 0.7 

    # Filter to rows where > 0.7
    f_cols = max_cols[mask].tolist()   # column indices of max
    f_vals = max_values[mask]     # corresponding max values
    f_rows = np.where(mask)[0].tolist()   # (optional) which row indices

    for row, col in zip(f_rows, f_cols):
        #print(f"Value {A[row, col]:.2f} at index ({row}, {col})")
        #print(query["descr_cleaned"][row], go["descr_cleaned"][col])
        #print(wikipathway["genes"][row])
        #print(go["genes"][col])
        curated_genes = go["genes"][col]
        parsed_llm_genes = query["genes"][row]
        
        # Save descriptions used in query gene sets for LLM eval
        descr_llm.append(query["descr_cleaned"][row])
        descr_go.append(go["descr_cleaned"][col])
        # Save reference genes for LLM eval.
        descr_go_genes.append(go["genes"][col])

        llm_genes = list(set(parsed_llm_genes)) # make sure unique genes are selected

        intersection = set(llm_genes).intersection(set(curated_genes))
        p_val = hypergeom.sf(len(intersection)-1,
                                19846, 
                                len(curated_genes), 
                                len(llm_genes))
        
        # generatio == recall 
        generatio = float(len(intersection)) / len(set(curated_genes))                                                                                                 
        bgratio = float(len(set(llm_genes))) / 19846                                                                                                    
                                                                                                                                                            
        richFactor = None                                                                                                                                      
        foldEnrich = None                                                                                                                                      
        if len(llm_genes) > 0:   
            # richFactor == precision                                                                                                                              
            richFactor = float(len(intersection)) / len(set(llm_genes))                                                                                        
            foldEnrich = generatio / bgratio                                                                                                                   

        
        x = {
            "model": "human",
            "query": query["descr_cleaned"][row],
            "go": go["descr_cleaned"][col],
            'ncurated': len(curated_genes),
            'nllm': len(llm_genes),
            'ninter': len(intersection),
            'generatio': generatio,
            'bgratio': bgratio,
            'richFactor': richFactor,
            'foldEnrich': foldEnrich,
            'jaccard': float(len(intersection)) / len(set(curated_genes).union(set(llm_genes))),
            'p_val': p_val
        }
        iou_output.append(x)

iou_output = []
descr_llm = []
descr_go = []
descr_go_genes = []
test_llm = []

eval_geneset(iou_output, descr_llm, descr_go, descr_go_genes, kegg, keggemb)
eval_geneset(iou_output, descr_llm, descr_go, descr_go_genes, react, reactemb)
eval_geneset(iou_output, descr_llm, descr_go, descr_go_genes, react, reactemb)


In [9]:
aclient = openai.AsyncClient()
models = ["gpt-4o-mini-2024-07-18", "gpt-3.5-turbo-0125", "gpt-4o-2024-08-06"]

for model in models:
    descr_llm_genes = await llm2geneset.get_genes_bench(aclient,descr_llm,model=model,prompt_type='basic',use_sysmsg=True)
    for idx in range(len(descr_llm)):
        curated_genes = descr_go_genes[idx]
        parsed_llm_genes = descr_llm_genes[idx]['parsed_genes']
        llm_genes = list(set(parsed_llm_genes)) # make sure unique genes are selected

        intersection = set(llm_genes).intersection(set(curated_genes))
        union = set(llm_genes).union(set(curated_genes))
        jaccard_similarity = len(intersection) / len(union) if len(union) > 0 else 0
        p_val = hypergeom.sf(len(intersection)-1,
                                19846, 
                                len(curated_genes), 
                                len(llm_genes))

        # generatio == recall 
        generatio = float(len(intersection)) / len(set(curated_genes))                                                                                                 
        bgratio = float(len(set(llm_genes))) / 19846                                                                                                    
                                                                                                                                                            
        richFactor = None                                                                                                                                      
        foldEnrich = None                                                                                                                                      
        if len(llm_genes) > 0:   
            # richFactor == precision                                                                                                                              
            richFactor = float(len(intersection)) / len(set(llm_genes))                                                                                        
            foldEnrich = generatio / bgratio                                                                                                                   
    
        x = {
            "model": model,
            "query": descr_llm[idx],
            "go": descr_go[idx],
            'ncurated': len(curated_genes),
            'nllm': len(llm_genes),
            'ninter': len(intersection),
            'generatio': generatio,
            'bgratio': bgratio,
            'richFactor': richFactor,
            'foldEnrich': foldEnrich,
            'jaccard': float(len(intersection)) / len(set(curated_genes).union(set(llm_genes))),
            'p_val': p_val
        }
        iou_output.append(x)

 35%|██████████████████████████████████████████████▍                                                                                       | 491/1418 [00:29<00:43, 21.53it/s]

In [10]:
df = pd.DataFrame(iou_output)
df.to_csv("outputs/human_agreement.csv", index=False, sep="\t")

(1418, 13)
0.5698166431593794
0.20954245124679313
0.1838016686768446
0.09579101546853223
