In [1]:
import json
from owlready2 import get_ontology
import pandas as pd
import numpy as np
import re
import faiss
from tqdm import tqdm
from fuzzywuzzy import process
from owlready2 import get_ontology
import pickle
from sentence_transformers import SentenceTransformer
import swifter
import hashlib
import duckdb
import uuid
from datetime import datetime
from ete3 import NCBITaxa
import torch

ncbi = NCBITaxa()
pd.set_option('display.max_rows', 50)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ncbi_ontology = get_ontology("ncbitaxon.owl").load()

In [30]:
COLUMN_NAME_MAP = {
        "sample_id": ["sample_id", "SampleID", "sampleid", "SampleId"],
        "environment": ["environment", "env", "condition", "habitat"],
        "species": ["species", "organism", "taxa"]
    }

def normalize_column(name):
    name = re.sub(r"[^a-zA-Z0-9]", "_", name)
    name = re.sub(r"([a-z])([A-Z])", r"\1_\2", name)  
    name = re.sub(r"_+", "_", name)

    # Map known columns to standardized names
    for standard_name, variations in COLUMN_NAME_MAP.items():
        if name in variations:
            return standard_name
    return name.strip("_").lower()

taxa_levels = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]

In [31]:
df_test_i = pd.read_excel("vdpTable.xlsx") 
df_test_j = pd.read_excel("FujitaTable.xlsx")

# for i in range(len(df_list)):
#     df_test_0.columns = [normalize_column(col) for col in df_test_0.columns] 
#     df_test_0 = df_test_0.drop_duplicates(subset=["kingdom", "phylum", "class", "order", "family", "genus", "species"], keep="first")
#     df_test_0 = df_test_0.replace("unidentified", np.nan)

#     for rank in taxa_levels:
#         df[rank] = (df_test_0[rank].str.lower()
#         .str.strip()
#         .str.replace(r"\s*[-/]\s*", "TEMPREPLACE", regex=True)  
#         .str.replace(r"[^\w\s]", "", regex=True)                
#         .str.replace(r"\s+", " ", regex=True)                   
#         .str.replace("TEMPREPLACE", "-", regex=False)           
#     )

#     df_test_0["lowest_known_taxon"] = df_test_0[taxa_levels].apply(
#         lambda x: f"{x.dropna().iloc[-1]}/{x.dropna().index[-1]}" if not x.dropna().empty else np.nan, 
#         axis=1
#     )
#     df_test_0 = df_test_{i}[["lowest_known_taxon"]]

def process_df(df):
    df.columns = [normalize_column(col) for col in df.columns]
    df = df.drop_duplicates(subset=["kingdom", "phylum", "class", "order", "family", "genus", "species"], keep="first")
    df = df.replace("unidentified", np.nan)

    for rank in taxa_levels:
        df[rank] = (df[rank].str.lower()
                    .str.strip()
                    .str.replace(r"\s*[-/]\s*", "TEMPREPLACE", regex=True)
                    .str.replace(r"[^\w\s]", "", regex=True)
                    .str.replace(r"\s+", " ", regex=True)
                    .str.replace("TEMPREPLACE", "-", regex=False)
                   )

    df["lowest_known_taxon"] = df[taxa_levels].apply(
        lambda x: f"{x.dropna().iloc[-1]}/{x.dropna().index[-1]}" if not x.dropna().empty else np.nan,
        axis=1
    )
    return df[["lowest_known_taxon"]]

df_test_i = process_df(df_test_i)
df_test_j = process_df(df_test_j)
df_combined = pd.concat([df_test_i, df_test_j], ignore_index=True)
df_combined



Unnamed: 0,lowest_known_taxon
0,streptococcus/genus
1,haemophilus/genus
2,dispar-parvula/species
3,atypica-dispar/species
4,neisseria/genus
...,...
512,aquitalea/genus
513,anaeromusa-anaeroarcus/genus
514,propionivibrio/genus
515,microbacter/genus


In [32]:
df_test = df_combined.copy()

# Translate a single name to a taxon IRI
def get_taxid(name):
    name = name.strip().lower()
    result = ncbi.get_name_translator([name])
    if result and name in result:
        return f"NCBITaxon_{result[name][0]}"
    return None

# Extract genus portion, split on "-" if needed, and resolve all parts
def get_taxids_from_taxon_string(value):
    if not isinstance(value, str):
        return None

    # 1. Take only the part before "/" (if present)
    part = value.split("/")[0].strip()

    # 2. Split on "-" to get all candidate names
    names = [n.strip() for n in part.split("-") if n.strip()]

    # 3. Look up each name
    ids = [get_taxid(name) for name in names]
    ids = [i for i in ids if i]  # remove None values

    return ids if ids else None

# Apply to dataframe
df_test["ete3"] = df_test["lowest_known_taxon"].apply(get_taxids_from_taxon_string)

# Apply all the steps
#df_test["ete3"] = df_test["taxon_id"]

In [33]:
# df_test.at[1, "taxon_id"]= ["NCBITaxon_1540"]
# df_test.at[21, "taxon_id"]= ["NCBITaxon_83767"]

df_test = df_test.dropna()
df_test

Unnamed: 0,lowest_known_taxon,ete3
0,streptococcus/genus,[NCBITaxon_1301]
1,haemophilus/genus,[NCBITaxon_724]
2,dispar-parvula/species,[NCBITaxon_509375]
3,atypica-dispar/species,[NCBITaxon_509375]
4,neisseria/genus,[NCBITaxon_482]
...,...,...
512,aquitalea/genus,[NCBITaxon_407217]
513,anaeromusa-anaeroarcus/genus,"[NCBITaxon_81463, NCBITaxon_151038]"
514,propionivibrio/genus,[NCBITaxon_83766]
515,microbacter/genus,[NCBITaxon_1548510]


In [34]:
df_test = df_test[df_test["ete3"].notna()]
df_test.reset_index(drop=True, inplace=True)
df_test["lowest_known_taxon"]


0               streptococcus/genus
1                 haemophilus/genus
2            dispar-parvula/species
3            atypica-dispar/species
4                   neisseria/genus
                   ...             
239                 aquitalea/genus
240    anaeromusa-anaeroarcus/genus
241            propionivibrio/genus
242               microbacter/genus
243              paludibacter/genus
Name: lowest_known_taxon, Length: 244, dtype: object

THIS DF IS EXPORTED AND TRANSFORMED BY LLM TO INTRODUCE TYPOS

In [35]:

df2 = pd.read_csv("corruptedvdpTable.csv") 
df_test ["lowest_known_taxon_c"] = df2["Corrupted Entry"]
df_test["ete3_c"] = df_test["lowest_known_taxon_c"].apply(get_taxids_from_taxon_string)
df_test


Unnamed: 0,lowest_known_taxon,ete3,lowest_known_taxon_c,ete3_c
0,streptococcus/genus,[NCBITaxon_1301],STC/genus,
1,haemophilus/genus,[NCBITaxon_724],haemophilus/genus,[NCBITaxon_724]
2,dispar-parvula/species,[NCBITaxon_509375],dispra-parvula/species,
3,atypica-dispar/species,[NCBITaxon_509375],atypia-disbpar/species,
4,neisseria/genus,[NCBITaxon_482],neisseria/genus,[NCBITaxon_482]
...,...,...,...,...
239,aquitalea/genus,[NCBITaxon_407217],Aqutalea/genus,
240,anaeromusa-anaeroarcus/genus,"[NCBITaxon_81463, NCBITaxon_151038]",Anaeromus-anaarocus/genus,
241,propionivibrio/genus,[NCBITaxon_83766],Propiovibrio/genus,
242,microbacter/genus,[NCBITaxon_1548510],Micrbacter/genus,


In [36]:
with open("taxon_data_r.pkl", "rb") as f:
    taxon_data_r = pickle.load(f)

In [37]:
# from typing import Optional

# def find_taxon_id_faiss(
#     query: str | list[str], 
#     st_number,
#     threshold: float = 70
# ) -> str | None | list[str | None]:
    
#     single_input = isinstance(query, str)
#     queries = [query] if single_input else query
    
#     results = []
    
#     for q in queries:
#         try:
#             current_query = q.lower()
            
#             # Encode and search
#             query_embedding = encoder.encode([current_query])
#             faiss.normalize_L2(query_embedding)
#             D, I = index.search(query_embedding, k=15)
            
#             # Retrieve candidates
#             candidates = [taxon_data_r[i] for i in I[0]]
            
#             # Handle rank filtering
#             if "/" in current_query:
#                 query_part, rank = current_query.split("/")
#                 candidates = [c for c in candidates if c[1] == rank]
#                 current_query = query_part
            
#             if not candidates:
#                 results.append(None)
#                 continue
            
#             # Find best match
#             candidate_names = [c[0] for c in candidates]
#             best_match, score = process.extractOne(current_query, candidate_names)
            
#             results.append(
#                 next((iri for name, rank, iri in candidates if name == best_match), None)
#                 if score >= threshold
#                 else None
#             )
            
#         except Exception as e:
#             print(f"Error processing query "{q}": {str(e)}")
#             results.append(None)
    
#     return results[0] if single_input else results

In [38]:
# import time
# def find_taxon_id_faiss(query, threshold: float = 70):
#     # Ensure input is treated as a list
    
    
#     queries = [query] if isinstance(query, str) else query

#     results = []

#     for q in queries:
#         try:
#             # Split at "/" to separate rank (if present)
#             if "/" in q:
#                 taxon_part, rank = q.split("/")
#                 taxon_names = taxon_part.split("-")
#             else:
#                 taxon_names = q.split("-")
#                 rank = None

#             taxon_ids = []
#             for name in taxon_names:
#                 query_embedding = encoder.encode([name.lower()])
#                 faiss.normalize_L2(query_embedding)
#                 D, I = index.search(query_embedding, k=5)

#                 candidates = [taxon_data_r[i] for i in I[0]]
#                 if rank:
#                     candidates = [c for c in candidates if c[1] == rank]

#                 if candidates:
#                     candidate_names = [c[0] for c in candidates]
#                     matches = process.extract(name, candidate_names)
#                     top_matches = [m for m in matches if m[1] >= threshold][:5]
                    
#                     # CORRECTED: Use proper unpacking of candidate tuple
#                     name_to_iri = {candidate[0]: candidate[2] for candidate in candidates}
                    
#                     for match, score in top_matches:
#                         if match in name_to_iri:
#                             taxon_ids.append(name_to_iri[match])

#             results.append(taxon_ids if taxon_ids else None)

#         except Exception as e:
#             print(f"Error processing '{q}': '{e}'")
#             results.append(None)

#     return results[0]


In [39]:
def find_taxon_id_faiss(query, rank_on: bool = False, threshold: float = 0, top_n: int = 5):
    """
    Find taxon IDs for queries, returning top N matches per taxon name component
    
    Args:
        query: Input string or list of strings (e.g., "homo-sapiens" or ["homo", "pan"])
        threshold: Minimum similarity score (0-100)
        top_n: Maximum number of matches to return per taxon component
        
    Returns:
        For single queries: List of matching taxon IDs (up to top_n) or None
        For multiple queries: List of lists (each containing up to top_n matches)
    """
    queries = [query] if isinstance(query, str) else query
    results = []

    for q in queries:
        try:
            if not q or pd.isna(q):  # Handle empty/NA values
                results.append(None)
                continue
                
            # Parse query
            
            if "/" in q:
                taxon_part, rank = q.split("/", 1)
                taxon_names = [n.strip() for n in taxon_part.split("-") if n.strip()]
                if not rank_on:
                    rank = None
            else:
                taxon_names = [n.strip() for n in q.split("-") if n.strip()]
                rank = None

            all_matches = []
            for name in taxon_names:
                # Get embeddings and search
                query_embedding = encoder.encode([name.lower()])
                faiss.normalize_L2(query_embedding)
                D, I = index.search(query_embedding, k=top_n*2)  # Search extra to account for rank filtering
                
                # Process candidates
                candidates = [taxon_data_r[i] for i in I[0] if i < len(taxon_data_r)]
                if rank:
                    candidates = [c for c in candidates if c[1] == rank]
                
                if not candidates:
                    continue
                
                # Get top matches
                candidate_names = [c[0] for c in candidates]
                matches = process.extract(name, candidate_names, limit=top_n*2)
                top_matches = [
                    (match, score, next(iri for (n, r, iri) in candidates if n == match))
                    for match, score in matches 
                    if score >= threshold
                ][:top_n]
                
                all_matches.extend(top_matches)

            # Format results
            if all_matches:
                # Deduplicate while preserving order
                seen = set()
                unique_matches = [
                    iri for match, score, iri in all_matches
                    if not (iri in seen or seen.add(iri))
                ][:top_n]
                results.append(unique_matches if unique_matches else None)
            else:
                results.append(None)
                
        except Exception as e:
            print(f"Error processing '{q}': {str(e)}")
            results.append(None)

    return results[0] if isinstance(query, str) else results

In [40]:
# encoder = SentenceTransformer("BAAI/bge-base-en-v1.5")
# index = faiss.read_index("ncbi_faiss_bgesmallenv15.index")
# columnname = "bgesmallenv15"
# #columnname = f"{indices[i].strip('.index').split('_')[2]}"
# df_test[columnname] = df_test["lowest_known_taxon"].apply(find_taxon_id_faiss)
# df_test[f"{columnname}_c"] = df_test["lowest_known_taxon_c"].apply(find_taxon_id_faiss)

In [44]:
import gc
import time

encoders = [
    "menadsa/BioS-MiniLM",
    "all-MiniLM-L6-v2",
    "BAAI/bge-base-en-v1.5",
    "pritamdeka/S-BioBert-snli-multinli-stsb",
    "intfloat/e5-small-v2",
    "intfloat/e5-large-v2",
    "intfloat/multilingual-e5-large",
    "juanpablomesa/all-mpnet-base-v2-bioasq-matryoshka",
    "NeuML/pubmedbert-base-embeddings"
    ]
indices = [
"ncbi_faiss_biosminilm.index",
"ncbi_faiss_allminilml6v2.index",
"ncbi_faiss_bgebaseenv15.index",
"ncbi_faiss_sbiobertsnlimultinlistsb.index",
"ncbi_faiss_e5smallv2.index",
"ncbi_faiss_e5largev2.index",
"ncbi_faiss_multilinguale5large.index",
"ncbi_faiss_allmpnetbasev2bioasqmatryoshka.index",
"ncbi_faiss_pubmedbertbaseembeddings.index"
]


for i in range(8,len(indices)):
    start_time = time.time()
    print(i)
    
    encoder = SentenceTransformer(encoders[i])
    index = faiss.read_index(indices[i])
    columnname = f"{indices[i].strip('.index').split('_')[2]}"
    df_test[columnname] = df_test["lowest_known_taxon"].apply(find_taxon_id_faiss)
    df_test[f"{columnname}_c"] = df_test["lowest_known_taxon_c"].apply(find_taxon_id_faiss)
    del encoder
    del index   
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(elapsed_time)
    gc.collect()
    torch.cuda.empty_cache()



# df_test["allminilmv6"] = df_test["lowest_known_taxon"].apply(
#     lambda x: find_taxon_id_faiss(x)
# )



8
201.4943709373474


In [None]:
encoder = SentenceTransformer("all-MiniLM-L6-v2")
index = faiss.read_index("ncbi_taxons_faiss_allminilmv6.index")

In [45]:
df_test

Unnamed: 0,lowest_known_taxon,ete3,lowest_known_taxon_c,ete3_c,biosminilm,biosminilm_c,allminilml6v2,allminilml6v2_c,bgebaseenv15,bgebaseenv15_c,...,e5smallv2,e5smallv2_c,e5largev2,e5largev2_c,multilinguale5larg,multilinguale5larg_c,allmpnetbasev2bioasqmatryoshka,allmpnetbasev2bioasqmatryoshka_c,pubmedbertbaseembeddings,pubmedbertbaseembeddings_c
0,streptococcus/genus,[NCBITaxon_1301],STC/genus,,"[NCBITaxon_1301, NCBITaxon_1214155, NCBITaxon_...","[NCBITaxon_137528, NCBITaxon_421157, NCBITaxon...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_211...","[NCBITaxon_53171, NCBITaxon_367637, NCBITaxon_...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_276...","[NCBITaxon_1352352, NCBITaxon_133923, NCBITaxo...",...,"[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_502...","[NCBITaxon_1352352, NCBITaxon_133923, NCBITaxo...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_142...","[NCBITaxon_1352352, NCBITaxon_137528, NCBITaxo...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_828...","[NCBITaxon_1352352, NCBITaxon_2588968, NCBITax...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_176...","[NCBITaxon_137528, NCBITaxon_153378, NCBITaxon...","[NCBITaxon_1301, NCBITaxon_1346, NCBITaxon_130...","[NCBITaxon_1898542, NCBITaxon_1069644, NCBITax..."
1,haemophilus/genus,[NCBITaxon_724],haemophilus/genus,[NCBITaxon_724],"[NCBITaxon_724, NCBITaxon_123834, NCBITaxon_94...","[NCBITaxon_724, NCBITaxon_123834, NCBITaxon_94...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_727, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_727, ...",...,"[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ..."
2,dispar-parvula/species,[NCBITaxon_509375],dispra-parvula/species,,"[NCBITaxon_509375, NCBITaxon_112346, NCBITaxon...","[NCBITaxon_78922, NCBITaxon_2495529, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_112346, NCBITaxon...","[NCBITaxon_78922, NCBITaxon_2495529, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_2502799, NCBITaxo...","[NCBITaxon_78922, NCBITaxon_79506, NCBITaxon_2...",...,"[NCBITaxon_509375, NCBITaxon_509376, NCBITaxon...","[NCBITaxon_2495529, NCBITaxon_78922, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_295252, NCBITaxon...","[NCBITaxon_78922, NCBITaxon_2495529, NCBITaxon...","[NCBITaxon_943112, NCBITaxon_491984, NCBITaxon...","[NCBITaxon_2495529, NCBITaxon_509375, NCBITaxo...","[NCBITaxon_509375, NCBITaxon_316177, NCBITaxon...","[NCBITaxon_2495529, NCBITaxon_78922, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_516669, NCBITaxon...","[NCBITaxon_2495529, NCBITaxon_78922, NCBITaxon..."
3,atypica-dispar/species,[NCBITaxon_509375],atypia-disbpar/species,,"[NCBITaxon_3233203, NCBITaxon_2612832, NCBITax...","[NCBITaxon_2612832, NCBITaxon_988072, NCBITaxo...","[NCBITaxon_3233203, NCBITaxon_2612832, NCBITax...","[NCBITaxon_2612832, NCBITaxon_988072, NCBITaxo...","[NCBITaxon_2819460, NCBITaxon_2761155, NCBITax...","[NCBITaxon_2612832, NCBITaxon_689118, NCBITaxo...",...,"[NCBITaxon_2819460, NCBITaxon_249409, NCBITaxo...","[NCBITaxon_2612832, NCBITaxon_88328, NCBITaxon...","[NCBITaxon_2819460, NCBITaxon_1698944, NCBITax...","[NCBITaxon_2612832, NCBITaxon_88328, NCBITaxon...","[NCBITaxon_1698944, NCBITaxon_2761155, NCBITax...","[NCBITaxon_1187032, NCBITaxon_1033, NCBITaxon_...","[NCBITaxon_2612832, NCBITaxon_988072, NCBITaxo...","[NCBITaxon_689118, NCBITaxon_316172, NCBITaxon...","[NCBITaxon_3233203, NCBITaxon_2819460, NCBITax...","[NCBITaxon_2612832, NCBITaxon_88328, NCBITaxon..."
4,neisseria/genus,[NCBITaxon_482],neisseria/genus,[NCBITaxon_482],"[NCBITaxon_482, NCBITaxon_325214, NCBITaxon_49...","[NCBITaxon_482, NCBITaxon_325214, NCBITaxon_49...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_32...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_32...",...,"[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_20...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_20..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
239,aquitalea/genus,[NCBITaxon_407217],Aqutalea/genus,,"[NCBITaxon_407217, NCBITaxon_2480983, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_115444, NCBITaxon...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_115444, NCBITaxon...",...,"[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_3044273, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1872623, NCBITaxo...","[NCBITaxon_407217, NCBITaxon_1537400, NCBITaxo..."
240,anaeromusa-anaeroarcus/genus,"[NCBITaxon_81463, NCBITaxon_151038]",Anaeromus-anaarocus/genus,,"[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_81464, NCBITaxon_2...","[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_2676953, NCBITaxon...",...,"[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_151038, NCBITaxon_...","[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_151038, NCBITaxon_...","[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_2946544, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_1872520, NCBITaxon...","[NCBITaxon_81463, NCBITaxon_3397854, NCBITaxon..."
241,propionivibrio/genus,[NCBITaxon_83766],Propiovibrio/genus,,"[NCBITaxon_83766, NCBITaxon_2976531, NCBITaxon...","[NCBITaxon_2991701, NCBITaxon_2004668, NCBITax...","[NCBITaxon_83766, NCBITaxon_2212460, NCBITaxon...","[NCBITaxon_662, NCBITaxon_2991701, NCBITaxon_2...","[NCBITaxon_83766, NCBITaxon_2212460, NCBITaxon...","[NCBITaxon_83766, NCBITaxon_662, NCBITaxon_299...",...,"[NCBITaxon_83766, NCBITaxon_2212460, NCBITaxon...","[NCBITaxon_2991701, NCBITaxon_2004668, NCBITax...","[NCBITaxon_83766, NCBITaxon_2212460, NCBITaxon...","[NCBITaxon_85274, NCBITaxon_2991701, NCBITaxon...","[NCBITaxon_83766, NCBITaxon_2976531, NCBITaxon...","[NCBITaxon_83766, NCBITaxon_2212460, NCBITaxon...","[NCBITaxon_83766, NCBITaxon_2976531, NCBITaxon...","[NCBITaxon_83766, NCBITaxon_2212460, NCBITaxon...","[NCBITaxon_83766, NCBITaxon_2212460, NCBITaxon...","[NCBITaxon_83766, NCBITaxon_2212460, NCBITaxon..."
242,microbacter/genus,[NCBITaxon_1548510],Micrbacter/genus,,"[NCBITaxon_1548510, NCBITaxon_2502250, NCBITax...","[NCBITaxon_1675603, NCBITaxon_2499509, NCBITax...","[NCBITaxon_1548510, NCBITaxon_1972636, NCBITax...","[NCBITaxon_2984579, NCBITaxon_105113, NCBITaxo...","[NCBITaxon_1548510, NCBITaxon_1972636, NCBITax...","[NCBITaxon_1429240, NCBITaxon_2219151, NCBITax...",...,"[NCBITaxon_1548510, NCBITaxon_1972636, NCBITax...","[NCBITaxon_56333, NCBITaxon_1535469, NCBITaxon...","[NCBITaxon_1548510, NCBITaxon_1972636, NCBITax...","[NCBITaxon_647744, NCBITaxon_53457, NCBITaxon_...","[NCBITaxon_1548510, NCBITaxon_1972636, NCBITax...","[NCBITaxon_1548510, NCBITaxon_1972636, NCBITax...","[NCBITaxon_1548510, NCBITaxon_1972636, NCBITax...","[NCBITaxon_3003933, NCBITaxon_2984579, NCBITax...","[NCBITaxon_1548510, NCBITaxon_1972636, NCBITax...","[NCBITaxon_2984579, NCBITaxon_206417, NCBITaxo..."


In [150]:
# def find_taxon_id_faiss_nr(query, threshold: float = 70):
#     # Ensure input is treated as a list
#     queries = [query] if isinstance(query, str) else query
    
#     results = []

#     for q in queries:
#         try:
#             # Split at "/" to separate rank (if present)
#             if "/" in q:
#                 taxon_part, rank = q.split("/")
#                 taxon_names = taxon_part.split("-")
#             else:
#                 taxon_names = q.split("-")

#             taxon_ids = []
#             for name in taxon_names:
#                 query_embedding = encoder.encode([name.lower()])
#                 faiss.normalize_L2(query_embedding)
#                 D, I = index.search(query_embedding, k=5)

#                 candidates = [taxon_data_r[i] for i in I[0]]

#                 if candidates:
#                     candidate_names = [c[0] for c in candidates]
#                     best_match, score = process.extractOne(name, candidate_names)
#                     if score >= threshold:
#                         match_iri = next(iri for (n, r, iri) in candidates if n == best_match)
#                         taxon_ids.append(match_iri)

#             # Always return a list: either empty or filled
#             results.append(taxon_ids if taxon_ids else None)

#         except Exception as e:
#             print(f"Error processing '{q}': '{e}'")
#             results.append(None)


#     return results

In [151]:

find_taxon_id_faiss_nr("pandoraea/genus")

NameError: name 'find_taxon_id_faiss_nr' is not defined

In [152]:
df_test

Unnamed: 0,lowest_known_taxon,ete3,lowest_known_taxon_c,ete3_c,allminilmv6,allminilmv6_c,bgebas,bgebas_c,bgesmall,bgesmall_c,biobertsnl,biobertsnl_c,distilbert,distilbert_c,e5small,e5small_c,e5larg,e5larg_c,multile5larg,multile5larg_c
0,streptococcus/genus,[NCBITaxon_1301],STC/genus,,"[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_211...","[NCBITaxon_53171, NCBITaxon_367637, NCBITaxon_...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_276...","[NCBITaxon_1352352, NCBITaxon_133923, NCBITaxo...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_130...","[NCBITaxon_1352352, NCBITaxon_133923, NCBITaxo...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_258...","[NCBITaxon_166218, NCBITaxon_3064781, NCBITaxo...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_302...","[NCBITaxon_137528, NCBITaxon_172500, NCBITaxon...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_502...","[NCBITaxon_1352352, NCBITaxon_133923, NCBITaxo...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_142...","[NCBITaxon_1352352, NCBITaxon_137528, NCBITaxo...","[NCBITaxon_1301, NCBITaxon_1306, NCBITaxon_828...","[NCBITaxon_1352352, NCBITaxon_2588968, NCBITax..."
1,haemophilus/genus,[NCBITaxon_724],haemophilus/genus,[NCBITaxon_724],"[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_727, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_727, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ...","[NCBITaxon_724, NCBITaxon_740, NCBITaxon_726, ..."
2,dispar-parvula/species,[NCBITaxon_509375],dispra-parvula/species,,"[NCBITaxon_509375, NCBITaxon_112346, NCBITaxon...","[NCBITaxon_78922, NCBITaxon_2495529, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_2502799, NCBITaxo...","[NCBITaxon_78922, NCBITaxon_79506, NCBITaxon_2...","[NCBITaxon_509375, NCBITaxon_465108, NCBITaxon...","[NCBITaxon_2495529, NCBITaxon_78922, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_1630944, NCBITaxo...","[NCBITaxon_2495529, NCBITaxon_78922, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_516669, NCBITaxon...","[NCBITaxon_2495529, NCBITaxon_78922, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_509376, NCBITaxon...","[NCBITaxon_2495529, NCBITaxon_78922, NCBITaxon...","[NCBITaxon_509375, NCBITaxon_295252, NCBITaxon...","[NCBITaxon_78922, NCBITaxon_2495529, NCBITaxon...","[NCBITaxon_943112, NCBITaxon_491984, NCBITaxon...","[NCBITaxon_2495529, NCBITaxon_509375, NCBITaxo..."
3,atypica-dispar/species,[NCBITaxon_509375],atypia-disbpar/species,,"[NCBITaxon_3233203, NCBITaxon_2612832, NCBITax...","[NCBITaxon_2612832, NCBITaxon_988072, NCBITaxo...","[NCBITaxon_2819460, NCBITaxon_2761155, NCBITax...","[NCBITaxon_2612832, NCBITaxon_689118, NCBITaxo...","[NCBITaxon_2819460, NCBITaxon_2054316, NCBITax...","[NCBITaxon_2612832, NCBITaxon_689118, NCBITaxo...","[NCBITaxon_2819460, NCBITaxon_1920382, NCBITax...","[NCBITaxon_689118, NCBITaxon_695336, NCBITaxon...","[NCBITaxon_2612832, NCBITaxon_88328, NCBITaxon...","[NCBITaxon_2612832, NCBITaxon_88328, NCBITaxon...","[NCBITaxon_2819460, NCBITaxon_249409, NCBITaxo...","[NCBITaxon_2612832, NCBITaxon_88328, NCBITaxon...","[NCBITaxon_2819460, NCBITaxon_1698944, NCBITax...","[NCBITaxon_2612832, NCBITaxon_88328, NCBITaxon...","[NCBITaxon_1698944, NCBITaxon_2761155, NCBITax...","[NCBITaxon_1187032, NCBITaxon_1033, NCBITaxon_..."
4,neisseria/genus,[NCBITaxon_482],neisseria/genus,[NCBITaxon_482],"[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_32...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_32...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_51315, NCBITaxon_989953, NCBITaxon_...","[NCBITaxon_51315, NCBITaxon_989953, NCBITaxon_...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_33...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_33...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48...","[NCBITaxon_482, NCBITaxon_192066, NCBITaxon_48..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,devosiaceae/family,[NCBITaxon_2831106],devosiaceae/family,[NCBITaxon_2831106],"[NCBITaxon_2831106, NCBITaxon_2845824, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845824, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845824, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845824, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_33105, NCBITaxon...","[NCBITaxon_2831106, NCBITaxon_33105, NCBITaxon...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax...","[NCBITaxon_2831106, NCBITaxon_2845825, NCBITax..."
216,planctomicrobium/genus,[NCBITaxon_1779141],pglanctomicrobium/genus,,"[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_1779141, NCBITaxon_162291, NCBITaxo...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_1779141, NCBITaxon_162291, NCBITaxo...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_1779141, NCBITaxon_204472, NCBITaxo...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_1779141, NCBITaxon_162291, NCBITaxo...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_339015, NCBITaxon_2569552, NCBITaxo...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_1779141, NCBITaxon_162291, NCBITaxo...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax...","[NCBITaxon_1779141, NCBITaxon_2608787, NCBITax..."
217,blastococcus/genus,[NCBITaxon_38501],blastococcus/genus,[NCBITaxon_38501],"[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1564162, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1564162, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon...","[NCBITaxon_38501, NCBITaxon_1873459, NCBITaxon..."
218,sumerlaea/genus,[NCBITaxon_2315418],sumerlaea/genus,[NCBITaxon_2315418],"[NCBITaxon_2315418, NCBITaxon_3066091, NCBITax...","[NCBITaxon_2315418, NCBITaxon_3066091, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2838781, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2838781, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2315421, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2315421, NCBITax...","[NCBITaxon_3149658, NCBITaxon_2708163, NCBITax...","[NCBITaxon_3149658, NCBITaxon_2708163, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2315419, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2315419, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2315420, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2315420, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2838780, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2838780, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2315420, NCBITax...","[NCBITaxon_2315418, NCBITaxon_2315420, NCBITax..."


In [46]:
def top_k_match(baseline_list, predictions, k):
    if not predictions or not baseline_list:
        return False
    return any(b in predictions[:k] for b in baseline_list)

In [47]:
top_k_values = [1, 3, 5]
methods = df_test.columns
results = {}

for method in methods:
    for k in top_k_values:
        correct = df_test.apply(lambda row: top_k_match(row["ete3"], row[method], k), axis=1)
        accuracy = correct.sum() / len(df_test)
        results[f"{method}/top{k}"] = accuracy


In [48]:
accuracy_df = pd.DataFrame.from_dict(results, orient="index", columns=["accuracy"])
accuracy_df.reset_index(inplace=True)
accuracy_df.columns = ["method_k", "accuracy"]

accuracy_df


Unnamed: 0,method_k,accuracy
0,lowest_known_taxon/top1,0.000000
1,lowest_known_taxon/top3,0.000000
2,lowest_known_taxon/top5,0.000000
3,ete3/top1,1.000000
4,ete3/top3,1.000000
...,...,...
61,pubmedbertbaseembeddings/top3,0.905738
62,pubmedbertbaseembeddings/top5,0.905738
63,pubmedbertbaseembeddings_c/top1,0.577869
64,pubmedbertbaseembeddings_c/top3,0.577869


In [49]:
import pandas as pd

# Start with a clean copy
df = accuracy_df.copy()

# --- Step 1: Split method name and top-k parts ---
df[['Method', 'Top@k']] = df['method_k'].str.split('/', expand=True)

# --- Step 2: Flag canonical vs. "_c" models ---
df['is_c'] = df['Method'].str.endswith('_c')
df['Method'] = df['Method'].str.replace('_c$', '', regex=True)

print(df)

# --- Step 3: Pivot base (non-_c) models ---
pivot_base = df[~df['is_c']].pivot(index='Method', columns='Top@k', values='accuracy')
pivot_base.columns = [f'Top @ {k[-1]}' for k in pivot_base.columns]
pivot_base['MRR'] = pivot_base.mean(axis=1)

# --- Step 4: Pivot "_c" models ---
pivot_c = df[df['is_c']].pivot(index='Method', columns='Top@k', values='accuracy')
pivot_c.columns = [f'Top @ {k[-1]}_c' for k in pivot_c.columns]
pivot_c['MRR_c'] = pivot_c.mean(axis=1)

# --- Step 5: Merge everything together ---
final_df = pivot_base.join(pivot_c, how='outer').reset_index()

# Optional: Sort by MRR
final_df = final_df.sort_values(by=['MRR', 'MRR_c'], ascending=False)

# Optional: Round for readability
final_df = final_df.round(4)



final_df

# df_test.at[1, "taxon_id"]= ["NCBITaxon_1540"]




















                           method_k  accuracy                    Method Top@k  \
0           lowest_known_taxon/top1  0.000000        lowest_known_taxon  top1   
1           lowest_known_taxon/top3  0.000000        lowest_known_taxon  top3   
2           lowest_known_taxon/top5  0.000000        lowest_known_taxon  top5   
3                         ete3/top1  1.000000                      ete3  top1   
4                         ete3/top3  1.000000                      ete3  top3   
..                              ...       ...                       ...   ...   
61    pubmedbertbaseembeddings/top3  0.905738  pubmedbertbaseembeddings  top3   
62    pubmedbertbaseembeddings/top5  0.905738  pubmedbertbaseembeddings  top5   
63  pubmedbertbaseembeddings_c/top1  0.577869  pubmedbertbaseembeddings  top1   
64  pubmedbertbaseembeddings_c/top3  0.577869  pubmedbertbaseembeddings  top3   
65  pubmedbertbaseembeddings_c/top5  0.577869  pubmedbertbaseembeddings  top5   

     is_c  
0   False  
1  

Unnamed: 0,Method,Top @ 1,Top @ 3,Top @ 5,MRR,Top @ 1_c,Top @ 3_c,Top @ 5_c,MRR_c
6,ete3,1.0,1.0,1.0,1.0,0.3074,0.3074,0.3074,0.3074
4,e5largev2,0.918,0.918,0.9344,0.9235,0.7336,0.7336,0.7418,0.7363
2,bgebaseenv15,0.9098,0.9262,0.9344,0.9235,0.7254,0.7336,0.7377,0.7322
5,e5smallv2,0.9139,0.918,0.9221,0.918,0.7131,0.7131,0.7131,0.7131
0,allminilml6v2,0.9098,0.9139,0.9139,0.9126,0.6311,0.6311,0.6352,0.6325
1,allmpnetbasev2bioasqmatryoshka,0.9057,0.9098,0.9098,0.9085,0.6189,0.623,0.623,0.6216
3,biosminilm,0.9057,0.9057,0.9098,0.9071,0.6189,0.6189,0.6189,0.6189
9,pubmedbertbaseembeddings,0.9016,0.9057,0.9057,0.9044,0.5779,0.5779,0.5779,0.5779
8,multilinguale5larg,0.8852,0.8934,0.8934,0.8907,0.7377,0.7459,0.75,0.7445
10,sbiobertsnlimultinlistsb,0.75,0.7541,0.7541,0.7527,0.5861,0.5902,0.5902,0.5888


Testing RAG retrieval

In [51]:
from sentence_transformers import SentenceTransformer
import faiss
import duckdb
import numpy as np
import json

from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import faiss
import duckdb
import torch

class PythiaRAGSystem:
    def __init__(self, db_path="document_vectors.db", faiss_path="docs.faiss", model_name="EleutherAI/pythia-2.8b"):
        # Initialize retrieval components
        self.retrieval_model = SentenceTransformer("juanpablomesa/all-mpnet-base-v2-bioasq-matryoshka")
        self.conn = duckdb.connect(db_path)
        self.index = faiss.read_index(faiss_path)
        
        # Initialize Pythia
        self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.llm_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16  # FP16 for memory efficiency
        )
        self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
    
    def retrieve(self, query_text: str, top_k: int = 50):
        """Retrieve relevant documents"""
        query_embedding = self.retrieval_model.encode(
            query_text,
            normalize_embeddings=True
        ).astype('float32')
        
        distances, indices = self.index.search(np.array([query_embedding]), top_k)
        
        results = self.conn.execute(f"""
        SELECT text, metadata 
        FROM documents 
        WHERE id IN ({','.join(map(str, indices[0]))})
        """).fetchall()
        
        return [
            {"text": r[0], "metadata": json.loads(r[1]), "score": float(distances[0][i])}
            for i, r in enumerate(results)
        ]

In [None]:
documents = retrieve(query_text, top_k)

In [52]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import NanoBEIREvaluator

model = SentenceTransformer('juanpablomesa/all-mpnet-base-v2-bioasq-matryoshka')

datasets = ["QuoraRetrieval", "MSMARCO"]
query_prompts = {
    "QuoraRetrieval": "Instruct: Given a question, retrieve questions that are semantically equivalent to the given question\nQuery: ",
    "MSMARCO": "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: "
}

evaluator = NanoBEIREvaluator(
    dataset_names=datasets,
    query_prompts=query_prompts,
)

results = evaluator(model)

Loading NanoBEIR datasets:   0%|          | 0/2 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [3]:
results

{'NanoQuoraRetrieval_cosine_accuracy@1': 0.92,
 'NanoQuoraRetrieval_cosine_accuracy@3': 0.98,
 'NanoQuoraRetrieval_cosine_accuracy@5': 1.0,
 'NanoQuoraRetrieval_cosine_accuracy@10': 1.0,
 'NanoQuoraRetrieval_cosine_precision@1': 0.92,
 'NanoQuoraRetrieval_cosine_precision@3': 0.40666666666666657,
 'NanoQuoraRetrieval_cosine_precision@5': 0.25999999999999995,
 'NanoQuoraRetrieval_cosine_precision@10': 0.13999999999999999,
 'NanoQuoraRetrieval_cosine_recall@1': 0.8173333333333334,
 'NanoQuoraRetrieval_cosine_recall@3': 0.9420000000000001,
 'NanoQuoraRetrieval_cosine_recall@5': 0.9793333333333334,
 'NanoQuoraRetrieval_cosine_recall@10': 1.0,
 'NanoQuoraRetrieval_cosine_ndcg@10': 0.9597276057012641,
 'NanoQuoraRetrieval_cosine_mrr@10': 0.9540000000000001,
 'NanoQuoraRetrieval_cosine_map@100': 0.9394920634920635,
 'NanoMSMARCO_cosine_accuracy@1': 0.4,
 'NanoMSMARCO_cosine_accuracy@3': 0.74,
 'NanoMSMARCO_cosine_accuracy@5': 0.78,
 'NanoMSMARCO_cosine_accuracy@10': 0.88,
 'NanoMSMARCO_cosine

In [53]:
import os
import json

def prepare_beir_dataset(
    table_rows,
    text_column="stringified_row",
    id_column="id",
    query_templates=None,
    output_dir="my_beir_dataset"
):
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, "qrels"), exist_ok=True)

    with open(os.path.join(output_dir, "corpus.jsonl"), 'w', encoding='utf-8') as f:
        for i, row in enumerate(table_rows):
            doc_id = row.get(id_column, f"doc{i}")
            f.write(json.dumps({
                "_id": str(doc_id),
                "title": "",
                "text": row[text_column]
            }) + "\n")

    with open(os.path.join(output_dir, "queries.jsonl"), 'w', encoding='utf-8') as qf, \
         open(os.path.join(output_dir, "qrels", "qrels.tsv"), 'w', encoding='utf-8') as rf:

        rf.write("query-id\tcorpus-id\tscore\n")

        for i, row in enumerate(table_rows):
            query_id = f"q{i}"
            doc_id = row.get(id_column, f"doc{i}")
            if query_templates:
                query_text = query_templates[i % len(query_templates)].format(**row)
            else:
                query_text = f"What is the genus and abundance in sample {row.get('sample_id', 'unknown')}?"

            qf.write(json.dumps({"_id": query_id, "text": query_text}) + "\n")
            rf.write(f"{query_id}\t{doc_id}\t1\n")

    print(f"✅ BEIR dataset created at: {output_dir}")

# Call the function
prepare_beir_dataset(
    table_rows=table_data,
    text_column="stringified_row",
    id_column="id",
    query_templates=[
        "Which genus is most abundant in {sample_id}?",
        "Tell me about {genus} in {sample_id}."
    ],
    output_dir="my_beir_dataset"
)


NameError: name 'table_data' is not defined