## Long Form Summarization using `text-bison-32k` and `DocAI`

#### Imports 

In [1]:
from vertexai.preview.language_models import TextGenerationModel
from google.api_core.client_options import ClientOptions
from concurrent.futures import ThreadPoolExecutor
from google.cloud import documentai
from pypdf import PdfWriter
from pypdf import PdfReader 
from tqdm import tqdm
import tiktoken
import vertexai
import requests
import logging
import json
import time
import yaml
import os

##### Setup logging 

In [2]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

#### Essentials 

In [3]:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = './../credentials/vai-key.json'
access_token = !gcloud auth print-access-token

In [4]:
PROJECT_ID = 'arun-genai-bb'
LOCATION = 'us-central1'
MODEL_NAME = 'text-bison-32k@latest'
ENCODING_NAME = 'cl100k_base'
CONTEXT_LENGTH = 32000  # text-bison-32k
STREAMING_API_URL = f'https://us-central1-aiplatform.googleapis.com/ui/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_NAME}:serverStreamingPredict'
DOCAI_PROCESSOR_NAME = 'projects/390991481152/locations/us/processors/ad9557a5be49204e'  # copy from notebook 00
vertexai.init(project=PROJECT_ID, location=LOCATION)

In [5]:
client_options = ClientOptions(api_endpoint=f'us-documentai.googleapis.com')
docai_client = documentai.DocumentProcessorServiceClient(client_options=client_options)

In [6]:
encoder = tiktoken.get_encoding(ENCODING_NAME)
logger.info(f'Using encoder=={encoder.name}')

Using encoder==cl100k_base


In [7]:
model = TextGenerationModel.from_pretrained(MODEL_NAME)
logger.info(f'Using model=={model._model_id}')

Using model==text-bison-32k@latest


#### Use Google DocumentAI to process input PDF

##### Break PDF into smaller PDFs for OCR

In [8]:
LOCAL_INPUT_DIR = './DATA/INPUT'
LOCAL_OUTPUT_DIR = './DATA/OUTPUT'
FILE_NAME = 'file-2'

In [9]:
reader = PdfReader(f'{LOCAL_INPUT_DIR}/{FILE_NAME}.pdf') 
pages = {}

for i, page in enumerate(reader.pages):
    pages[i] = page

In [12]:
n = len(reader.pages)
d = 15  # docai has a current constraint of 15 pages per document 
for i in range(0, n, d):
    writer = PdfWriter()
    for j in range(i, i+d):
        if j < n:
            writer.add_page(pages[j])
    os.makedirs(f'{LOCAL_INPUT_DIR}/{FILE_NAME}/PARTS/', exist_ok=True)
    with open(f'{LOCAL_INPUT_DIR}/{FILE_NAME}/PARTS/{FILE_NAME}_{i+1}-{i+d}.pdf', 'wb') as f:
        writer.write(f)

In [None]:
def layout_to_text(layout: documentai.Document.Page.Layout, text: str) -> str:
    """
    Document AI identifies text in different parts of the document by their
    offsets in the entirety of the document's text. This function converts
    offsets to a string.
    """
    # If a text segment spans several lines, it will be stored in different text segments.
    return ''.join(text[int(segment.start_index): int(segment.end_index)] for segment in layout.text_anchor.text_segments)

In [None]:
def get_file_paths(dir_name: str) -> list:
    file_paths = []
    for file_name in os.listdir(dir_name):
        if os.path.isfile(os.path.join(dir_name, file_name)):
            file_path = os.path.join(dir_name, file_name)
            file_paths.append(file_path)
    return file_paths

In [None]:
def ocr_docai(file_path: str) -> dict:
    pages_map = {}

    with open(file_path, 'rb') as f:
        pdf = f.read()
        raw_document = documentai.RawDocument(content=pdf, mime_type='application/pdf')
        request = documentai.ProcessRequest(name=DOCAI_PROCESSOR_NAME, raw_document=raw_document)
        response = docai_client.process_document(request=request)
        text = response.document.text
        file_name = file_path.split('/')[-1]
        page_number = int(file_name.split('.')[0].split('-')[-1])
        for page in response.document.pages:
            page_text = []
            for paragraph in page.paragraphs:
                paragraph_text = layout_to_text(paragraph.layout, text)
                page_text.append(paragraph_text)
            pages_map[page_number] = ''.join(page_text)
            page_number += 1
    return pages_map

In [None]:
%%time 

input_dir = f'./DATA/INPUT/{FILE_NAME}/PARTS/'
file_paths = get_file_paths(input_dir)
    
pages_map_list = []
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:  
    pages_map_list = list(tqdm(executor.map(ocr_docai, file_paths)))

merged_dict = {k: v for d in pages_map_list for k, v in d.items()}   
sorted_pages_map = dict(sorted(merged_dict.items()))

pages = []
for _, page_text in sorted_pages_map.items():
    pages.append(page_text)

Save concatenated pages as txt for later use (if needed)

In [None]:
extracted_pages = ''.join(pages)
with open(f'{LOCAL_OUTPUT_DIR}/{FILE_NAME}/{FILE_NAME}.txt', 'wb') as out:
    out.write(extracted_pages)

In [None]:
def get_total_tokens(contexts: list) -> int:
    total_tokens = 0
    for context in contexts:
        n_tokens = len(encoder.encode(context))
        total_tokens += n_tokens 
    return total_tokens

In [None]:
total_tokens = get_total_tokens([extracted_pages])
logger.info(f'Total tokens in the input doc = {total_tokens}')

In [None]:
def get_max_tokens_per_page(contexts: list) -> list:
    max_tokens_per_page = 0
    for context in contexts:
        n_tokens = len(encoder.encode(context))
        if n_tokens > max_tokens_per_page:
            max_tokens_per_page = n_tokens
    return max_tokens_per_page

#### Map Reduce 1

In [None]:
def get_summary_via_streaming_api(chunk: str) -> str:
    access_token = !gcloud auth print-access-token
    prompt = f'You are a Financial Regulations & Derivatives Expert. Summarize the following information into five brief sentences in English, capturing the essential details.\n\n{chunk}'
    headers = {
        "Authorization": f"Bearer {access_token[0]}",
        "Content-Type": "application/json; charset=utf-8"
    }
    
    data = {
        "inputs": [
            {
                "struct_val": {
                    "prompt": {
                        "string_val": [prompt]
                    }
                }
            }
        ],
        "parameters": {
            "struct_val": {
                "temperature": {"float_val": 0.0},
                "maxOutputTokens": {"int_val": 256},
                "topK": {"int_val": 40},
                "topP": {"float_val": 0.8}
            }
        }
    }
    response = requests.post(STREAMING_API_URL, headers=headers, json=data)
    content = json.loads(response.content)
    output = []

    for item in content:
        try:
            text = item['outputs'][0]['structVal']['content']['stringVal'][0]
            output.append(text)
        except Exception as e:
            logger.error(f'Content error => {content}')
    output = ''.join(output)
    return output


In [None]:
%%time 

CONTEXTS_PER_CALL = 5  # process 5 pages per API call
MAX_OUTPUT_TOKENS = 256

def reduce(contexts: list) -> list:
    partitions = []
    max_input_tokens = CONTEXT_LENGTH - MAX_OUTPUT_TOKENS
    logger.info(f'Max input tokens = {max_input_tokens}')
    max_tokens_per_page = get_max_tokens_per_page(contexts)
    logger.info(f'Max tokens = {max_tokens_per_page}')
    logger.info(f'Processing {CONTEXTS_PER_CALL} contexts per API call')
    
    for i in range(0, len(contexts), CONTEXTS_PER_CALL):
        partitions.append(contexts[i: i+CONTEXTS_PER_CALL])

    chunks = []
    for partition in partitions:
        chunks.append('\n'.join(partition))

    reduced_contexts = []

    # max_workers can result in running over quota limits for invocation | current limit for text bison is 60/min
    # for our experiments, we set max_workers=4 cores without any limit breach
    with ThreadPoolExecutor(max_workers=4) as executor:  
        reduced_contexts = list(tqdm(executor.map(get_summary_via_streaming_api, chunks),  total=len(chunks)))
    return reduced_contexts


In [None]:
logger.info(f'Number of pages to process = {len(pages)}')
summaries = reduce(pages)
logger.info(f'Number of generated summaries = {len(summaries)}')
n_tokens = get_total_tokens(summaries)
logger.info(f'Total number of tokens in generated summaries = {n_tokens}')

In [None]:
logger.info(summaries[5])
logger.info('-' * 100)
logger.info(summaries[15])
logger.info('-' * 100)
logger.info(summaries[20])

#### Map Reduce 2

In [None]:
def get_summary_via_streaming_api(context: str) -> str:
    access_token = !gcloud auth print-access-token
    prompt = f"""For the context below, create a consolidated refined short summary with the most important pointers only.\n\n{context}\n\nDo not repeat pointers. Breakdown the summary into SECTIONS. Make it crisp and concise."""
    headers = {
        "Authorization": f"Bearer {access_token[0]}",
        "Content-Type": "application/json; charset=utf-8"
    }
    
    data = {
        "inputs": [
            {
                "struct_val": {
                    "prompt": {
                        "string_val": [prompt]
                    }
                }
            }
        ],
        "parameters": {
            "struct_val": {
                "temperature": {"float_val": 0.0},
                "maxOutputTokens": {"int_val": 4096},
                "topK": {"int_val": 40},
                "topP": {"float_val": 0.8}
            }
        }
    }
    response = requests.post(STREAMING_API_URL, headers=headers, json=data)
    content = json.loads(response.content)
    output = []

    for item in content:
        try:
            text = item['outputs'][0]['structVal']['content']['stringVal'][0]
            output.append(text)
        except Exception as e:
            logger.error(f'Content error => {content}')
    output = ''.join(output)
    return output

In [None]:
%%time 

CONTEXTS_PER_CALL = 50  # process 50 summaries per API call

def reduce(contexts: list) -> list:
    partitions = []
    logger.info(f'Processing {CONTEXTS_PER_CALL} contexts per API call')
    
    for i in range(0, len(contexts), CONTEXTS_PER_CALL):
        partitions.append(contexts[i: i+CONTEXTS_PER_CALL])

    chunks = []
    for partition in partitions:
        chunks.append('\n'.join(partition))
    logger.info(len(chunks))

    reduced_contexts = []

    with ThreadPoolExecutor(max_workers=4) as executor:  
        reduced_contexts = list(tqdm(executor.map(get_summary_via_streaming_api, chunks),  total=len(chunks)))
    return reduced_contexts

In [None]:
reduced_summaries = reduce(summaries)

In [None]:
logger.info(f'Total number of summaries after map reduce 2 = {len(reduced_summaries)}')

In [None]:
logger.info(reduced_summaries[0])
logger.info('-' * 100)
logger.info(reduced_summaries[1])
logger.info('-' * 100)
logger.info(reduced_summaries[2])
logger.info('-' * 100)

In [None]:
consolidated_summaries = '\n'.join(reduced_summaries)
logger.info(get_total_tokens([consolidated_summaries]))

In [None]:
logger.info(consolidated_summaries)

#### Final Consolidation

In [None]:
def get_summary_via_streaming_api(context: str) -> str:
    access_token = !gcloud auth print-access-token
    prompt = f"""Given the context below, combine and merge duplicate sections and pointers.\n\n{context}\nAdd SECTIONS and bullets wherever needed. Clean rewrite and re-number sections."""
    headers = {
        "Authorization": f"Bearer {access_token[0]}",
        "Content-Type": "application/json; charset=utf-8"
    }
    
    data = {
        "inputs": [
            {
                "struct_val": {
                    "prompt": {
                        "string_val": [prompt]
                    }
                }
            }
        ],
        "parameters": {
            "struct_val": {
                "temperature": {"float_val": 0.0},
                "maxOutputTokens": {"int_val": 8192},
                "topK": {"int_val": 40},
                "topP": {"float_val": 0.8}
            }
        }
    }
    response = requests.post(STREAMING_API_URL, headers=headers, json=data)
    content = json.loads(response.content)
    output = []

    for item in content:
        try:
            text = item['outputs'][0]['structVal']['content']['stringVal'][0]
            output.append(text)
        except Exception as e:
            logger.error(f'Content error => {content}')
    output = ''.join(output)
    return output

In [None]:
final_summary = get_summary_via_streaming_api(consolidated_summaries)
logger.info(final_summary)