# Step 3: Assemble Context

This notebook sets up vector search and assembles all the context needed for contract classification.

**What it does:**
1. Creates a `sections` table by splitting each document into its section headers and text
2. Sets up Databricks Vector Search indexes on the `flat`, `doc_info`, and `sections` tables
3. Assembles a comprehensive context table that combines -- for each contract -- its own info, folder-level related documents, and semantically similar documents found via vector search

The assembled table is what gets fed to the LLM in Step 4 (Classify). It gives the model everything it needs to determine whether a contract is a master agreement, identify amendments, and figure out final expiry dates.

**Before you run this:**
- Steps 1 and 2 must be complete
- The `flat`, `references`, and `doc_info` tables must exist with data
- You need a Vector Search endpoint (the notebook will create one if it doesn't exist)

**Output tables:**
- `sections` -- document text split by section header
- `assembled` -- full context per contract (own info + folder docs + vector search results)

## Configuration

In [None]:
dbutils.widgets.text("catalog", "shm", "Catalog")
dbutils.widgets.text("schema", "contract", "Schema")
dbutils.widgets.text("endpoint_name", "contract_vs", "Vector Search Endpoint")
dbutils.widgets.text("parsed_table", "parsed", "Parsed Table Name")
dbutils.widgets.text("vs_char_limit", "8000", "Vector Search Character Limit")

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
endpoint_name = dbutils.widgets.get("endpoint_name")
parsed_table = dbutils.widgets.get("parsed_table")
vs_char_limit = dbutils.widgets.get("vs_char_limit")

In [None]:
%pip install databricks-vectorsearch
%restart_python

In [None]:
from databricks.vector_search.client import VectorSearchClient
client = VectorSearchClient()

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
endpoint_name = dbutils.widgets.get("endpoint_name")
parsed_table = dbutils.widgets.get("parsed_table")
vs_char_limit = dbutils.widgets.get("vs_char_limit")

# Tables we want to index for vector search
tables = {
  "flat": {
    'columns_to_sync': ['path', 'vendor_name', 'file_name', 'vendor_folder_paths', 'preamble'],
    'vs_col': 'preamble'
  },
  "doc_info": {
    'columns_to_sync': ['path', 'combined_doc_info'],
    'vs_col': 'combined_doc_info'
  }, 
  "sections": {
    'columns_to_sync': ['path', 'section_id', 'section_title', 'combined_text'],
    'vs_col': 'combined_text'
  }
}

---
## 3A: Create Sections Table

This splits each parsed document into sections based on section headers. Each row is one section of one document, with the section title and combined text. This is mainly useful for agentic flows where the agent can search specific sections.

In [None]:
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.' || :schema || '.sections') (
    path STRING,
    section_id INT,
    section_title STRING,
    text STRING,
    combined_text STRING
)

In [None]:
MERGE INTO IDENTIFIER(:catalog || '.' || :schema || '.sections') AS target
USING (
  WITH exploded_elements AS (
    SELECT
      p.path,
      element:id::STRING AS element_id,
      element:type::STRING AS element_type,
      element:content::STRING AS element_content
    FROM (SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :parsed_table)) AS p,
    LATERAL explode(cast(p.parsed:document:elements AS ARRAY<VARIANT>)) AS t(element)
    WHERE element:id IS NOT NULL
  ),
  section_headers AS (
    SELECT
      path,
      element_id,
      element_content,
      ROW_NUMBER() OVER (PARTITION BY path ORDER BY element_id::INT) AS mono_section_id
    FROM exploded_elements
    WHERE element_type = 'section_header'
  ),
  section_tracking AS (
    SELECT
      e.path,
      e.element_id,
      e.element_type,
      e.element_content,
      sh.mono_section_id AS section_id,
      sh.element_content AS section_title
    FROM exploded_elements e
    LEFT JOIN LATERAL (
      SELECT mono_section_id, element_content
      FROM section_headers sh
      WHERE sh.path = e.path AND sh.element_id::INT <= e.element_id::INT
      ORDER BY sh.element_id::INT DESC
      LIMIT 1
    ) sh ON TRUE
  )
  SELECT
    path,
    section_id,
    section_title,
    CONCAT_WS('\n', COLLECT_LIST(element_content)) AS text,
    CONCAT_WS('\n', section_title, COLLECT_LIST(element_content)) AS combined_text
  FROM section_tracking
  WHERE element_type = 'text' 
    AND section_id IS NOT NULL
  GROUP BY path, section_id, section_title
) AS source
ON target.path = source.path AND target.section_id = source.section_id
WHEN MATCHED THEN
  UPDATE SET
    section_title = source.section_title,
    text = source.text,
    combined_text = source.combined_text
WHEN NOT MATCHED THEN
  INSERT (path, section_id, section_title, text, combined_text)
  VALUES (source.path, source.section_id, source.section_title, source.text, source.combined_text)

---
## 3B: Set Up Vector Search

This creates a Vector Search endpoint (if needed) and indexes three tables:
- `flat` (indexed on preamble) -- for finding documents with similar openings
- `doc_info` (indexed on combined_doc_info) -- for finding documents with similar metadata
- `sections` (indexed on combined_text) -- for finding specific sections across documents

If the indexes already exist, they will be synced instead of recreated.

In [None]:
# Create endpoint if it doesn't exist
endpoints = [x['name'] for x in client.list_endpoints()['endpoints']]
if not any(ep == endpoint_name for ep in endpoints):
    client.create_endpoint(
        name=endpoint_name,
        endpoint_type="STANDARD"
    )

# Enable change data feed on all source tables
for tbl in tables.keys():
    spark.sql(f"""
        ALTER TABLE IDENTIFIER('{catalog}.{schema}.{tbl}') 
        SET TBLPROPERTIES (delta.enableChangeDataFeed = true)
    """)

# Create or sync vector search indexes
for tbl in tables.keys():
    try:
        index = client.create_delta_sync_index(
            endpoint_name=endpoint_name,
            source_table_name=f"{catalog}.{schema}.{tbl}",
            index_name=f"{catalog}.{schema}.{tbl}_index",
            pipeline_type="TRIGGERED",
            primary_key="path",
            columns_to_sync=tables[tbl]['columns_to_sync'],
            embedding_source_column=tables[tbl]['vs_col'],
            embedding_model_endpoint_name="databricks-gte-large-en",
        )
        print(f'Creating index for {tbl}')
    except Exception as e:
        print(e)
        index = client.get_index(index_name=f"{catalog}.{schema}.{tbl}_index")
        index.sync()
        print(f'Syncing index for {tbl}')

### Test vector search

Quick sanity check to make sure the indexes work.

In [None]:
SELECT * 
FROM vector_search(
  index => :catalog || '.' || :schema || '.flat_index',
  query_text => 'Contract No. 1885-16859 Amendment No. 1',
  query_type => 'HYBRID',
  num_results => 5
)

---
## 3C: Assemble Context

This is the key step that brings everything together. For each contract, it assembles:

- **doc_info** -- the contract's own extracted metadata
- **preamble** -- the contract's first ~100 words
- **vs_results** -- doc_info from the top 5 semantically similar documents (via vector search)
- **vs_preamble_results** -- preambles from the top 5 semantically similar documents
- **other_preambles / other_doc_infos** -- info from other documents in the same vendor folder

The folder-level documents are especially important: amendments are almost always in the same folder as their master agreement.

In [None]:
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.' || :schema || '.assembled') (
  path STRING,
  doc_info STRING,
  preamble STRING,
  vs_results STRING,
  vs_preamble_results STRING,
  other_folder_docs BIGINT,
  other_preambles STRING,
  other_doc_infos STRING
);

MERGE INTO IDENTIFIER(:catalog || '.' || :schema || '.assembled') AS target
USING (
  WITH doc_info_search_results AS (
    SELECT
      doc.path,
      concat_ws('\n \n ##### \n \n', collect_list(vs.combined_doc_info)) AS vs_combined_doc_info
    FROM IDENTIFIER(:catalog || '.' || :schema || '.doc_info') doc
    CROSS JOIN LATERAL (
      SELECT combined_doc_info
      FROM vector_search(
        index => :catalog || '.' || :schema || '.doc_info_index',
        query_text => substring(doc.combined_doc_info, 0, :vs_char_limit),
        query_type => 'HYBRID',
        num_results => 5
      )
    ) vs
    GROUP BY doc.path
  ),
  preamble_search_results AS (
    SELECT
      f.path,
      concat_ws('\n \n ##### \n \n', collect_list(vs_preamble.preamble)) AS vs_preamble_results
    FROM IDENTIFIER(:catalog || '.' || :schema || '.flat') f
    CROSS JOIN LATERAL (
      SELECT preamble
      FROM vector_search(
        index => :catalog || '.' || :schema || '.flat_index',
        query_text => substring(f.preamble, 0, :vs_char_limit),
        query_type => 'HYBRID',
        num_results => 5
      )
    ) vs_preamble
    GROUP BY f.path
  )
  SELECT
    base.path,
    base.preamble,
    doc.combined_doc_info AS doc_info,
    vs.vs_combined_doc_info AS vs_results,
    preamble_vs.vs_preamble_results AS vs_preamble_results,
    count(DISTINCT rel.other_preamble) AS other_folder_docs,
    concat_ws('\n \n ##### \n \n', collect_list(DISTINCT rel.other_preamble)) AS other_preambles,
    concat_ws('\n \n ##### \n \n', collect_list(DISTINCT rel.other_combined_doc_info)) AS other_doc_infos
  FROM IDENTIFIER(:catalog || '.' || :schema || '.flat') base
  LEFT JOIN (
    SELECT
      o.other_path AS path,
      t.preamble AS other_preamble,
      doc.combined_doc_info AS other_combined_doc_info
    FROM (
      SELECT
        path,
        vendor_folder_path,
        preamble
      FROM IDENTIFIER(:catalog || '.' || :schema || '.flat')
      LATERAL VIEW OUTER explode(vendor_folder_paths) AS vendor_folder_path
    ) t
    JOIN (
      SELECT
        path AS other_path,
        vendor_folder_path AS other_vendor_folder_path
      FROM IDENTIFIER(:catalog || '.' || :schema || '.flat')
      LATERAL VIEW OUTER explode(vendor_folder_paths) AS vendor_folder_path
    ) o
      ON t.vendor_folder_path = o.other_vendor_folder_path
      AND t.path <> o.other_path
    LEFT JOIN IDENTIFIER(:catalog || '.' || :schema || '.doc_info') doc
      ON t.path = doc.path
  ) rel
  ON base.path = rel.path
  LEFT JOIN IDENTIFIER(:catalog || '.' || :schema || '.doc_info') doc
  ON base.path = doc.path
  LEFT JOIN doc_info_search_results vs
  ON base.path = vs.path
  LEFT JOIN preamble_search_results preamble_vs
  ON base.path = preamble_vs.path
  GROUP BY base.path, base.preamble, doc.combined_doc_info, vs.vs_combined_doc_info, preamble_vs.vs_preamble_results
) AS source
ON target.path = source.path
WHEN NOT MATCHED THEN INSERT (
  path,
  doc_info,
  preamble,
  vs_results,
  vs_preamble_results,
  other_folder_docs,
  other_preambles,
  other_doc_infos
) VALUES (
  source.path,
  source.doc_info,
  source.preamble,
  source.vs_results,
  source.vs_preamble_results,
  source.other_folder_docs,
  source.other_preambles,
  source.other_doc_infos
)

In [None]:
-- Check: how many assembled records?
SELECT COUNT(*) as total_assembled FROM IDENTIFIER(:catalog || '.' || :schema || '.assembled')