In [None]:
# @title Install Necessary Packages

first = False
if first:
  ! pip install google-cloud-aiplatform --upgrade --quiet
  ! pip install shapely==1.8.5 --quiet
  ! pip install sqlalchemy --upgrade --quiet
  ! pip install asyncio asyncpg cloud-sql-python-connector["asyncpg"] --quiet
  ! pip install numpy pandas --quiet
  ! pip install pgvector --quiet
  ! pip install pg8000 --quiet
  ! pip install gradio --quiet

In [None]:
# get_ipython().kernel.do_shutdown(True)

In [1]:
from google.colab import auth
auth.authenticate_user()

In [2]:
import os
auth_user=!gcloud config get-value account
auth_user=auth_user[0]
print('Authenticated User: ' + str(auth_user))


Authenticated User: moksh.google@shoppersstop.com


In [3]:
#@title Assignment of Variables
source_type='BigQuery'


# @markdown Provide the below details to start using the notebook
PROJECT_ID='ss-genai-npd-svc-prj-01' # @param {type:"string"}
LLM_ENDPOINT_REGION = 'asia-southeast1' # @param {type:"string"}
DATAPROJECT_ID='ss-genai-npd-svc-prj-01'  # This needs to be adjusted when using the bq public dataset

#set and show gcp project
!gcloud config set project {PROJECT_ID}
!gcloud config get-value project
#!bash gcloud auth application-default login


# BQ Schema (DATASET) where tables leave

schema="NA" # @param {type:"string"}.  ### DDL extraction performed at this level, for the entire schema
USER_DATASET= DATAPROJECT_ID + '.' + schema

table_id_list="ss-jarvis-npd-svc-prj-01.Customer.FC_Customer_Master, ss-jarvis-npd-svc-prj-01.Customer.Persona,ss-jarvis-npd-svc-prj-01.Product.Product_Master, ss-jarvis-npd-svc-prj-01.Sales.storesales,ss-jarvis-npd-svc-prj-01.Sales.ecommdenormdata,ss-jarvis-npd-svc-prj-01.Store.dim_location_full" # @param {type:"string"}


# BQ Schema (DATASET) where tables leave

# Execution Parameters
SQL_VALIDATION='ALL'
INJECT_ONE_ERROR=False
EXECUTE_FINAL_SQL=True
SQL_MAX_FIX_RETRY=3
AUTO_ADD_KNOWNGOOD_SQL=True

# Analytics Warehouse
ENABLE_ANALYTICS=True
DATASET_NAME='nl2sql'
DATASET_LOCATION='asia-south1'
LOG_TABLE_NAME='query_logs'
FULL_LOG_TEXT=''


# Palm Models to use
model_id='gemini-pro' # @param {type:"string"}
chat_model_id='codechat-bison-32k' # @param {type:"string"}
embeddings_model='textembedding-gecko@001'

Updated property [core/project].
ss-genai-npd-svc-prj-01


In [4]:
# @title Common Imports
import time
import datetime
from datetime import datetime, timezone
import hashlib
import vertexai
import pandas
import pandas_gbq
import matplotlib.pyplot as plt
from sqlalchemy import create_engine
from sqlalchemy import text
import pandas as pd
from google.colab import data_table
data_table.enable_dataframe_formatter()
import json
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
from logging import exception
import asyncio
import asyncpg
from google.cloud.sql.connector import Connector
import numpy as np
from pgvector.asyncpg import register_vector
from google.cloud import aiplatform
from vertexai.language_models import TextEmbeddingModel
import gradio as gr

In [5]:
# @title Model Endpoint Creation
def createModel(PROJECT_ID, LLM_ENDPOINT_REGION, model_id):
  from vertexai.preview.language_models import TextGenerationModel
  from vertexai.preview.language_models import CodeGenerationModel
  from vertexai.preview.language_models import CodeChatModel
  from vertexai.preview.generative_models import GenerativeModel

  if model_id == 'code-bison-32k':
    model = CodeGenerationModel.from_pretrained('code-bison-32k')
  elif model_id == 'text-bison-32k':
    model = TextGenerationModel.from_pretrained('text-bison-32k')
  elif model_id == 'codechat-bison-32k':
    model = CodeChatModel.from_pretrained("codechat-bison-32k")
  elif model_id == 'gemini-pro':
    model = GenerativeModel("gemini-pro")
  else:
    raise ValueError
  return model


vertexai.init(project=PROJECT_ID, location=LLM_ENDPOINT_REGION)
model=createModel(PROJECT_ID, LLM_ENDPOINT_REGION,model_id)
chat_model=createModel(PROJECT_ID, LLM_ENDPOINT_REGION,chat_model_id)

In [6]:
# @title Functions to pull metadata from BQ (Tables, Columns, PKs and FKs)
def create_metadata_sqls(table_id_list):
  get_columns_sql="""SELECT
    TABLE_CATALOG as project_id, TABLE_SCHEMA as owner , TABLE_NAME as table_name, COLUMN_NAME as column_name,
    DATA_TYPE as data_type, ROUNDING_MODE as rounding_mode, description
  FROM
    {}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS WHERE TABLE_NAME IN {}
  ORDER BY
   project_id, owner, table_name, column_name """

  get_fkeys_sql="""SELECT T.CONSTRAINT_CATALOG, T.CONSTRAINT_SCHEMA, T.CONSTRAINT_NAME,
T.TABLE_CATALOG as project_id, T.TABLE_SCHEMA as owner, T.TABLE_NAME as table_name, T.CONSTRAINT_TYPE,
T.IS_DEFERRABLE, T.ENFORCED, K.COLUMN_NAME
FROM
{}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS T
JOIN {}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE K
ON K.CONSTRAINT_NAME=T.CONSTRAINT_NAME
WHERE T.TABLE_NAME IN {} AND
T.CONSTRAINT_TYPE="FOREIGN KEY"
ORDER BY
project_id, owner, table_name"""

  get_pkeys_sql="""
SELECT T.CONSTRAINT_CATALOG, T.CONSTRAINT_SCHEMA, T.CONSTRAINT_NAME,
T.TABLE_CATALOG as project_id, T.TABLE_SCHEMA as owner, T.TABLE_NAME as table_name, T.CONSTRAINT_TYPE,
T.IS_DEFERRABLE, T.ENFORCED, K.COLUMN_NAME
FROM
{}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS T
JOIN {}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE K
ON K.CONSTRAINT_NAME=T.CONSTRAINT_NAME
WHERE  T.TABLE_NAME IN {} AND
T.CONSTRAINT_TYPE="PRIMARY KEY"
ORDER BY
project_id, owner, table_name
"""


  get_table_comments_sql="""
select TABLE_CATALOG as project_id, TABLE_SCHEMA as owner , TABLE_NAME as table_name, OPTION_NAME, OPTION_TYPE, OPTION_VALUE as comments
FROM
{}.INFORMATION_SCHEMA.TABLE_OPTIONS
WHERE TABLE_NAME IN {} AND
 OPTION_NAME = "description"
ORDER BY
project_id, owner, table_name
"""
  table_id_list=table_id_list.split(",")
  result = []
  for table_id in table_id_list:
    # print(table_id)
    project, schema, table_name = table_id.split('.')
    project = project.strip()
    schema = schema.strip()
    table_name = table_name.strip()
    result.append({'project': project, 'schema': schema, 'table_name': table_name})
  result.sort(key=lambda x: x['project'])
  # print((result))
  project_schema_groups = {}
  for item in result:
      key = (item['project'], item['schema'])
      project_schema_groups.setdefault(key, []).append(item['table_name'])

  sql_statements_1, sql_statements_2, sql_statements_3, sql_statements_4 = [], [], [], []
  for (project, schema), table_names in project_schema_groups.items():
      table_names_placeholder = "({})".format(", ".join("'{}'".format(t) for t in table_names))
      sql_1 = get_columns_sql.format('{}.{}'.format(project, schema), table_names_placeholder)
      sql_2 = get_fkeys_sql.format('{}.{}'.format(project, schema), '{}.{}'.format(project, schema), table_names_placeholder)
      sql_3 = get_pkeys_sql.format('{}.{}'.format(project, schema), '{}.{}'.format(project, schema), table_names_placeholder)
      sql_4 = get_table_comments_sql.format('{}.{}'.format(project, schema), table_names_placeholder)
      sql_statements_1.append(f"({sql_1})")
      sql_statements_2.append(f"({sql_2})")
      sql_statements_3.append(f"({sql_3})")
      sql_statements_4.append(f"({sql_4})")

  get_columns_sql = " UNION ALL ".join(sql_statements_1) + ";"
  get_fkeys_sql = " UNION ALL ".join(sql_statements_2) + ";"
  get_pkeys_sql = " UNION ALL ".join(sql_statements_3) + ";"
  get_table_comments_sql = " UNION ALL ".join(sql_statements_4) + ";"

  return get_columns_sql, get_fkeys_sql, get_pkeys_sql, get_table_comments_sql




#Utility Functions

In [7]:

def schema_generator(sql):
  df = pandas_gbq.read_gbq(sql, project_id=PROJECT_ID)
  return df


def add_table_comments(columns_df, pkeys_df, fkeys_df, table_comments_df):

  for index, row in table_comments_df.iterrows():
    if row['comments'] is None: ## or row['comments'] is not None:
        q=f"table_name == '{row['table_name']}' and owner == '{row['owner']}'"
        context_prompt = f"""
        Generate detailed table comments for the table {row['project_id']}.{row['owner']}.{row['table_name']}
Remember that these comments should help LLMs to help build better SQL for any quries related to this table.
        Parameters:
        - column metadata: {columns_df.query(q).to_markdown(index = False)}
        - primary key metadata: {pkeys_df.query(q).to_markdown(index = False)}
        - foreign keys metadata: {fkeys_df.query(q).to_markdown(index = False)}
        - table metadata: {table_comments_df.query(q).to_markdown(index = False)}
      """
        #context_query = model.predict(context_prompt, max_output_tokens = 256, temperature= 0.2)
        # print(context_prompt)
        context_query = model.generate_content(context_prompt, stream=False)
        # print(clean_sql(str(context_query.candidates[0])))
        table_comments_df.at[index, 'comments'] = clean_sql(str(context_query.candidates[0].text))

  return table_comments_df



def add_column_comments(columns_df, pkeys_df, fkeys_df, table_comments_df):
  for index, row in columns_df.iterrows():
    if row['description'] == None:
      q=f"table_name == '{row['table_name']}' and owner == '{row['owner']}'"
      context_prompt = f"""
      Generate comments for the column {row['project_id']}.{row['owner']}.{row['table_name']}.{row['column_name']}

      Remember that these comments should help LLMs to help generate better SQL for any queries related to these columns.

      Consider the below information to generate a good comment

      Description of the column is : {row['description']}
      Data type of the column is : {row['data_type']}
      Details of the table of this column are below:
      {table_comments_df.query(q).to_markdown(index=False)}
      Details of the primary keys of the table containing this column are below:
      {pkeys_df.query(q).to_markdown(index=False)}
      Details of the foreign keys of the table containing this column are below:
      {fkeys_df.query(q).to_markdown(index=False)}
      """
      # print(context_prompt)
      context_query = model.generate_content(context_prompt, stream=False)
      print({row['column_name']})
      columns_df.at[index, 'column_comments'] = clean_sql(str(context_query.candidates[0].text))
  return columns_df

def get_column_sample(columns_df):
  sample_column_list=[]

  for index, row in columns_df.iterrows():
    get_column_sample_sql=f'''
        SELECT STRING_AGG(CAST(value AS STRING)) as sample_values
        FROM UNNEST((SELECT APPROX_TOP_COUNT({row["column_name"]}, 5) as osn
                     FROM `{row["project_id"]}.{row["owner"]}.{row["table_name"]}`
                ))
    '''
    column_samples_df=schema_generator(get_column_sample_sql)
    sample_column_list.append(column_samples_df['sample_values'].to_string(index=False))

  columns_df["sample_values"]=sample_column_list
  return columns_df



def clean_sql(result):
  result = result.replace("```sql", "").replace("```", "")
  return result


In [8]:

# Augment Table dataframe with detailed description. This detailed description column will be the one used as the document when adding the record to the VectorDB
def build_table_desc(table_comments_df,columns_df,pkeys_df,fkeys_df):
  aug_table_comments_df = table_comments_df

  #print(len(aug_table_comments_df))
  #print(len(table_comments_df))

  cur_table_name = ""
  cur_table_owner = ""
  cur_project_id = ""
  cur_full_table= cur_project_id + '.' + cur_table_owner + '.' + cur_table_name

  for index_aug, row_aug in aug_table_comments_df.iterrows():

    cur_table_name = str(row_aug['table_name'])
    cur_table_owner = str(row_aug['owner'])
    cur_project_id = str(row_aug['project_id'])
    cur_full_table= cur_project_id + '.' + cur_table_owner + '.' + cur_table_name
    #print('\n' + cur_table_owner + '.' + cur_table_name + ':')

    table_cols=[]
    table_cols_datatype=[]
    table_col_comments=[]
    table_pk_cols=[]
    table_fk_cols=[]

    for index, row in columns_df.loc[ (columns_df['owner'] == cur_table_owner) & (columns_df['table_name'] == cur_table_name) ].iterrows():
      # Inside each owner.table_name combination
      table_cols.append( row['column_name']  )
      table_cols_datatype.append( row['column_name'] + ' (' + row['data_type'] + ') '  )
      col_comments_text=f"""
        Column {row['column_name']} belongs to table {row['table_name']}, owner {row['owner']}.
        It has data type {row['data_type']}.
        Sample values are {row['sample_values']}.
        Description of the column is {row['description']}.
        User commments: {row['column_comments']}.
      """
      table_col_comments.append(col_comments_text)

    for index, row in pkeys_df.loc[ (pkeys_df['owner'] == cur_table_owner) & (pkeys_df['table_name'] == cur_table_name)  ].iterrows():
      # Inside each owner.table_name combination
      table_pk_cols.append( row['column_name']  )

    for index, row in fkeys_df.loc[ (fkeys_df['owner'] == cur_table_owner) & (fkeys_df['table_name'] == cur_table_name) ].iterrows():
      # Inside each owner.table_name combination
      fk_cols_text=f"""
      Column {row['column_name']} is equal to column {row['r_column_name']} in table {row['owner']}.{row['r_table_name']}
      """
      table_fk_cols.append(fk_cols_text)


    if len(",".join(table_pk_cols)) == 0:
      final_pk_cols = "None"
    else:
      final_pk_cols = ",".join(table_pk_cols)

    if len(",".join(table_fk_cols)) == 0:
      final_fk_cols = "None"
    else:
      final_fk_cols = ",".join(table_fk_cols)

    aug_table_desc=f"""
      Table Name: {cur_full_table} |
      Owner: {cur_table_owner} |
      Schema Columns:{",".join(table_cols)} |
      Column Types: {",".join(table_cols_datatype)} |
      Primary Key: {final_pk_cols} |
      Foreign Keys: {final_fk_cols} |
      Project_id: {str(row_aug['project_id'])} |
      Table Comments: {str(row_aug['comments'])}
    """

    #print ('Current aug dataset row: '  + str(row_aug['table_name']))
    #print(aug_table_desc)

    # Works well
    aug_table_comments_df.at[index_aug, 'detailed_description'] = aug_table_desc
  return aug_table_comments_df


In [9]:
# Augment columns dataframe with detailed description. This detailed description column will be the one used as the document when adding the record to the VectorDB

def build_column_desc(columns_df):
  aug_columns_df = columns_df

  print(len(aug_columns_df))
  #print(len(columns_df))

  cur_table_name = ""
  cur_table_owner = ""
  cur_full_table= cur_table_owner + '.' + cur_table_name

  for index_aug, row_aug in aug_columns_df.iterrows():

    cur_table_name = str(row_aug['table_name'])
    cur_table_owner = str(row_aug['owner'])
    cur_full_table= cur_table_owner + '.' + cur_table_name
    curr_col_name = str(row_aug['column_name'])
    curr_col_desc = str(row_aug['description'])

    print('\n' + cur_table_owner + '.' + cur_table_name + ':')

    col_comments_text=f"""
        Column Name: {row_aug['column_name']} |
        Sample values: {row_aug['sample_values']} |
        Data type: {row_aug['data_type']} |
        Table Name: {row_aug['table_name']} |
        Table Owner: {row_aug['owner']} |
        Project_id: {row_aug['project_id']}
        Column description: {row_aug['description']}
    """
        #Low value: {row_aug['low_value']} |
        #High value: {row_aug['high_value']} |
        #User commments: {row_aug['column_comments']}

    print(' Column ' + cur_full_table + '.' + curr_col_name + " Description: " + col_comments_text)

    aug_columns_df.at[index_aug, 'detailed_description'] = col_comments_text
  return aug_columns_df

# Vector Pre-Processing

In [24]:

# print(table_id_list)
get_columns_sql, get_fkeys_sql, get_pkeys_sql, get_table_comments_sql=create_metadata_sqls(table_id_list)

# print(get_columns_sql)
# print(get_fkeys_sql)
# print(get_pkeys_sql)
# print(get_table_comments_sql)

In [25]:

columns_df=schema_generator(get_columns_sql)
fkeys_df=schema_generator(get_fkeys_sql)
pkeys_df=schema_generator(get_pkeys_sql)
table_comments_df=schema_generator(get_table_comments_sql)

Downloading: 100%|[32m██████████[0m|
Downloading: |[32m          [0m|
Downloading: |[32m          [0m|
Downloading: 100%|[32m██████████[0m|


In [26]:
# data_table.DataTable(fkeys_df)
# data_table.DataTable(pkeys_df)
# data_table.DataTable(columns_df)
data_table.DataTable(table_comments_df)

Unnamed: 0,project_id,owner,table_name,OPTION_NAME,OPTION_TYPE,comments
0,ss-jarvis-npd-svc-prj-01,Store,dim_location_full,description,STRING,"""Store demographic details"""
1,ss-jarvis-npd-svc-prj-01,Sales,storesales,description,STRING,"""Contains transaction level sales data from st..."
2,ss-jarvis-npd-svc-prj-01,Sales,ecommdenormdata,description,STRING,"""Contains Transaction level E-Commerce Sales d..."
3,ss-jarvis-npd-svc-prj-01,Product,Product_Master,description,STRING,"""Contains SKU level details about products"""
4,ss-jarvis-npd-svc-prj-01,Customer,FC_Customer_Master,description,STRING,"""Contains Analytical info about the members. T..."
5,ss-jarvis-npd-svc-prj-01,Customer,Persona,description,STRING,"""Contains Detailed Info about personas about e..."


In [27]:
# columns_df=get_column_sample(columns_df)
# data_table.DataTable(columns_df)
columns_df["sample_values"]=None

In [28]:

# Using Gemini to add table comments if comments are null
table_comments_df=add_table_comments(columns_df, pkeys_df, fkeys_df, table_comments_df)
data_table.DataTable(table_comments_df)






Unnamed: 0,project_id,owner,table_name,OPTION_NAME,OPTION_TYPE,comments
0,ss-jarvis-npd-svc-prj-01,Store,dim_location_full,description,STRING,"""Store demographic details"""
1,ss-jarvis-npd-svc-prj-01,Sales,storesales,description,STRING,"""Contains transaction level sales data from st..."
2,ss-jarvis-npd-svc-prj-01,Sales,ecommdenormdata,description,STRING,"""Contains Transaction level E-Commerce Sales d..."
3,ss-jarvis-npd-svc-prj-01,Product,Product_Master,description,STRING,"""Contains SKU level details about products"""
4,ss-jarvis-npd-svc-prj-01,Customer,FC_Customer_Master,description,STRING,"""Contains Analytical info about the members. T..."
5,ss-jarvis-npd-svc-prj-01,Customer,Persona,description,STRING,"""Contains Detailed Info about personas about e..."


In [29]:
# columns_df=add_column_comments(columns_df, pkeys_df, fkeys_df, table_comments_df)
columns_df["column_comments"]=None
data_table.DataTable(columns_df)

Unnamed: 0,project_id,owner,table_name,column_name,data_type,rounding_mode,description,sample_values,column_comments
0,ss-jarvis-npd-svc-prj-01,Store,dim_location_full,town,STRING,,The name of the town where the store is located,,
1,ss-jarvis-npd-svc-prj-01,Store,dim_location_full,state,STRING,,The state where the town is situated,,
2,ss-jarvis-npd-svc-prj-01,Store,dim_location_full,zone,STRING,,The geographical zone classification where the...,,
3,ss-jarvis-npd-svc-prj-01,Store,dim_location_full,zplant,STRING,,he identification code for the store,,
4,ss-jarvis-npd-svc-prj-01,Store,dim_location_full,citytier,STRING,,"The tier classification of the city (e.g., Met...",,
...,...,...,...,...,...,...,...,...,...
571,ss-jarvis-npd-svc-prj-01,Customer,Persona,airport_store_sales,"BIGNUMERIC(48, 10)",,Total sales in airport stores,,
572,ss-jarvis-npd-svc-prj-01,Customer,Persona,discount_slab,STRING,,Slab/category based on discount percentage,,
573,ss-jarvis-npd-svc-prj-01,Customer,Persona,ps_contri_slab,STRING,,Slab/category based on private sales contribution,,
574,ss-jarvis-npd-svc-prj-01,Customer,Persona,storename,STRING,,Name of the store where the customer made purc...,,


In [30]:
table_comments_df=build_table_desc(table_comments_df,columns_df,pkeys_df,fkeys_df)

In [31]:
data_table.DataTable(table_comments_df)

Unnamed: 0,project_id,owner,table_name,OPTION_NAME,OPTION_TYPE,comments,detailed_description
0,ss-jarvis-npd-svc-prj-01,Store,dim_location_full,description,STRING,"""Store demographic details""",\n Table Name: ss-jarvis-npd-svc-prj-01.S...
1,ss-jarvis-npd-svc-prj-01,Sales,storesales,description,STRING,"""Contains transaction level sales data from st...",\n Table Name: ss-jarvis-npd-svc-prj-01.S...
2,ss-jarvis-npd-svc-prj-01,Sales,ecommdenormdata,description,STRING,"""Contains Transaction level E-Commerce Sales d...",\n Table Name: ss-jarvis-npd-svc-prj-01.S...
3,ss-jarvis-npd-svc-prj-01,Product,Product_Master,description,STRING,"""Contains SKU level details about products""",\n Table Name: ss-jarvis-npd-svc-prj-01.P...
4,ss-jarvis-npd-svc-prj-01,Customer,FC_Customer_Master,description,STRING,"""Contains Analytical info about the members. T...",\n Table Name: ss-jarvis-npd-svc-prj-01.C...
5,ss-jarvis-npd-svc-prj-01,Customer,Persona,description,STRING,"""Contains Detailed Info about personas about e...",\n Table Name: ss-jarvis-npd-svc-prj-01.C...


In [32]:
columns_df=build_column_desc(columns_df)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        Data type: DATETIME |
        Table Name: storesales |
        Table Owner: Sales |
        Project_id: ss-jarvis-npd-svc-prj-01
        Column description: None
    

Sales.storesales:
 Column Sales.storesales.ny_mtd1 Description: 
        Column Name: ny_mtd1 |
        Sample values: None |
        Data type: DATETIME |
        Table Name: storesales |
        Table Owner: Sales |
        Project_id: ss-jarvis-npd-svc-prj-01
        Column description: None
    

Sales.storesales:
 Column Sales.storesales.nny_mtd1 Description: 
        Column Name: nny_mtd1 |
        Sample values: None |
        Data type: DATETIME |
        Table Name: storesales |
        Table Owner: Sales |
        Project_id: ss-jarvis-npd-svc-prj-01
        Column description: None
    

Sales.storesales:
 Column Sales.storesales.emp_id Description: 
        Column Name: emp_id |
        Sample values: None |
        Data type: STRING |
 

#BigQuery Vector Embedding

In [33]:
if ENABLE_ANALYTICS is True:
  # Create a BigQuery client
  bq_client = bigquery.Client(location=DATASET_LOCATION, project=PROJECT_ID)

  # Create a dataset
  try:
    dataset = bq_client.create_dataset(dataset=DATASET_NAME)
  except Exception as e:
    print('Failed to create the dataset\n')
    print(str(e))

Failed to create the dataset

409 POST https://bigquery.googleapis.com/bigquery/v2/projects/ss-genai-npd-svc-prj-01/datasets?prettyPrint=false: Already Exists: Dataset ss-genai-npd-svc-prj-01:nl2sql


In [None]:
# def text_embedding(question):
#     """Text embedding with a Large Language Model."""
#     model = TextEmbeddingModel.from_pretrained(embeddings_model)
#     embeddings = model.get_embeddings([question])
#     for embedding in embeddings:
#         vector = embedding.values
#         print(f"Length of Embedding Vector: {len(vector)}")
#     return vector

In [None]:
# Configure gcloud.
!gcloud config set project {PROJECT_ID}

# Grant Cloud SQL Client role to authenticated user
current_user = !gcloud auth list --filter=status:ACTIVE --format="value(account)"

!gcloud projects add-iam-policy-binding {PROJECT_ID} \
  --member=user:{current_user[0]} \
  --role="roles/cloudsql.client"

  # Enable Cloud SQL Admin API
!gcloud services enable sqladmin.googleapis.com
!gcloud services enable aiplatform.googleapis.com

In [34]:
def create_bq_metadata_table():
  # data_table.DataTable(table_comments_df)

  client = bigquery.Client()
  table_comments_df.to_gbq(destination_table=f'{DATAPROJECT_ID}.{DATASET_NAME}.table_comments',
            project_id=f'{DATAPROJECT_ID}',
            if_exists='replace')

  columns_df.to_gbq(destination_table=f'{DATAPROJECT_ID}.{DATASET_NAME}.column_comments',
            project_id=f'{DATAPROJECT_ID}',
            if_exists='replace')

  sql_ddl = f"""CREATE TABLE IF NOT EXISTS {DATAPROJECT_ID}.{DATASET_NAME}.sql(
                                  question string,
                                  generated_sql string);"""

  pandas_gbq.read_gbq(sql_ddl)
  return True;


In [36]:
def create_bq_embeddings():
  table_embeddings_ddl="""CREATE OR REPLACE TABLE `ss-genai-npd-svc-prj-01.nl2sql.table_comments_embeddings` AS
SELECT * FROM ML.GENERATE_EMBEDDING(
  MODEL `ss-genai-npd-svc-prj-01.nl2sql.embedding_model`,
  (
    select to_hex(md5(concat(project_id,'.',owner,'.',table_name,'.',detailed_description))) as idx, detailed_description, project_id, owner, table_name,current_datetime() as epoch_time,detailed_description as content  from ss-genai-npd-svc-prj-01.nl2sql.table_comments
    where length(detailed_description)>0
  )
);"""

  column_embeddings_ddl="""CREATE OR REPLACE TABLE `ss-genai-npd-svc-prj-01.nl2sql.column_comments_embeddings` AS
SELECT * FROM ML.GENERATE_EMBEDDING(
  MODEL `ss-genai-npd-svc-prj-01.nl2sql.embedding_model`,
  (
    select to_hex(md5(concat(project_id,'.',owner,'.',table_name,'.',detailed_description))) as idx, detailed_description, project_id, owner, table_name,column_name,current_datetime() as epoch_time,detailed_description as content  from
    ss-genai-npd-svc-prj-01.nl2sql.column_comments  where length(detailed_description)>0
  )
);"""
  sql_embeddings_ddl="""
CREATE OR REPLACE TABLE `ss-genai-npd-svc-prj-01.nl2sql.sql_embeddings` AS
SELECT * FROM ML.GENERATE_EMBEDDING(
  MODEL `ss-genai-npd-svc-prj-01.nl2sql.embedding_model`,
  (
    select to_hex(md5((question))) as idx, question,generated_sql,current_datetime() as epoch_time,question as content  from ss-genai-npd-svc-prj-01.nl2sql.sql where length(question)>0
    )
);"""

  client=bigquery.Client()

  pandas_gbq.read_gbq(table_embeddings_ddl)
  pandas_gbq.read_gbq(column_embeddings_ddl)
  pandas_gbq.read_gbq(sql_embeddings_ddl)
  return True;


In [37]:
if create_bq_metadata_table():
  if create_bq_embeddings():
    print("Embeddings for tables, columns and sql are successfully generated")
  else:
    print("Failed to generate embeddings for tables, columns and sql")
else:
  print("Failed to create metadata tables for embeedings")


100%|██████████| 1/1 [00:00<00:00, 7182.03it/s]
100%|██████████| 1/1 [00:00<00:00, 7958.83it/s]


Downloading: |[32m          [0m|
Downloading:   0%|[32m          [0m|



Downloading: 100%|[32m██████████[0m|
Downloading:   0%|[32m          [0m|



Downloading: 100%|[32m██████████[0m|
Downloading: |[32m          [0m|



Downloading: |[32m          [0m|
Embeddings for tables, columns and sql are successfully generated
