# Enrich your Databricks Data Catalog with generative AI metadata using Amazon Bedrock

### [Reference Blog](https://aws.amazon.com/blogs/big-data/enrich-your-aws-glue-data-catalog-with-generative-ai-metadata-using-amazon-bedrock/)

### Prerequisite:
[Entity Extraction using Templating with Amazon Bedrock](https://youtu.be/jjn0EjiFT6I?si=Iex1B7yMSCBtbRtZ)

In [0]:
%pip install jsonschema

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
dbutils.library.restartPython()

# Create Dummy Tables for experiment

In [0]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

# 1️⃣ Products Table - 5 rows
products_data = [
    (101, "Laptop", "Electronics", 850.50),
    (102, "Book", "Education", 15.99),
    (103, "T-Shirt", "Clothing", 25.49),
    (104, "Phone", "Electronics", 699.00),
    (105, "Shoes", "Footwear", 49.99)
]
products_schema = StructType([
    StructField("product_id", IntegerType(), True),
    StructField("name", StringType(), True),
    StructField("category", StringType(), True),
    StructField("price", DoubleType(), True)
])
spark.createDataFrame(products_data, products_schema) \
    .write.mode("overwrite").format("delta").saveAsTable("delta_products")


# 2️⃣ Employees Table - 3 rows
employees_data = [
    (1, "Alice", "Engineering", 75000),
    (2, "Bob", "Sales", 55000),
    (3, "Charlie", "HR", 48000)
]
employees_schema = StructType([
    StructField("emp_id", IntegerType(), True),
    StructField("name", StringType(), True),
    StructField("department", StringType(), True),
    StructField("salary", IntegerType(), True)
])
spark.createDataFrame(employees_data, employees_schema) \
    .write.mode("overwrite").format("delta").saveAsTable("delta_employees")


# 3️⃣ Sales Table - 10 rows
sales_data = [
    (1, 101, 2, "2024-01-01"),
    (2, 103, 5, "2024-02-10"),
    (3, 102, 1, "2024-03-05"),
    (4, 104, 3, "2024-03-15"),
    (5, 105, 2, "2024-03-20"),
    (6, 101, 1, "2024-03-22"),
    (7, 102, 4, "2024-03-25"),
    (8, 103, 2, "2024-04-01"),
    (9, 104, 1, "2024-04-02"),
    (10, 105, 6, "2024-04-04")
]
sales_schema = StructType([
    StructField("sale_id", IntegerType(), True),
    StructField("product_id", IntegerType(), True),
    StructField("quantity", IntegerType(), True),
    StructField("sale_date", StringType(), True)
])
spark.createDataFrame(sales_data, sales_schema) \
    .write.mode("overwrite").format("delta").saveAsTable("delta_sales")


# 4️⃣ Customers Table - 2 rows
customers_data = [
    (1001, "Daniel", "daniel@example.com", "India"),
    (1002, "Emma", "emma@example.com", "USA")
]
customers_schema = StructType([
    StructField("customer_id", IntegerType(), True),
    StructField("name", StringType(), True),
    StructField("email", StringType(), True),
    StructField("country", StringType(), True)
])
spark.createDataFrame(customers_data, customers_schema) \
    .write.mode("overwrite").format("delta").saveAsTable("delta_customers")


# 5️⃣ Transactions Table - 7 rows
transactions_data = [
    ("TX100", 1001, 101, "2024-04-01", "Completed"),
    ("TX101", 1002, 103, "2024-04-03", "Pending"),
    ("TX102", 1001, 102, "2024-04-04", "Failed"),
    ("TX103", 1002, 104, "2024-04-05", "Completed"),
    ("TX104", 1001, 105, "2024-04-06", "Completed"),
    ("TX105", 1002, 101, "2024-04-07", "Pending"),
    ("TX106", 1001, 103, "2024-04-08", "Failed")
]
transactions_schema = StructType([
    StructField("txn_id", StringType(), True),
    StructField("customer_id", IntegerType(), True),
    StructField("product_id", IntegerType(), True),
    StructField("txn_date", StringType(), True),
    StructField("status", StringType(), True)
])
spark.createDataFrame(transactions_data, transactions_schema) \
    .write.mode("overwrite").format("delta").saveAsTable("delta_transactions")

In [0]:
import json
import boto3 
from botocore.config import Config

# [Bedrock](https://aws.amazon.com/bedrock/) Client Creation

In [0]:
bedrock_client = boto3.client("bedrock-runtime",aws_access_key_id='',aws_secret_access_key='',region_name='us-east-1')

In [0]:
model_id = "arn:aws:bedrock:us-east-1::model/anthropic.claude-3-7-sonnet-20250219-v1:0"
catalog = "workspace"
schema = "default"
table_name="delta_customers"

In [0]:
spark.sql(f"DESCRIBE TABLE  {catalog}.{schema}.{table_name}").display()

col_name,data_type,comment
customer_id,int,
name,string,
email,string,
country,string,


# Fetch Column Names & Datatypes

In [0]:
columns_df = spark.sql(f"DESCRIBE TABLE {catalog}.{schema}.{table_name}")
existing_metadata = columns_df.toPandas().to_dict(orient='records')

In [0]:
existing_metadata

[{'col_name': 'customer_id', 'data_type': 'int', 'comment': None},
 {'col_name': 'name', 'data_type': 'string', 'comment': None},
 {'col_name': 'email', 'data_type': 'string', 'comment': None},
 {'col_name': 'country', 'data_type': 'string', 'comment': None}]

# Column Description Generation using in-context learning

In [0]:
# Step 2: Prepare prompt for LLM
user_msg_template_table = f"""
You are given metadata for a Databricks  table called {catalog}.{schema}.{table_name}. Your task is to generate meaningful comments for the table and its columns.

Instructions:
1. Use the metadata to understand the structure.
2. Create a helpful and concise table description.
3. Add meaningful comments to each column based on the name and data type.
4. If a column is a primary key or foreign key (e.g., ends with '_id'), mention that in the comment.
5. Return your output strictly in the following JSON format:

{{
  "table_description": "your description here",
  "columns_with_comments": [
    {{"column_name": "col1", "comment": "description"}},
    {{"column_name": "col2", "comment": "description"}}
  ]
}}

If you can't infer a description, use "not available".
Here is the table metadata in <metadata></metadata> tags:

<metadata>
{existing_metadata}
</metadata>
"""

# Step 3: Call the Bedrock model
response = bedrock_client.invoke_model(
    modelId="arn:aws:bedrock:us-east-1:404091004961:inference-profile/us.anthropic.claude-3-sonnet-20240229-v1:0",
    contentType="application/json",
    accept="application/json",
    body=json.dumps({
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 800,
        "top_k": 250,
        "temperature": 0.5,
        "top_p": 0.999,
        "stop_sequences": [],
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": user_msg_template_table
                    }
                ]
            }
        ]
    })
)

# Step 4: Parse the response
response_body = json.loads(response['body'].read())

enriched_metadata_json = json.loads(response_body['content'][0]['text'])
print(enriched_metadata_json)


{'table_description': 'This table stores customer information, including unique identifiers, names, email addresses, and countries.', 'columns_with_comments': [{'column_name': 'customer_id', 'comment': 'Unique identifier for each customer (primary key)'}, {'column_name': 'name', 'comment': 'Full name of the customer'}, {'column_name': 'email', 'comment': 'Email address of the customer'}, {'column_name': 'country', 'comment': 'Country where the customer is located'}]}


In [0]:
enriched_metadata_json

{'table_description': 'This table stores customer information, including unique identifiers, names, email addresses, and countries.',
 'columns_with_comments': [{'column_name': 'customer_id',
   'comment': 'Unique identifier for each customer (primary key)'},
  {'column_name': 'name', 'comment': 'Full name of the customer'},
  {'column_name': 'email', 'comment': 'Email address of the customer'},
  {'column_name': 'country',
   'comment': 'Country where the customer is located'}]}

In [0]:
enriched_metadata_json['table_description']

'This table stores customer information, including unique identifiers, names, email addresses, and countries.'

# [Schema Validation](https://python-jsonschema.readthedocs.io/en/latest/validate/)

In [0]:
from jsonschema import validate

schema_table_input = {
    "type": "object", 
    "properties" : {
  "table_description": {"type" : "string"},
  "columns_with_comments": {"type" : "array"}
}
}



validate(instance=json.loads(response_body['content'][0]['text']), schema=schema_table_input)

# Update Data Catalog

In [0]:
spark.sql(f"ALTER TABLE {catalog}.{schema}.{table_name} SET TBLPROPERTIES ('comment' = '{enriched_metadata_json['table_description']}')")

# Apply column comments
for col in enriched_metadata_json["columns_with_comments"]:
    col_name = col["column_name"]
    comment = col["comment"]
    spark.sql(f"ALTER TABLE {catalog}.{schema}.{table_name} CHANGE COLUMN {col_name} COMMENT '{comment}'")

print(f"✅ Metadata for table {table_name} updated successfully.")

✅ Metadata for table delta_customers updated successfully.
