In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Summarize contracts (PDF files) using OCR (Vision API) and LLM (PaLM API)

## Overview

This notebook shows how to perform OCR and summarization using LLM for a large number of contract PDF files in a GCS bucket

#### **Steps**
Using Spark, 
1) It reads a metadata table of the [Contract Understanding Atticus Dataset (CUAD)](https://www.atticusprojectai.org/cuad) from the **public_datasets** dataset located in the [metastore](../gcp_services/README.md) (notebook should be connected with the public metastore if using this specific dataset).    
   This metadata table contains the paths of the pdf files in the bucket.  
   If you want to apply this to a different dataset, you can read the pdf files in your bucket with spark.read.format("binaryFile") (no need of the metastore) - more details [here](../gcp_services/README.md). 
2) It runs OCR using Vision API - it start a series of async operations and then checks its completion status.
3) It calls [Vertex AI PaLM API](https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart#try_text_prompts) to summarize each text page.
4) It saves the output to BigQuery

#### Related content

- [Summarization with Large Documents using LangChain](https://github.com/GoogleCloudPlatform/generative-ai/blob/dev/language/examples/oss-samples/langchain/summarization_with_large_documents_langchain.ipynb)
- [Design summarization prompts](https://cloud.google.com/vertex-ai/docs/generative-ai/text/summarization-prompts)

## Setup

#### Identity and Access Management (IAM)

Make sure the service account running this notebook has the required permissions:

- **Run the notebook**
  - AI Platform Notebooks Service Agent
  - Notebooks Admin
  - Vertex AI Administrator
- **Read tables from Dataproc Metastore**
  - Dataproc Metastore Editor
  - Dataproc Metastore Metadata Editor
  - Dataproc Metastore Metadata User
  - Dataproc Metastore Service Agent
- **Read files from bucket**
  - Storage Object Viewer
- **Run Dataproc jobs**
  - Dataproc Service Agent
  - Dataproc Worker
- **Call Google APIs (PaLM and Vision)**
  - Service Usage Consumer
  - VisionAI Admin
- **BigQuery**
  - BigQuery Data Editor

#### Imports

In [None]:
import os
import sys
import re
import json
import time

from pyspark.sql.functions import udf, col, lit, split, explode, size, avg, count, regexp_replace, collect_list
from pyspark.sql.types import StructType, StructField, StringType, ArrayType

import google.auth
import google.auth.transport.requests
import requests

from google.cloud import storage

#### Authentication

In [None]:
# Get credentials to authenticate with Google APIs
credentials, project_id = google.auth.default()
auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)

#### Setup Spark Session

In [None]:
from pyspark.sql import SparkSession

In [None]:
spark = SparkSession.builder \
    .appName("OCR contract PDF files and summarize") \
    .enableHiveSupport() \
    .getOrCreate()

#### Parameters

In [None]:
# PDFs
input_dataset_table = "public_datasets.cuad_v1"
# Change the maximum number of files you want to consider
limit_files = 5
# OCR
gcs_output_bucket = "gs://dataproc-metastore-public-binaries" # Output bucket where OCR text files will be saved
output_path_prefix = "cuad_v1/output_ocr" # path prefix after bucket name where the folder structure will be created
# BigQuery
output_dataset_bq = "output_dataset" # create the BigQuery dataset beforehand
output_table_bq = "ocr_page_summaries"
bq_temp_bucket_name = "workspaces-bq-temp-bucket-dev"

## Read dataset

#### Read CUAD V1 dataset from metastore

In [None]:
cuad_v1_df = spark.read.table(input_dataset_table).limit(limit_files)

|                path|    modificationTime| length|             content|
|--------------------|--------------------|-------|--------------------|
|gs://dataproc-met...|2023-05-15 20:53:...|3683550|[25 50 44 46 2D 3...|
|gs://dataproc-met...|2023-05-15 20:53:...|2881262|[25 50 44 46 2D 3...|
|gs://dataproc-met...|2023-05-15 20:54:...|1778356|[25 50 44 46 2D 3...|
|gs://dataproc-met...|2023-05-15 20:53:...|1557129|[25 50 44 46 2D 3...|
|gs://dataproc-met...|2023-05-15 20:53:...|1452180|[25 50 44 46 2D 3...|

In [None]:
files_df = cuad_v1_df.select("path").withColumnRenamed("path", "pdf_path")

## Run OCR using Vision API

#### Run OCR - Start async operations

In [None]:
### Spark User Defined Function (UDF)
def perform_ocr(gcs_source_uri, gcs_output_bucket, output_path_prefix):     

    gcs_uri, file_name = os.path.split(gcs_source_uri)
    sub_paths = re.sub(r"gs://[^/]+", "", gcs_uri, 1)
    gcs_destination_uri = gcs_output_bucket + "/" + output_path_prefix + sub_paths + "/" + file_name

    operation = requests.post(
            f"https://us-vision.googleapis.com/v1/projects/{project_id}/locations/us/files:asyncBatchAnnotate",
            headers={'Authorization': 'Bearer %s' % credentials.token,
                     'x-goog-user-project': project_id,
                     'Content-Type': 'application/json; charset=utf-8'},
            json={ "requests":[{
                      "inputConfig": {
                        "gcsSource": {
                          "uri": gcs_source_uri
                        },
                        "mimeType": "application/pdf"
                      },
                      "features": [{
                        "type": "DOCUMENT_TEXT_DETECTION"
                      }],
                      "outputConfig": {
                        "gcsDestination": {
                          "uri": gcs_destination_uri
                        },
                        "batchSize": 100
                      }
                    }]
                }
        ).json()

    return [gcs_destination_uri, operation["name"]]

In [None]:
schema = StructType(
    [
        StructField("ocr_text_path", StringType(), False),
        StructField("vision_api_async_operation_name", StringType(), False)
    ]
)

perform_ocr_udf = udf(perform_ocr, schema)

In [None]:
ocr_async_op_df = files_df.withColumn("ocr_async_op", perform_ocr_udf(files_df["pdf_path"], lit(gcs_output_bucket), lit(output_path_prefix)))

In [None]:
ocr_async_op_df = ocr_async_op_df.withColumn("ocr_text_path", ocr_async_op_df["ocr_async_op"]["ocr_text_path"]) \
                                 .withColumn("vision_api_async_operation_name", ocr_async_op_df["ocr_async_op"]["vision_api_async_operation_name"]) \
                                 .drop("ocr_async_op")

In [None]:
ocr_async_op_df.show(10,50)

|                                          pdf_path|                                     ocr_text_path|                   vision_api_async_operation_name|
|--------------------------------------------------|--------------------------------------------------|--------------------------------------------------|
|gs://dataproc-metastore-public-binaries/cuad_v1...|gs://dataproc-metastore-public-binaries/cuad_v1...|projects/dataproc-workspaces-notebooks/operatio...|
|gs://dataproc-metastore-public-binaries/cuad_v1...|gs://dataproc-metastore-public-binaries/cuad_v1...|projects/dataproc-workspaces-notebooks/operatio...|
|gs://dataproc-metastore-public-binaries/cuad_v1...|gs://dataproc-metastore-public-binaries/cuad_v1...|projects/dataproc-workspaces-notebooks/operatio...|
|gs://dataproc-metastore-public-binaries/cuad_v1...|gs://dataproc-metastore-public-binaries/cuad_v1...|projects/dataproc-workspaces-notebooks/operatio...|
|gs://dataproc-metastore-public-binaries/cuad_v1...|gs://dataproc-metastore-public-binaries/cuad_v1...|projects/dataproc-workspaces-notebooks/operatio...|

In [None]:
ocr_async_op_df.cache()

#### Check status of OCR operations

In [None]:
### Spark User Defined Function (UDF)
def check_completion(operation_name):

    operation = requests.get(
        f"https://us-vision.googleapis.com/v1/{operation_name}",
        headers={'Authorization': 'Bearer %s' % credentials.token,
                 'x-goog-user-project': project_id}
    ).json()
    
    if "done" in operation and operation["done"]:
        if "error" in operation:
            return f'Operation error: code {operation["error"]["code"]} and message {operation["error"]["message"]}'
        else:
            return "done"
    else:
        if ("error" in operation):
            return f"Error getting operation: {operation['error']}"
        else:
            return "processing"

In [None]:
check_completion = udf(check_completion)

In [None]:
time.sleep(45)

In [None]:
check_completion_df = ocr_async_op_df.withColumn("status", check_completion(ocr_async_op_df["vision_api_async_operation_name"]))

#### Get processed OCR text files from bucket

In [None]:
### Spark User Defined Function (UDF)
def read_completed_ocr(path):

    bucket = re.compile(r"gs://[^/]+").findall(path)[0]
    prefix = re.sub(r"gs://[^/]+", "", path, 1)[1:]

    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket[5:])
    blobs = [blob for blob in list(bucket.list_blobs(prefix=prefix)) if not blob.name.endswith('/')]

    try:
        
        json_string = blobs[0].download_as_bytes().decode("utf-8")
        structured_ocr = json.loads(json_string)
        ocr_text = ""
        ocr_pages = []
        for page in structured_ocr['responses']:
            if('fullTextAnnotation' in page):
                fullTextAnnotation = page["fullTextAnnotation"]
                if('text' in fullTextAnnotation):
                    page_text = fullTextAnnotation['text']
                    ocr_text += page_text
                    ocr_pages.append(page_text)
        return ocr_text.strip(), ocr_pages
    
    except Exception as e:
        return "Error getting ocr from pdf: " + str(e)

In [None]:
schema = StructType(
    [
        StructField("ocr_text", StringType(), False),
        StructField("ocr_pages", ArrayType(StringType(), False), False)
    ]
)

read_completed_ocr = udf(read_completed_ocr, schema)

In [None]:
check_completion_df.show()

In [None]:
completion_df = check_completion_df.filter("status == 'done'")

In [None]:
completion_df.cache()

#### Get complete OCR text

In [None]:
fetch_ocr_df = check_completion_df.withColumn("ocr_output", read_completed_ocr(check_completion_df['ocr_text_path']))
ocr_df = fetch_ocr_df.select("pdf_path","ocr_output") \
                     .withColumn("ocr_text", fetch_ocr_df["ocr_output"]["ocr_text"]) \
                     .withColumn("ocr_pages", fetch_ocr_df["ocr_output"]["ocr_pages"]) \
                     .withColumn("number_pages", size(col("ocr_pages"))) \
                     .drop("ocr_output")

In [None]:
ocr_df.show(5,5)

In [None]:
ocr_df.cache()


|  pdf_path|  ocr_text| ocr_pages|number_pages|
|----------|----------|----------|------------|
|gs://da...|THIS AG...|[THIS A...|           8|
|gs://da...|Exhibit...|[Exhibi...|          40|
|gs://da...|Exhibit...|[Exhibi...|          44|
|gs://da...|Exhibit...|[Exhibi...|         100|
|gs://da...|TRANSPO...|[TRANSP...|          25|

## Summarize pages using Palm API

In [None]:
### Spark User Defined Function (UDF)
def summarize_page(page):
    
    def predict_palm(prompt):
        MODEL_ID="text-bison"
        prediction = requests.post(
            f"https://us-central1-aiplatform.googleapis.com/v1/projects/{project_id}/locations/us-central1/publishers/google/models/{MODEL_ID}:predict",
            headers={'Authorization': 'Bearer %s' % credentials.token,
                     'Content-Type': 'application/json'},
            json = {
                      "instances": [
                        { "prompt": prompt}
                      ],
                      "parameters": {
                        "temperature": 0.2,
                        "maxOutputTokens": 256,
                        "topK": 40,
                        "topP": 0.95
                    }
            }
        ).json()
        print(prediction)

        if "predictions" in prediction:
            pred = prediction["predictions"][0]
            if "content" in pred:
                return pred["content"]
        else:
            if "error" in prediction:
                if prediction["error"]["code"] == 429:  # Quota exceeded
                    time.sleep(5)
                    return predict_palm(prompt)
                else:
                    return f"Error getting prediction: {prediction['error']}"

            return f"Error getting predictions"
        
    prompt = f"""Provide a summary with about two sentences for the following article page:
    {page}
    Summary:"""
    
    summary = predict_palm(prompt)
        
    return summary

In [None]:
summarize_page = udf(summarize_page)

In [None]:
ocr_pages_df = ocr_df.select("pdf_path", explode(ocr_df["ocr_pages"]).alias("page"))

In [None]:
summaries_df = ocr_pages_df.withColumn("summary", summarize_page(ocr_pages_df["page"]))

|  pdf_path|      page|   summary|
|----------|----------|----------|
|gs://da...|THIS AG...|This is...|
|gs://da...|Definit...|(b)    ...|
|gs://da...|- 5- fo...|This se...|

In [None]:
summaries_df.show(5,10)

|page|
|----------|
|[THIS AGREEMENT is dated May 3, 2006. NON-COMPETITION AGREEMENT AND RIGHT OF FIRST OFFER BETWEEN: AND: WHEREAS: GLAMIS GOLD LTD., a company incorporated under the laws of the Province of British Columbia, having an office at 310-5190 Neil Road, Reno, Nevada 89502 ("Glamis") WESTERN COPPER CORPORATION, a company incorporated under the laws of the Province of British Columbia, having an office at 2050-1111 West Georgia Street, Vancouver, B.C. V6E 4M3 ("Western Copper") (A) Glamis, Western Copper and Western Silver Corporation ("Western Silver") are parties to an arrangement agreement dated as of February 23, 2006 (the "Arrangement Agreement"), pursuant to which, among other things, Western Copper will acquire certain assets of Western Silver and Glamis will become the sole shareholder of Western Silver and the indirect owner, through Western Silver, of certain corporations and mineral properties in Mexico (the "Arrangement"); and 1162967.3...|

|summary|
|----------|
|[This is a non-competition agreement and right of first offer between Glamis Gold Ltd. and Western Copper Corporation. Glamis Gold Ltd. will not compete with Western Copper Corporation in certain areas of Mexico and will grant Western Copper Corporation a right of first offer with respect to the proposed disposition by Glamis Gold Ltd. of mineral properties or legal interests therein located in Mexico that Glamis Gold Ltd. acquired under the Arrangement.,  (b) the headings in this Agreement are for convenience of reference only and shall not affect its interpretation...|

## Save to BigQuery

In [None]:
agreggated_df = summaries_df.groupby("pdf_path") \
                           .agg(collect_list("summary").alias("page_summary_list"))

In [None]:
agreggated_df.write \
            .format("com.google.cloud.spark.bigquery") \
            .option("table", project_id + ":" + output_dataset_bq + "." + output_table_bq) \
            .option("temporaryGcsBucket", bq_temp_bucket_name) \
            .option("enableListInference", True) \
            .mode("overwrite") \
            .save()