In [1]:
import llm2geneset
import openai
import pandas as pd
from scipy.stats import hypergeom
from statsmodels.stats.multitest import multipletests
import re
import numpy as np
client = openai.Client()

In [2]:
def load(gmt_file):
    gmt = llm2geneset.read_gmt(gmt_file)
    def clean_elements(array):
        """Use regular expression to remove (GO:xxx) substring,  
        R-HSA-xxx substrings, and WPxxx substrings"""
        cleaned_array = []
        for element in array:
            cleaned_element = re.sub(r'\s*\(GO:\d+\)\s*|\s*R-HSA-\d+\s*|\s*WP\d+\s*', '', element)
            cleaned_array.append(cleaned_element)
        return cleaned_array
    gmt["descr_cleaned"] = clean_elements(gmt["descr"]) 
    gmtemb = llm2geneset.get_embeddings(client, gmt["descr_cleaned"])
    return (gmt, gmtemb)

In [3]:
go, goemb = load("libs_human/gmt/GO_Biological_Process_2023.txt")

In [4]:
wiki, wikiemb = load("libs_human/gmt/Reactome_2022.txt")
kegg, keggemb = load("libs_human/gmt/KEGG_2021_Human.txt")
react, reactemb = load("libs_human/gmt/Reactome_2022.txt")

In [5]:

def eval_geneset(iou_output, query, queryemb):
    # Convert lists to matrices
    matrix1 = np.vstack(queryemb)  # Shape (m, d), where m is the number of vectors
    matrix2 = np.vstack(goemb)  # Shape (n, d), where n is the number of vectors

    # Compute pairwise dot products using matrix multiplication
    A = matrix1 @ matrix2.T  # Shape (m, n)

    # Which column has highest dot product. 
    max_cols = np.argmax(A, axis=1)
    max_values = A[np.arange(A.shape[0]), max_cols]
    mask = max_values > 0.7 

    # Filter to rows where > 0.7
    f_cols = max_cols[mask].tolist()   # column indices of max
    f_vals = max_values[mask]     # corresponding max values
    f_rows = np.where(mask)[0].tolist()   # (optional) which row indices

    for row, col in zip(f_rows, f_cols):
        #print(f"Value {A[row, col]:.2f} at index ({row}, {col})")
        #print(query["descr_cleaned"][row], go["descr_cleaned"][col])
        #print(wikipathway["genes"][row])
        #print(go["genes"][col])
        curated_genes = go["genes"][col]
        parsed_llm_genes = query["genes"][row]
        
        descr_llm.append(query["descr_cleaned"][row])
        descr_go_genes.append(go["genes"][col])

        llm_genes = list(set(parsed_llm_genes)) # make sure unique genes are selected

        intersection = set(llm_genes).intersection(set(curated_genes))
        p_val = hypergeom.sf(len(intersection)-1,
                                19846, 
                                len(curated_genes), 
                                len(llm_genes))
        
        # generatio == recall 
        generatio = float(len(intersection)) / len(set(curated_genes))                                                                                                 
        bgratio = float(len(set(llm_genes))) / 19846                                                                                                    
                                                                                                                                                            
        richFactor = None                                                                                                                                      
        foldEnrich = None                                                                                                                                      
        if len(llm_genes) > 0:   
            # richFactor == precision                                                                                                                              
            richFactor = float(len(intersection)) / len(set(llm_genes))                                                                                        
            foldEnrich = generatio / bgratio                                                                                                                   

        
        x = {
            "query": query["descr_cleaned"][row],
            "go": go["descr_cleaned"][col],
            'ncurated': len(curated_genes),
            'nllm': len(llm_genes),
            'ninter': len(intersection),
            'generatio': generatio,
            'bgratio': bgratio,
            'richFactor': richFactor,
            'foldEnrich': foldEnrich,
            'jaccard': float(len(intersection)) / len(set(curated_genes).union(set(llm_genes))),
            'p_val': p_val
        }
        iou_output.append(x)

iou_output = []
descr_llm = []
descr_go_genes = []
test_llm = []

eval_geneset(iou_output, descr_llm, descr_go_genes, kegg, keggemb)
eval_geneset(iou_output, descr_llm, descr_go_genes, react, reactemb)
eval_geneset(iou_output, descr_llm, descr_go_genes, react, reactemb)


In [81]:
df = pd.DataFrame(iou_output)
_, pvals_corrected, _, _ = multipletests(df['p_val'], method='bonferroni')
df['p_val_adj'] = pvals_corrected
print(df.shape)
print(df[(df['p_val_adj'] < 0.01) ].shape[0] / df.shape[0])

(1418, 12)
0.7031029619181947


In [97]:
df.jaccard.mean()

0.1462119980862434

In [105]:
models = ["gpt-4o-mini-2024-07-18", "gpt-3.5-turbo-0125", "gpt-4o-2024-08-06"]
model = models[0]
aclient = openai.AsyncClient()
llm_genes = await llm2geneset.get_genes_bench(aclient,descr_llm,model=model,prompt_type='basic',use_sysmsg=True)

<bound method DataFrame.query of                                             query  \
0                               Adherens junction   
1                 Adipocytokine signaling pathway   
2     Amino sugar and nucleotide sugar metabolism   
3                     Aminoacyl-tRNA biosynthesis   
4             Antigen processing and presentation   
...                                           ...   
1413     tRNA Modification In Nucleus And Cytosol   
1414                              tRNA Processing   
1415             tRNA Processing In Mitochondrion   
1416                   tRNA Processing In Nucleus   
1417          trans-Golgi Network Vesicle Budding   

                                                     go  ncurated  nllm  \
0                            Adherens Junction Assembly         9    71   
1               Adiponectin-Activated Signaling Pathway         5    69   
2                         Amino Sugar Metabolic Process         9    48   
3                             