This notebook uses small batches of parse_document to work through a large corpus of documents safely.

In [0]:
dbutils.widgets.text("catalog", "dev", "Catalog")
dbutils.widgets.text("schema", "raw", "Schema")
dbutils.widgets.text("batch_size", "100", "Batch Size")
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
batch_size = int(dbutils.widgets.get("batch_size"))
file_pattern = r'\.(pdf|jpg|jpeg|png|doc|docx|ppt|pptx)$'

In [0]:
df = spark.sql(f"""
    SELECT 
        regexp_extract(file_name, r'\\.([a-zA-Z0-9]+)$', 1) AS file_extension,
        COUNT(*) AS count
    FROM {catalog}.{schema}.bytes
    WHERE file_name RLIKE '{file_pattern}'
    GROUP BY file_extension
    ORDER BY count DESC
""")
display(df)

In [0]:
# Let's limit our files to pdfs only, since engineer suspects a bug with office files
file_pattern = r'\.(pdf|jpg|jpeg|png)$'

In [0]:
%sql
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.' || :schema || '.parsed') (
  path STRING NOT NULL PRIMARY KEY,
  parsed VARIANT
);

In [0]:
import time

def remaining_count():
    return spark.sql(f"""
        SELECT COUNT(*) AS cnt
        FROM {catalog}.{schema}.bytes AS b
        LEFT JOIN {catalog}.{schema}.parsed AS p
          ON b.path = p.path
        WHERE b.file_name RLIKE '{file_pattern}'
          AND p.path IS NULL
    """).collect()[0]["cnt"]

batch = 0
start = time.time()

while True:
    remaining = remaining_count()
    print(f"Batch {batch+1}, remaining: {remaining}")
    if remaining == 0:
        break

    t0 = time.time()
    spark.sql(f"""
        MERGE INTO {catalog}.{schema}.parsed AS target
        USING (
          SELECT 
            b.path,
            AI_PARSE_DOCUMENT(b.content) AS parsed
          FROM (
            SELECT b.path, content
            FROM {catalog}.{schema}.bytes AS b
            LEFT JOIN {catalog}.{schema}.parsed AS p
              ON b.path = p.path
            WHERE b.file_name RLIKE '{file_pattern}'
              AND p.path IS NULL
            ORDER BY b.length
            LIMIT CAST({batch_size} AS INTEGER)
          ) AS b
        ) AS source
        ON target.path = source.path
        WHEN NOT MATCHED THEN INSERT *
    """)
    print(f"Batch {batch} done in {time.time() - t0:.1f}s")
    batch += 1

print(f"All done in {time.time() - start:.1f}s")

In [0]:
%sql
SELECT * 
FROM shm.contract.parsed
WHERE length(concat_ws(
        '\n\n',
        transform(
          try_cast(parsed:document:elements AS ARRAY<VARIANT>),
          element -> try_cast(element:content AS STRING)
        )
      )) < 100

In [0]:
%sql
SELECT * FROM shm.contract.parsed

In [0]:
%sql
SELECT * FROM shm.contract.parsed
WHERE path ILIKE '%border%'

In [0]:
%sql
-- Deletes failed parsing runs
DELETE FROM shm.contract.parsed
WHERE length(concat_ws(
        '\n\n',
        transform(
          try_cast(parsed:document:elements AS ARRAY<VARIANT>),
          element -> try_cast(element:content AS STRING)
        )
      )) < 100

In [0]:
# Now let's focus on office files
file_pattern = r'\.(doc|docx|ppt|pptx)$'

In [0]:
import time

def remaining_count():
    return spark.sql(f"""
        SELECT COUNT(*) AS cnt
        FROM {catalog}.{schema}.bytes AS b
        LEFT JOIN {catalog}.{schema}.parsed AS p
          ON b.path = p.path
        WHERE b.file_name RLIKE '{file_pattern}'
          AND p.path IS NULL
    """).collect()[0]["cnt"]

batch = 0
start = time.time()

while True:
    remaining = remaining_count()
    print(f"Batch {batch+1}, remaining: {remaining}")
    if remaining == 0:
        break

    t0 = time.time()
    spark.sql(f"""
        MERGE INTO {catalog}.{schema}.parsed AS target
        USING (
          SELECT 
            b.path,
            AI_PARSE_DOCUMENT(b.content) AS parsed
          FROM (
            SELECT path, content
            FROM {catalog}.{schema}.bytes AS b
            LEFT JOIN {catalog}.{schema}.parsed AS p
              ON b.path = p.path
            WHERE b.file_name RLIKE '{file_pattern}'
              AND p.path IS NULL
            ORDER BY b.length
            LIMIT CAST({batch_size} AS INTEGER)
          ) AS b
        ) AS source
        ON target.path = source.path
        WHEN NOT MATCHED THEN INSERT *
    """)
    print(f"Batch {batch} done in {time.time() - t0:.1f}s")
    batch += 1

print(f"All done in {time.time() - start:.1f}s")