The goal of the following code is to embed the descriptions of the tables based on schema that were created by prompting the MPT-7B Instruct model, generating their embeddings and storing in a vector database. Please note that for 10-20 tables, simple cosine similarity would be suffice. However, the vector database Chroma is employed here such that the solution can scale as more denormalized gold layer tables are added

In [0]:
%pip install chromadb

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
Collecting chromadb
  Downloading chromadb-0.3.22-py3-none-any.whl (69 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 69.2/69.2 kB 2.1 MB/s eta 0:00:00
Collecting duckdb>=0.7.1
  Downloading duckdb-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 15.2/15.2 MB 41.2 MB/s eta 0:00:00
Collecting numpy>=1.21.6
  Downloading numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.3/17.3 MB 78.0 MB/s eta 0:00:00
Collecting posthog>=2.4.0
  Downloading posthog-3.0.1-py2.py3-none-any.whl (37 kB)
Collecting clickhouse-connect>=0.5.7
  Downloading clickhouse_connect-0.5.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (922 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 922.6/922.6 kB 73.4 MB/s eta 0:00:00
Collecting sentence

In [0]:
from pyspark.sql.functions import monotonically_increasing_id
import chromadb
from chromadb.config import Settings
import time
from chromadb.utils import embedding_functions
from transformers import pipeline
import pandas as pd
from sys import version_info
import cloudpickle



In [0]:
%sql
USE CATALOG metsql;
USE DATABASE metabase;

In [0]:
metadf = spark.sql("SELECT * FROM metadata_table")
display(metadf)

table_metadata,table_names,table_description
"schema: tbl[bank_churn] cols [RowNumber, CustomerId, Surname, CreditScore, Geography, Gender, Age, Tenure, Balance, NumOfProducts, HasCrCard, IsActiveMember, EstimatedSalary, Exited, Complain, Satisfaction Score, Card Type, Point Earned]",bank_churn,This table contains information about customers who have left their banking provider in the last 12 months (the churn column). The data includes demographic and behavioural characteristics as well as financial details for each customer.
"schema: tbl[cars] cols [price, brand, model, year, title_status, mileage, color, vin, lot, state, country, condition]",cars,"The cars table contains price,brand,model,year,title status,mileage and various other attributes for different car models"
"schema: tbl[commodity_prices] cols [date, oil_brent, oil_dubai, coffee_arabica, coffee_robustas, tea_columbo, tea_kolkata, tea_mombasa, sugar_eu, sugar_us, sugar_world]",commodity_prices,A table with commodity prices over time
"schema: tbl[hmeq] cols [BAD, LOAN, MORTDUE, VALUE, REASON, JOB, YOJ, DEROG, DELINQ, CLAGE, NINQ, CLNO, DEBTINC]",hmeq,"This table contains data on customers with loans and mortgages from HME Equity company in 2022. The columns are as follows BAD (Boolean), LOAN(String),MORTDUE(Integer),VALUE(Double),REASON(Text),JOB(Text),YOJ(Integer),DEROG(Boolean),DELINQ(Boolean),CLAGE(Integer),NINQ(Boolean),CLNO(Integer),DEBTINC(Double)."
"schema: tbl[mushrooms] cols [cap_shape, cap_surface, cap_color, bruises, odor, gill_attachment, gill_spacing, gill_size, gill_color, stalk_shape, stalk_root, stalk_surface_above_ring, stalk_surface_below_ring, stalk_color_above_ring, stalk_color_below_ring, veil_type, veil_color, ring_number, ring_type, spore_point_color, population, habitat, poisonous]",mushrooms,"This dataset contains information about different types and varieties of mushrooms with their respective characteristics such as shape, color etc"
"schema: tbl[name_counts] cols [Name, Gender, Count, Probability]",name_counts,A table with names and their counts grouped by gender
"schema: tbl[rent] cols [Posted On, BHK, Rent, Size, Floor, Area Type, Area Locality, City, Furnishing Status, Tenant Preferred, Bathroom, Point of Contact]",rent,A table with various attributes related to rental properties including their location and other information about them
"schema: tbl[sales] cols [Invoice ID, Branch, City, Customer type, Gender, Product line, Unit price, Quantity, Tax 5%, Total, Date, Time, Payment, cogs, gross margin percentage, gross income, Rating]",sales,Sales data with multiple dimensions including branch location and customer types by product lines


In [0]:
persist_directory = "/dbfs/FileStore/shared_uploads/avinash.sooriyarachchi@databricks.com/metsqlchroma"

In [0]:
# setup Chroma in-memory, for easy prototyping. Can add persistence easily!
client = chromadb.Client(
      Settings(
        persist_directory=persist_directory,
        chroma_db_impl="duckdb+parquet",
        
    )
)
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
# Create collection. get_collection, get_or_create_collection, delete_collection also available!
collection = client.create_collection(name="metadata_vdb", embedding_function=sentence_transformer_ef) 

Using embedded DuckDB with persistence: data will be stored in: /dbfs/FileStore/shared_uploads/avinash.sooriyarachchi@databricks.com/metsqlchroma


In [0]:
metadf_pd = metadf.toPandas()
table_descriptions = metadf_pd.table_description.to_list()
table_names = metadf_pd.table_names.to_list()

In [0]:
table_metadata = metadf_pd.table_metadata.apply(lambda x: {"table_metadata": x}).to_list()

In [0]:
table_descriptions[0], table_metadata[0]["table_metadata"]

('This table contains information about customers who have left their banking provider in the last 12 months (the churn column). The data includes demographic and behavioural characteristics as well as financial details for each customer.',
 'schema: tbl[bank_churn] cols [RowNumber, CustomerId, Surname, CreditScore, Geography, Gender, Age, Tenure, Balance, NumOfProducts, HasCrCard, IsActiveMember, EstimatedSalary, Exited, Complain, Satisfaction Score, Card Type, Point Earned]')

In [0]:
collection.add(
  documents=table_descriptions, # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
  metadatas=table_metadata, # filter on these!
  ids=table_names, # unique for each doc 
)


In [0]:
client.persist()

True

In [0]:
result = collection.query(
    query_texts=["How many vehicles were sold in 2008?"],
    n_results=1,
    # where={"metadata_field": "is_equal_to_this"}, # optional filter
    # where_document={"$contains":"search_string"}  # optional filter
)

In [0]:
result

{'ids': [['cars']],
 'embeddings': None,
 'documents': [['The cars table contains price,brand,model,year,title status,mileage and various other attributes for different car models']],
 'metadatas': [[{'table_metadata': 'schema: tbl[cars] cols [price, brand, model, year, title_status, mileage, color, vin, lot, state, country, condition]'}]],
 'distances': [[1.0797241926193237]]}

## Testing:
### Now load the vector database from disk and query again

In [0]:
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
# Create a new client with the same settings
client = chromadb.Client(
    Settings(
        persist_directory=persist_directory,
        chroma_db_impl="duckdb+parquet",
    )
)

# Load the collection
collection = client.get_collection(name="metadata_vdb", embedding_function=sentence_transformer_ef)

In [0]:
result = collection.query(
    query_texts=["How many people named Alex are there in the US?"],
    n_results=1,
    # where={"metadata_field": "is_equal_to_this"}, # optional filter
    # where_document={"$contains":"search_string"}  # optional filter
)

In [0]:
result

{'ids': [['name_counts']],
 'embeddings': None,
 'documents': [['A table with names and their counts grouped by gender']],
 'metadatas': [[{'table_metadata': 'schema: tbl[name_counts] cols [Name, Gender, Count, Probability]'}]],
 'distances': [[1.3860574960708618]]}

In [0]:
result['metadatas'][0][0]['table_metadata']

'schema: tbl[name_counts] cols [Name, Gender, Count, Probability]'

In [0]:
result['ids'][0][0]

'name_counts'

### Register with MLFlow and Deploy to Serving
Define a custom model pyfunc class to wrap the vector database and deploy

In [0]:
import mlflow.pyfunc

class RelevantTableFinder(mlflow.pyfunc.PythonModel):

  def load_context(self, context):
    import chromadb
    from chromadb.config import Settings
    from chromadb.utils import embedding_functions
    self.sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
    # Create a new client with the same settings
    self.client = chromadb.Client(
        Settings(
            persist_directory=context.artifacts["persist_directory"],
            chroma_db_impl="duckdb+parquet",
        )
    )

    # Load the collection
    self.collection = self.client.get_collection(name="metadata_vdb", embedding_function=self.sentence_transformer_ef)




  def predict(self, context, model_input):
    import json
    question = model_input.iloc[:,0].to_list() # get the first column
    results = self.collection.query(
    query_texts=question,
    n_results=1,
    )
    # The vector search is over the similar stack overflow questions. What needs to be concatenated are responses to the top 2 queries stored as metadata
    responses = {"table": results['ids'][0][0], 'description': results['documents'][0][0], 'schema':results['metadatas'][0][0]['table_metadata']}
    result = {'Response': responses}
    return json.dumps(result)

In [0]:
#Sample Question
payload_pd = pd.DataFrame([["How many cars sold in 2018?"]],columns=['text'])
input_example = payload_pd

In [0]:
PYTHON_VERSION = "{major}.{minor}.{micro}".format(major=version_info.major,
                                                  minor=version_info.minor,
                                                  micro=version_info.micro)

In [0]:
conda_env = {
    'channels': ['defaults'],
    'dependencies': [
      'python={}'.format(PYTHON_VERSION),
      'pip',
      {
        'pip': [
          'mlflow',
          'transformers',
          'pandas',
          'chromadb',
          'cloudpickle=={}'.format(cloudpickle.__version__),
          'torch'],
      },
    ],
    'name': 'code_env'
}


In [0]:
mlflow_pyfunc_model_path = "tablemetadata_vectordb"

In [0]:
artifacts = {
   "persist_directory": persist_directory
}

In [0]:
mlflow.pyfunc.log_model(artifact_path=mlflow_pyfunc_model_path, python_model=RelevantTableFinder(),artifacts=artifacts, conda_env=conda_env, input_example = input_example)



<mlflow.models.model.ModelInfo at 0x7f011ad66620>

In [0]:
loraadapter_location = "/dbfs/FileStore/shared_uploads/avinash.sooriyarachchi@databricks.com/text_to_sql/flt5txt2sql"