# SPARQL generation via Generative AI
---
This notebook demonstrates one way to generate SPARQL queries from natural language questions. Here we focus on
prompting the model by implicitly showing it the database schema.

If you are running this notebook outside of an AWS environment (e.g., on your laptop) then you should uncomment
this cell and add the appropriate authentication keys:

In [None]:
# %env AWS_ACCESS_KEY_ID=<...>
# %env AWS_SECRET_ACCESS_KEY=<...>
# %env AWS_SESSION_TOKEN=<...>

If you are running this notebook inside of an AWS environment (e.g., inside Sagemaker Studio) then
use the "conda_pytorch_p310" kernel and uncomment the following cell:

In [None]:
# %pip install -q boto3==1.34.*
# %pip install -q botocore==1.34.*
# %pip install -q jupyter==1.0.*
# %pip install -q sagemaker==2.212.*
# %pip install -q jinja2==3.1.*
# %pip install -q ipykernel==6.29.*
# %pip install -q awswrangler==3.7.*

In [64]:
from pathlib import Path
import json

import boto3
import botocore
from botocore.exceptions import ClientError
import sagemaker
import jinja2
import awswrangler as wr

In [65]:
sess = sagemaker.Session()
region = sess.boto_region_name
sm_client = boto3.client("sagemaker", region_name=region)
bedrock_runtime = boto3.client("bedrock-runtime", region_name=region)
bedrock = boto3.client("bedrock", region_name=region)

jenv = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
# model_id = "anthropic.claude-v2:1"
# model_id = "anthropic.claude-3-haiku-20240307-v1:0"
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
temperature= 0.1

In [None]:
neptune_url = "sgh-uniprot.cluster-ro-c8eun1oaqppj.us-east-1.neptune.amazonaws.com"
neptune_port = 8182
neptune_client = wr.neptune.connect(neptune_url, neptune_port, iam_enabled=True)
print(f"Neptune client status: {neptune_client.status()}")

In [51]:
def run_bedrock(prompt: str) -> str:
    try:
        response = bedrock_runtime.invoke_model(
            modelId=model_id,
            body=json.dumps(
                {
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 1024,
                    "temperature": temperature,
                    "messages": [
                        {
                            "role": "user",
                            "content": [{"type": "text", "text": prompt}],
                        }
                    ],
                }
            ),
        )
        result = json.loads(response.get("body").read())
        output_list = result.get("content", [])
        return "".join(output["text"] for output in output_list if output["type"] == "text")

    except ClientError as err:
        print(f"Error invoking {model_id}: {err.response['Error']['Code']} {err.response['Error']['Message']}")
        raise err

In [52]:
prompt_template = (Path.cwd() / "resources" / "prompt.txt").read_text()

def generate_sparql_query(question: str) -> str:
    prompt = jenv.from_string(prompt_template).render(question=question)
    response = run_bedrock(prompt).strip()
    idx = response.index("<sparql>")
    if idx is not None:
        response = response[idx+8:]
    idx = response.index("</sparql>")
    if idx is not None:
        response = response[:idx]
    return response

In [None]:
query = generate_sparql_query("Show me all proteins that are located in the mitochondrian")
print(query)

In [None]:
def query_neptune(query: str) -> Union[pd.DataFrame, str]:
    print(f"query_neptune {query}")
    try:
        rv = wr.neptune.execute_sparql(neptune_client, query)
        print(f"query_neptune -> {rv}")
        return rv
    except wr.exceptions.QueryFailed as ex:
        return str(ex)
    except requests.exceptions.JSONDecodeError as ex:
        return str(ex)
    except Exception as ex:
        if "json.decoder.JSONDecodeError" in str(ex):
            print("Looks like result set was too big")
            return None
        else:
            raise ex

def simplify_neptune_df(df: pd.DataFrame) -> pd.DataFrame:
    """
    Replace {'type': 'uri', 'value': 'http://...'} with 'http://...' etc.
    """
    def func(elem):
        try:
            return elem["value"]
        except:
            return elem
    if df is not None:
        return df.applymap(func)
    else:
        return df


select_pat = re.compile(r"\s*SELECT", re.DOTALL|re.IGNORECASE)
limit_pat = re.compile(r"^.*LIMIT\s+([0-9]+)\s*$", re.DOTALL)

def add_limit_to_sparql_q(query: str, max_results: int) -> str:
    print(f"add_limit {max_results} <<{query}>>")
    match = select_pat.match(query)
    if match:
        # only put limits on SELECT queries
        match = limit_pat.search(query)
        print(f"match: {match}")
        if match:
            print(f"{match.start(1)}:{match.end(1)} --> {query[match.start(1):match.end(1)]}")
            return query[:match.start(1)] + str(max_results) + query[match.end(1)+1:]
        else:
            return query + f"\nLIMIT {max_results}"
    else:
        return query