In [1]:
# system
import os
import csv
from typing import Dict

# langchain
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings

# Vector DB
from docarray.index import InMemoryExactNNIndex

# code utils
import black

# utils
from utils.models import VecDBEntry
from utils.db import search_db
from utils.templates import get_extractor_chain, get_data_extractor_map_chain
from utils.models import Scope, ScopeEnum



In [2]:
os.environ["AZURE_OPENAI_API_KEY"] = os.environ["OPENAI_API_KEY"]
os.environ["AZURE_OPENAI_ENDPOINT"] = os.environ["OPENAI_API_BASE"]
os.environ["AZURE_OPENAI_API_VERSION"] = os.environ["OPENAI_API_VERSION"]
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = "firstcontact-gpt4-turbo"

del os.environ["OPENAI_API_BASE"]

In [3]:
llm_json = AzureChatOpenAI(
    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
    azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
    model="gpt-4-128k",
    model_kwargs={"response_format": {"type": "json_object"}},
)

embeddings_model = AzureOpenAIEmbeddings(
    azure_deployment="firstcontact-embeddings", model="gpt-4-128k"
)

## Extract data and generate SDK construct


In [4]:
docs = []

with open("./embeddings.csv") as fp:
    reader = csv.reader(fp)
    for row, (term, label, scope, embedding) in enumerate(reader):
        if row == 0:
            continue
        embedding = eval(embedding)
        docs.append(
            VecDBEntry(term=term, label=label, scope=scope, embedding=embedding)
        )
db = InMemoryExactNNIndex[VecDBEntry]()
db.index(docs)

In [5]:
def generate_extractor_code(query: str) -> Dict:
    extractor_chain = get_data_extractor_map_chain(llm_json)
    extracted_data = extractor_chain.invoke(query)

    print(f"{extracted_data=}")
    print(f"filters text: {extracted_data["filters"].filters}")

    if len(extracted_data["filters"].filters) > 0:
        filters = extracted_data["filters"]
        filters = filters.model_dump(mode="json")["filters"]
        hits = []

        for fil in filters:
            entries, scores = search_db(db, fil["term"], embeddings_model)
            for entry, score in zip(entries, scores):
                if score < 0.9:
                    break
                hits.append(
                    {
                        "scope": fil["scope"],
                        "term": entry.term,
                        "score": score,
                        "label": entry.label,
                        "query": fil["term"],
                    }
                )
        extracted_data["filters"] = hits
        print(f"filters coded: {hits}")
    else:
        extracted_data["filters"] = []

    sdk_construct = "data = BeaconV2()"
    sdk_comments = ""

    scope = extracted_data.get("scope", Scope(scope=ScopeEnum.UNKNOWN))
    filters = extracted_data["filters"]

    if extracted_data["variant"].success:
        sdk_comments += f"""# Variants detected in the query.\n"""
        sdk_construct += f""".with_g_variant("""
        print(f"variant found: {extracted_data["variant"].dict()}")
        assembly_id = extracted_data["variant"].assembly_id
        sdk_construct += "'GRCH38'" if assembly_id == "unknown" else f"'{assembly_id}'"
        sdk_construct += ",'N','N',"
        start = extracted_data["variant"].start
        if isinstance(start, list):
            sdk_construct += "[0]," if start == "unknown" else f"{start},"
        else:
            sdk_construct += "[0]," if start == "unknown" else f"[{start}],"
        end = extracted_data["variant"].end
        if isinstance(end, list):
            sdk_construct += "[0]," if end == "unknown" else f"{end},"
        else:
            sdk_construct += "[0]," if end == "unknown" else f"[{end}],"
        reference_name = extracted_data["variant"].chromosome
        sdk_construct += "'1'" if reference_name == "unknown" else f"'{reference_name}'"
        sdk_construct += ")"

    if scope.scope != ScopeEnum.UNKNOWN:
        print(f"{scope=}")
        scope = scope.model_dump(mode="json")["scope"]
        sdk_construct += f""".with_scope("{scope}")"""
        sdk_comments += f"""# Scope detected to be '{scope}'.\n"""
    else:
        sdk_comments += f"""# Could not decide a scope for your query.\n"""
        sdk_construct += f""".with_scope('<ENTER YOUR SCOPE>')"""

    for fil in filters:
        print(f"{fil=}")
        sdk_comments += f"""# {fil["term"]} -> '{fil["label"]}'\n"""
        sdk_construct += (
            f""".with_filter('ontology', '{fil["term"]}', '{fil["scope"]}') """
        )

    sdk_construct = (
        sdk_construct
        + ".load()"
        + "\n\n"
        + sdk_comments
        + "\n# Please update this line with other dataframes"
        + "\ndataframes = [data]"
    )

    sdk_construct = black.format_str(sdk_construct, mode=black.FileMode())

    return sdk_construct

In [6]:
sdk_construct = generate_extractor_code(
    "Individuals with Parkinson's with variants in first chromosome from the 10000-15000 bases"
)

extracted_data={'scope': Scope(scope=<ScopeEnum.INDIVIDUALS: 'individuals'>), 'filters': Filters(filters=[Filter(term="Parkinson's", scope=<ScopeEnum.INDIVIDUALS: 'individuals'>)]), 'variant': Variant(success=True, assembly_id='unknown', chromosome='1', start=[10000], end=[15000]), 'granularity': Granularity(granularity=<GranularityEnum.RECORD: 'record'>), 'query': "Individuals with Parkinson's with variants in first chromosome from the 10000-15000 bases"}
filters text: [Filter(term="Parkinson's", scope=<ScopeEnum.INDIVIDUALS: 'individuals'>)]
filters coded: [{'scope': 'individuals', 'term': 'SNOMED:49049000', 'score': np.float64(0.9633136720184313), 'label': "Parkinson's disease", 'query': "Parkinson's"}]
variant found: {'success': True, 'assembly_id': 'unknown', 'chromosome': '1', 'start': [10000], 'end': [15000]}
scope=Scope(scope=<ScopeEnum.INDIVIDUALS: 'individuals'>)
fil={'scope': 'individuals', 'term': 'SNOMED:49049000', 'score': np.float64(0.9633136720184313), 'label': "Parkins

In [7]:
print(sdk_construct)

data = (
    BeaconV2()
    .with_g_variant("GRCH38", "N", "N", [10000], [15000], "1")
    .with_scope("individuals")
    .with_filter("ontology", "SNOMED:49049000", "individuals")
    .load()
)

# Variants detected in the query.
# Scope detected to be 'individuals'.
# SNOMED:49049000 -> 'Parkinson's disease'

# Please update this line with other dataframes
dataframes = [data]

