%md
# Contact Analysis
## Metadata
This notebook takes a CSV file (dropped into the workspace) and creats dynamic SQL queries for getting metadata out of the previous steps. The CSV is meant to make category management a lot easier - you only need metadata_name, metadata_description, and any ENUM categories and it will construct the SQL queries for you.

In [None]:
dbutils.widgets.text("catalog", "")
dbutils.widgets.text("schema", "")
dbutils.widgets.dropdown("contract_type", "master_agreement", ["master_agreement", "amendment"])
dbutils.widgets.text("llm_endpoint", "databricks-claude-sonnet-4-5")
dbutils.widgets.text("batch_size", "100")
dbutils.widgets.text("max_input_char", "400000")

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
contract_type = dbutils.widgets.get("contract_type")
llm_endpoint = dbutils.widgets.get("llm_endpoint")
batch_size = int(dbutils.widgets.get("batch_size").strip())
max_input_char = int(dbutils.widgets.get("max_input_char").strip())

## Load file list and filter

In [None]:
import pandas as pd

# Load file list based on contract type
if contract_type == "amendment":
    file_list_df = pd.read_csv('./amendment_file_names.csv')
    file_names = file_list_df['file_name'].tolist()
else:
    # For master agreements, we'll filter based on classified table
    file_names = None

# Create SQL tuple for filtering if we have file names
if file_names:
    file_names_sql_tuple = "(" + ", ".join(f"'{name}'" for name in file_names) + ")"
    print(f"There are {len(file_names)} files in the list")
    print(file_names_sql_tuple)
else:
    file_names_sql_tuple = None
    print("No file filter - will use classified table filter")

## Create append-only contract file tracking table

In [None]:
# Create append-only tracking table for contract files
spark.sql(f"""
CREATE TABLE IF NOT EXISTS {catalog}.{schema}.contract_files (
  file_name STRING,
  contract_type STRING,
  added_timestamp TIMESTAMP
)
USING DELTA
""")

# Add files to tracking table if using file list and they don't already exist
if file_names:
    from pyspark.sql.functions import lit, current_timestamp
    from pyspark.sql.types import StructType, StructField, StringType
    
    # Create dataframe with file names
    file_df = spark.createDataFrame(
        [(name, contract_type) for name in file_names],
        ["file_name", "contract_type"]
    ).withColumn("added_timestamp", current_timestamp())
    
    # Only insert files that don't already exist
    file_df.createOrReplaceTempView("new_files")
    spark.sql(f"""
    MERGE INTO {catalog}.{schema}.contract_files AS target
    USING new_files AS source
    ON target.file_name = source.file_name
    WHEN NOT MATCHED THEN INSERT *
    """)

In [None]:
%sql
SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.contract_files')

## Make metadata prompt

In [None]:
# Load metadata CSV
metadata_df = pd.read_csv('./metadata.csv')

# Filter metadata for the selected contract type
metadata_df = metadata_df[metadata_df['type'] == contract_type]

column_names = ['path', 'contract_type'] + metadata_df['metadata_name'].tolist()

In [None]:
# Create the metadata table if it doesn't exist
from pyspark.sql.types import StructType, StructField, StringType
table_schema = StructType([StructField(col, StringType(), True) for col in column_names])
df = spark.createDataFrame([], table_schema)
if not spark.catalog.tableExists(f"{catalog}.{schema}.metadata"):
    df.write.format("delta").mode("overwrite").saveAsTable(f"{catalog}.{schema}.metadata")

In [None]:
# Load prompt from markdown file
if contract_type == "amendment":
    with open('./amendment_prompt.md', 'r') as f:
        base_prompt = f.read().strip()
else:
    with open('./master_agreement_prompt.md', 'r') as f:
        base_prompt = f.read().strip()

# Build field list from metadata_df
fields = []
response_struct_fields = []
for _, row in metadata_df.iterrows():
    name = row['metadata_name']
    desc = row['metadata_description']
    enum = row.get('enum_fields', None)
    if pd.notnull(enum) and enum:
        enum_str = f" ENUM: {enum}"
    else:
        enum_str = ""
    fields.append(f"- {name}: {desc}{enum_str}")
    response_struct_fields.append(f"{name}:STRING")

# Combine base prompt with fields at the bottom
prompt = base_prompt + "\n\n" + "\n".join(fields)

json_struct = "STRUCT<" + ",".join(response_struct_fields) + ">"
response_format = f"STRUCT<result:{json_struct}>"

In [None]:
# Ensure metadata table has all necessary columns
fields_ddl = ",\n  ".join([f"{col} STRING" for col in column_names])
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS {catalog}.{schema}.metadata (
  {fields_ddl}
)
USING DELTA
"""
spark.sql(create_table_sql)

In [None]:
# Build WHERE clause based on contract type
if contract_type == "amendment" and file_names_sql_tuple:
    where_clause = f"WHERE file_name in {file_names_sql_tuple}"
elif contract_type == "master_agreement":
    where_clause = "WHERE c.is_master_agreement"
else:
    where_clause = ""

# Build context fields based on contract type
if contract_type == "amendment":
    context_fields = """
          'Vendor Name:', vendor_name, '\\n',
          'File Path:', path, '\\n',
          'Related Master Agreement Name:', related_master_agreement_name, '\\n',
          'Initial Master Agreement Expiry:', effective_date, '\\n',
          'Doc Info:\\n', combined_doc_info, '\\n',
          'Text:\\n', truncated, '\\n'
    """
else:
    context_fields = """
          'Vendor Name:', vendor_name, '\\n',
          'File Path:', path, '\\n',
          'Has Amendments:', has_amendments, '\\n',
          'Initial Master Agreement Expiry:', initial_master_agreement_expiry_date, '\\n',
          'Final Master Agreement Expiry:', final_expiry_date, '\\n',
          'Doc Info:\\n', combined_doc_info, '\\n',
          'Text:\\n', truncated, '\\n'
    """

# Compose the SQL query
sql_query = f"""
MERGE INTO {catalog}.{schema}.metadata AS target
USING (
  SELECT
    path,
    '{contract_type}' as contract_type,
    metadata.*,
    cast(to_json(metadata) as string) AS combined_metadata
  FROM (
    SELECT 
      path,
      from_json(AI_QUERY(
        '{llm_endpoint}',
        SUBSTRING(CONCAT(
          '{prompt}', '\\n',
          {context_fields}
        ),0,{max_input_char}),
        responseFormat => '{response_format}'
      ), '{json_struct}'
      ) as metadata
    FROM (
      SELECT * EXCEPT(c.path, d.path)
      FROM {catalog}.{schema}.flat f
      LEFT JOIN {catalog}.{schema}.doc_info d
        ON f.path = d.path
      LEFT JOIN {catalog}.{schema}.classified c
        ON f.path = c.path
      LEFT ANTI JOIN {catalog}.{schema}.metadata m
        ON f.path = m.path
      {where_clause}
      LIMIT CAST({batch_size} AS INT)
    )
  )
) AS source
ON target.path = source.path
WHEN NOT MATCHED THEN INSERT *;
"""

In [None]:
# Execute the query!
spark.sql(sql_query)

In [None]:
%sql
SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.metadata')