# **BigQuery Q&A using langchain & LLM (go/bq-research-assistant)**


Demo: This notebook guides you on use of LLM to answer questions over a BigQuery public dataset.
This notebook has also been tested on big datasets(~2 million rows) and works well with low latency.

To use it on your own data, change the parameters in the section "What are the top 10 most common mesh codes for compounds with a PubChem CID of 123456?"


|Author:       |Alex Burdenko (aburdenko@)|
|--------------|---------|
|Last Updated: |6/17/2023|


Example input: What kind of assays are in PubChem?

#Python libraries setup (Run First, Run Once)
Note, we put these at the end so that we can run this cell and then, once the notebook restarts, we can run all cells above without worrying about this step that we only have to do once.

In [6]:
#!pip3 uninstall -y google-cloud-aiplatform
!pip install google-cloud-aiplatform --upgrade --quiet

from google.cloud import aiplatform
print(f"Vertex AI SDK version: {aiplatform.__version__}")

# Install Python Libraries
!pip install langchain --upgrade --quiet
!pip install google-cloud-core --quiet
!pip install gradio --quiet
!pip install gradio_tools --quiet

!pip install chromadb --quiet

# Below libraries are required to build a SQL engine for BigQuery
!pip install SQLAlchemy --quiet
!pip install sqlalchemy-bigquery --quiet
!pip install google-cloud-pubsub --quiet

import IPython
print( 'restarting kernel...' )
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m321.3/321.3 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m74.5 MB/s[0m eta [36m0:00:00[0m
[?25hVertex AI SDK version: 1.26.1
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.0/90.0 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.1/49.1 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.7/19.7 MB[0m [31m67.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.0/57.0 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25

{'status': 'ok', 'restart': True}

#Since the kernel restarted at this point, you can run all cells below.

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
# Authenticate with Google account
from google.colab import auth as google_auth
google_auth.authenticate_user()

import google.auth
credentials, project = google.auth.default()
from google.cloud.bigquery import magics
magics.context.credentials = credentials

### LLM Model Initialization & App parameters initialization

In [4]:
# @title Specify Project details and LOCATION of the BQ table

PROJECT_ID = "kallogjeri-project-345114"  # @param {type:"string"}
DB_PROJECT_ID = "sciwalker-open-data"  # @param {type:"string"}

#DB_PROJECT_ID = "cloud-llm-preview1"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}
#DATASET_ID = 'blackbelt_capstone_healthcare' # @param {type:"string"}
DATASET_ID = 'clinical_trials_aact' # @param {type:"string"}

#@markdown ### Enter the topic name and subscription to be used for pub/sub
TOPIC_NAME="customer-chat" # @param {type:"string"}
SUBSCRIPTION="colab-sub" # @param {type:"string"}


import sys
IN_COLAB = 'google.colab' in sys.modules
%env IN_COLAB=$IN_COLAB

!gcloud config set project $PROJECT_ID -q
!gcloud config get project

env: IN_COLAB=True
Updated property [core/project].
kallogjeri-project-345114


In [5]:
import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION)

from langchain.llms import VertexAI
llm = VertexAI(
    model_name='text-bison',
    max_output_tokens=1024,
    temperature=0.1, # 0 causes some responses to get blocked!
    top_p=1,
    top_k=40,
    verbose=True
)

ModuleNotFoundError: ignored

### Create SQL engine for BigQuery

In [None]:
from sqlalchemy import *
from sqlalchemy.engine import create_engine
from sqlalchemy.schema import *
import pandas as pd

In [None]:
from google.cloud import bigquery

# Create a BigQuery client.
bq_client = bigquery.Client(project=PROJECT_ID)

table_uri = f"bigquery://{PROJECT_ID}/{DATASET_ID}"
engine = create_engine(
    f"bigquery://{DB_PROJECT_ID}/{DATASET_ID}?user_supplied_client=True",
    connect_args={'client': bq_client}
)

In [None]:
from google.cloud import bigquery

def create_ds_view( tbl_name ):
  view_name = f"vw_{tbl_name}"
  view_id = f"{PROJECT_ID}.{DATASET_ID}.{view_name}"
  source_id = f"{DB_PROJECT_ID}.{DATASET_ID}.{tbl_name}"
  view = bigquery.Table(view_id)

  view.view_query = f"SELECT * FROM `{source_id}`"

  # Make an API request to create the view.
  view = bq_client.create_table(view, exists_ok=True)
  #print(f"Created {view.table_type}: {str(view.reference)}")
  fq_view_name = f"{DATASET_ID}.{view_name}"
  return view_name, fq_view_name



query = f"""SELECT table_name \
  FROM `{DB_PROJECT_ID}.{DATASET_ID}`.INFORMATION_SCHEMA.COLUMNS """

tbl_names = set(engine.execute(query).unique().fetchall())
from itertools import chain
tbl_names = list(chain(*tbl_names))

# Create the dataset if it doesn't exist
try:
    bq_client.get_dataset(DATASET_ID)
except:
    # The dataset doesn't exist, so create it.
    dataset = bigquery.Dataset(f"{PROJECT_ID}.{DATASET_ID}")
    bq_client.create_dataset(dataset)


view_struct = list(map( create_ds_view, tbl_names ))

view_names = [view[0] for view in view_struct]
print(view_names)
fq_view_names = [view[1] for view in view_struct]


table_str = "','".join(tbl_names)
view_str = "','".join(view_names)

column_query =  f"SELECT table_name, column_name \
  FROM `{PROJECT_ID}.{DATASET_ID}`.INFORMATION_SCHEMA.COLUMNS \
  WHERE table_name in ('{view_str}')"

print(column_query)

columns =  list(engine.execute(column_query).unique().fetchall())
#print(columns[1])
columns = list(chain(*columns))
#print(columns)

column_names = list(map(lambda x, y: f"{x}.{y}", columns[::2], columns[1::2]))

In [None]:
query=f"""SELECT * FROM {PROJECT_ID}.{DATASET_ID}.{view_struct[0][0]} limit 1000"""
engine.execute(query).first()

In [None]:
from google.cloud import bigquery
from langchain import SQLDatabase
from sqlalchemy import create_engine, MetaData


# Create a BigQuery client.
bq_client = bigquery.Client(project=PROJECT_ID)

table_uri = f"bigquery://{PROJECT_ID}/{DATASET_ID}"
# engine = create_engine(
#     f"bigquery://{PROJECT_ID}/{DATASET_ID}?user_supplied_client=True",
#     connect_args={'client': bq_client}
# )

table_uri = f"bigquery://{PROJECT_ID}/{DATASET_ID}"
engine = create_engine(f"bigquery://{PROJECT_ID}/{DATASET_ID}")
query=f"""SELECT * FROM {DB_PROJECT_ID}.{DATASET_ID}.{tbl_names[0]} limit 1000000"""
engine.execute(query).first()


# Create an instance of SQLDatabase
db = SQLDatabase(
    engine=engine,
    metadata=MetaData(bind=engine),
    include_tables=view_names,
    view_support=True,
)


### SQL Chain setup for LLM

In [2]:
import json
import re

def convert_llm_response_to_gradio_json(llm_response, PROJECT_ID=PROJECT_ID, DB_PROJECT_ID=DB_PROJECT_ID):
  """Converts an LLM response string to GradiO JSON.

  Args:
    llm_response: The LLM response string.

  Returns:
    The GradiO JSON object.
  """

  # Extract SQLQuery value
  sql_query : str = llm_response.split("SQLQuery:")[1].split("Answer:")[0].strip()
  sql_query = sql_query.replace( PROJECT_ID, DB_PROJECT_ID )

  # Extract Answer value
  answer = str(llm_response.split("Answer:")[1].strip())

  if not answer.startswith("['"):
    answer = "['" + answer

  import re
  if not re.match(r".*']", answer) and len(answer)>0:
    answer = answer + "']"

  return  answer, sql_query


llm_response=f"\
SQLQuery: SELECT DISTINCT meshcode \
FROM `{PROJECT_ID}`.vw_cid_meshcodes \
WHERE cid = 123456 \
ORDER BY COUNT(meshcode) DESC \
LIMIT 10 \
Answer: D000077600, D000077601, D000077602, D000077603, D000077604, D000077605, D000077606, D000077607, D000077608, D000077609"

convert_llm_response_to_gradio_json(llm_response, PROJECT_ID, DB_PROJECT_ID)

NameError: name 'PROJECT_ID' is not defined

In [None]:
from langchain.prompts.chat import BaseStringMessagePromptTemplate
from langchain import SQLDatabase, SQLDatabaseChain
from langchain.prompts.prompt import PromptTemplate
from langchain import PromptTemplate, LLMChain

def bq_qna(question):
  #create SQLDatabase instance from BQ engine
  db = SQLDatabase(engine=engine,metadata=MetaData(bind=engine),include_tables=view_names, view_support=True)

  #create SQL DB Chain with the initialized LLM and above SQLDB instance
  db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_intermediate_steps=True)

  #Define prompt for BigQuery SQL
  _googlesql_prompt = """You are a GoogleSQL expert. Given an input question, first create a syntactically correct GoogleSQL query to run, then look at the results of the query and return the answer to the input question.
  Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per GoogleSQL. You can order the results to return the most informative data in the database.
  Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
  Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
  Use the following format:
  Question: "Question here"
  SQLQuery: "SQL Query to run"
  Answer: "Final answer here"

  Limit the response to 1024 characters.

  Only use the following tables:
  ('{table_str}')

  If someone asks for aggregation on a STRING data type column, then CAST column as NUMERIC before you do the aggregation.

  If someone asks for specific month, use ActivityDate between current month's start date and current month's end date.

  If someone asks for column names in the table, use the following format:
  SELECT column_name
  FROM `{PROJECT_ID}.{DATASET_ID}`.INFORMATION_SCHEMA.COLUMNS
  WHERE table_name in ('{table_str}').

  Question: {input}"""

  GOOGLESQL_PROMPT = PromptTemplate(
      input_variables=["input", "PROJECT_ID", "DATASET_ID", "table_str", "top_k"],
      template=_googlesql_prompt
  )

  #passing question to the prompt template
  final_prompt = GOOGLESQL_PROMPT.format(input=question, PROJECT_ID =PROJECT_ID, DATASET_ID=DATASET_ID, table_str="','".join(fq_view_names), top_k=10000)

  #print( final_prompt)

  result = llm.predict( final_prompt )
  return convert_llm_response_to_gradio_json(result)


### Testing the setup

In [None]:
#Testing 0
res = bq_qna('Find all tetrazole ocid subject compounds in clinical trials.')
print(res)

In [None]:
#Testing 1
res = bq_qna('What are the top 10 most common mesh codes for compounds with a PubChem CID of 123456?')
print(res)

In [None]:
#Testing 2
res = bq_qna('What kind of assays are in PubChem?')
print(res)

In [None]:
#Testing 3
res = bq_qna('How many molecules have EC50 assays for SGLT2 protein?')
print(res)

### UI for Demo

Sample Inputs:

  * What are the top 10 most common mesh codes for compounds with a PubChem CID of 123456?

  * What kinds of assays are in PubChem?

  * How many molecules have EC50 assays for SGLT2 protein?

In [None]:
from logging import debug
import gradio as gr
import warnings
warnings.filterwarnings('ignore')

# Create a Gradio interface.
interface = gr.Interface(

    inputs=[gr.inputs.Textbox(label="Question", placeholder="What are the top 10 most common mesh codes for compounds with a PubChem CID of 123456?")],
    outputs=[ gr.outputs.Textbox(label="Answer"), gr.outputs.Textbox(label="SQL Query")],
    fn=bq_qna,debug=True
  )

  # Display the Gradio interface.

interface.launch(debug=False)

# Interact with a front end via Pub/Sub

In [None]:
import os
import json
from google.cloud import pubsub

subscription_name = f"projects/{PROJECT_ID}/subscriptions/{SUBSCRIPTION}"



msg_result = None
res = None
def callback(message):
    data_dict = json.loads(message.data)
    if not 'response' in data_dict.keys():
      response = data_dict
      msg_result = message.ack_with_response()
      print(response)
      question = data_dict['question']
      res = bq_qna(question)

      future.cancel()



with pubsub.SubscriberClient() as subscriber:
    # subscriber.create_subscription(
    #     name=subscription_name, topic=topic_name)
    response = dict()
    future = subscriber.subscribe(subscription_name, callback)
    result = future.result()

    print(msg_result)



    # # Initialize request argument(s)
    # request = subscriber.AcknowledgeRequest(
    #     subscription=subscription_name,
    #     ack_ids=[msg.ack_id]
    # )

    # future.cancel()


    # # Make the request
    # #client.acknowledge(request=request)
    # #msg = future.result()
    # print(msg)
    # subscriber.close()
    # # Cancel the future.
    # future.cancel()

    # Check if the future was cancelled.
    if future.cancelled():
        print("The future was cancelled.")
    else:
        print("The future was not cancelled.")

fq_topic_name = f"projects/{PROJECT_ID}/topics/{TOPIC_NAME}"

with pubsub.PublisherClient() as publisher:
  # Convert the dictionary to a JSON string.
  message_data = dict()
  message_data['response'] = res
  json_string = json.dumps(message_data)

  # Convert the JSON string to a bytestring.
  bytestring = bytes(json_string, "utf-8")
  future = publisher.publish(fq_topic_name, bytestring)
  message = future.result()

print(response)