# Step 5: Extract Metadata

This notebook extracts detailed metadata from contracts for integration with contract management systems (like Icertis).

**How it works:**
1. You select a contract type (master_agreement, amendment, scope_of_work, or termination)
2. The notebook loads a CSV file listing which files to process (or uses the classified table)
3. It reads `metadata.csv` to determine which fields to extract for that contract type
4. It loads a prompt template from `prompt_{contract_type}.md`
5. It builds a dynamic SQL query that calls the LLM to extract all fields at once

**Contract types supported:**
- **master_agreement** -- full metadata extraction (parties, dates, terms, clauses, etc.)
- **amendment** -- amendment-specific fields (parent agreement, what changed, new terms)
- **scope_of_work** -- SOW-specific fields (deliverables, milestones, pricing, acceptance criteria)
- **termination** -- termination-specific fields (reason, type, obligations, settlement)

**Before you run this:**
- Steps 1-4 must be complete
- The `flat`, `doc_info`, and `classified` tables must exist
- The `metadata.csv` and `prompt_{type}.md` files must be in the repo

**Output tables:**
- `metadata_{contract_type}` -- one table per contract type with all extracted fields

## Configuration

Select the contract type from the dropdown. The notebook will load the right metadata fields and prompt template automatically.

In [None]:
dbutils.widgets.text("catalog", "shm", "Catalog")
dbutils.widgets.text("schema", "contract", "Schema")
dbutils.widgets.dropdown("contract_type", "master_agreement", ["master_agreement", "amendment", "scope_of_work", "termination"])
dbutils.widgets.text("llm_endpoint", "databricks-claude-sonnet-4-5", "LLM Endpoint")
dbutils.widgets.text("batch_size", "100", "Batch Size")
dbutils.widgets.text("max_input_char", "400000", "Max Input Characters")

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
contract_type = dbutils.widgets.get("contract_type")
llm_endpoint = dbutils.widgets.get("llm_endpoint")
batch_size = int(dbutils.widgets.get("batch_size").strip())
max_input_char = int(dbutils.widgets.get("max_input_char").strip())

---
## 5A: Load File List

If a CSV file exists for the selected contract type (e.g., `amendment_file_names.csv`), it will be used to filter which files to process. Otherwise, the classified table is used to find the right documents.

In [None]:
import pandas as pd
import os

# Try to load a file list CSV for the selected contract type
file_list_path = f'./{contract_type}_file_names.csv'
if os.path.exists(file_list_path):
    file_list_df = pd.read_csv(file_list_path)
    file_names = file_list_df['file_name'].tolist()
    file_names_sql_tuple = "(" + ", ".join(f"'{name}'" for name in file_names) + ")"
    print(f"Loaded {len(file_names)} files from {file_list_path}")
else:
    file_names = []
    file_names_sql_tuple = None
    print(f"No file list found at {file_list_path} -- will use classified table filter")

---
## 5B: Build Metadata Prompt

The `metadata.csv` file defines what fields to extract for each contract type. Each row has:
- `type` -- which contract type this field belongs to
- `metadata_name` -- the field name
- `metadata_description` -- what to extract
- `enum_fields` -- allowed values (if any)

The prompt is assembled from the base template (`prompt_{type}.md`) plus the field definitions.

In [None]:
# Load metadata definitions and filter for the selected contract type
metadata_df = pd.read_csv('./metadata.csv')
metadata_df = metadata_df[metadata_df['type'] == contract_type]
print(f"Found {len(metadata_df)} metadata fields for {contract_type}")

column_names = ['path', 'contract_type'] + metadata_df['metadata_name'].tolist()

In [None]:
# Create the output table if it doesn't exist
from pyspark.sql.types import StructType, StructField, StringType
table_schema = StructType([StructField(col, StringType(), True) for col in column_names])
df = spark.createDataFrame([], table_schema)
if not spark.catalog.tableExists(f"{catalog}.{schema}.metadata_{contract_type}"):
    df.write.format("delta").mode("overwrite").saveAsTable(f"{catalog}.{schema}.metadata_{contract_type}")
    print(f"Created table: {catalog}.{schema}.metadata_{contract_type}")
else:
    print(f"Table already exists: {catalog}.{schema}.metadata_{contract_type}")

In [None]:
# Build the prompt from the template + field definitions
with open(f"./prompt_{contract_type}.md", 'r') as f:
    base_prompt = f.read().strip()

fields = []
response_struct_fields = []
for _, row in metadata_df.iterrows():
    name = row['metadata_name']
    desc = row['metadata_description']
    enum = row.get('enum_fields', None)
    if pd.notnull(enum) and enum:
        enum_str = f" ENUM: {enum}"
    else:
        enum_str = ""
    fields.append(f"- {name}: {desc}{enum_str}")
    response_struct_fields.append(f"{name}:STRING")

prompt = base_prompt + "\n\n" + "\n".join(fields)
json_struct = "STRUCT<" + ",".join(response_struct_fields) + ">"
response_format = f"STRUCT<result:{json_struct}>"

print(f"Prompt length: {len(prompt)} characters")
print(f"Fields: {len(fields)}")

---
## 5C: Run Metadata Extraction

This builds and executes the SQL query that calls the LLM for each contract. The query dynamically adapts based on the contract type:
- **master_agreement** -- filters to classified master agreements, includes amendment info
- **amendment** -- filters to files in the amendment file list, includes master agreement context
- **scope_of_work** -- filters to SOW documents from doc_info classification
- **termination** -- filters to termination documents from doc_info classification

It skips documents already in the metadata table (idempotent).

In [None]:
# Build WHERE clause and context fields based on contract type
if contract_type == "amendment" and file_names_sql_tuple:
    where_clause = f"WHERE file_name in {file_names_sql_tuple}"
    context_fields = """
          'Vendor Name:', vendor_name, '\\n',
          'File Path:', path, '\\n',
          'Related Master Agreement Name:', related_master_agreement_name, '\\n',
          'Initial Master Agreement Expiry:', effective_date, '\\n',
          'Doc Info:\\n', combined_doc_info, '\\n',
          'Text:\\n', truncated, '\\n'
    """
elif contract_type == "master_agreement":
    where_clause = "WHERE c.is_master_agreement"
    context_fields = """
          'Vendor Name:', vendor_name, '\\n',
          'File Path:', path, '\\n',
          'Has Amendments:', has_amendments, '\\n',
          'Initial Master Agreement Expiry:', initial_master_agreement_expiry_date, '\\n',
          'Final Master Agreement Expiry:', final_expiry_date, '\\n',
          'Doc Info:\\n', combined_doc_info, '\\n',
          'Text:\\n', truncated, '\\n'
    """
elif contract_type == "scope_of_work":
    where_clause = "WHERE d.document_type = 'SCOPE_OF_WORK' OR d.agreement_type = 'SOW'"
    context_fields = """
          'Vendor Name:', vendor_name, '\\n',
          'File Path:', path, '\\n',
          'Doc Info:\\n', combined_doc_info, '\\n',
          'Text:\\n', truncated, '\\n'
    """
elif contract_type == "termination":
    where_clause = "WHERE d.document_type = 'TERMINATION' OR d.agreement_type = 'TERMINATION'"
    context_fields = """
          'Vendor Name:', vendor_name, '\\n',
          'File Path:', path, '\\n',
          'Doc Info:\\n', combined_doc_info, '\\n',
          'Text:\\n', truncated, '\\n'
    """
else:
    raise KeyError(f"Unknown contract type: {contract_type}")

# Build the SQL query
sql_query = f"""
MERGE INTO {catalog}.{schema}.metadata_{contract_type} AS target
USING (
  SELECT
    path,
    '{contract_type}' as contract_type,
    metadata.*,
    cast(to_json(metadata) as string) AS combined_metadata
  FROM (
    SELECT 
      path,
      from_json(AI_QUERY(
        '{llm_endpoint}',
        SUBSTRING(CONCAT(
          '{prompt}', '\\n',
          {context_fields}
        ),0,{max_input_char}),
        responseFormat => '{response_format}'
      ), '{json_struct}'
      ) as metadata
    FROM (
      SELECT * EXCEPT(c.path, d.path)
      FROM {catalog}.{schema}.flat f
      LEFT JOIN {catalog}.{schema}.doc_info d
        ON f.path = d.path
      LEFT JOIN {catalog}.{schema}.classified c
        ON f.path = c.path
      LEFT ANTI JOIN {catalog}.{schema}.metadata_{contract_type} m
        ON f.path = m.path
      {where_clause}
      LIMIT CAST({batch_size} AS INT)
    )
  )
) AS source
ON target.path = source.path
WHEN NOT MATCHED THEN INSERT *;
"""

print("Query built successfully")

In [None]:
# Execute the metadata extraction query
spark.sql(sql_query)

In [None]:
# Check results
count = spark.sql(f"SELECT * FROM {catalog}.{schema}.metadata_{contract_type}").count()
print(f"There are {count} rows in metadata_{contract_type}")
spark.sql(f"SELECT * FROM {catalog}.{schema}.metadata_{contract_type} LIMIT 5").display()