# Importing necessary  libraries

In [1]:
import pandas as pd
import time
import os
import time
from typing import Union
import concurrent.futures
import threading
import numpy as np
import cv2
import pickle

from docsumo_image_util.parse.ocr.google import read_data, read_everything
from docsumo_image_util.parse.pdf2img import PdfImages
from pydantic import BaseModel, Field, field_validator
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
from docllm.caller.providers import Provider, ProviderFactory
from docllm.config import CallerConfig
from docllm.handlers import YamlFileHandler
from docllm.parser.input_parser import get_line_text
from docllm.parser.output_parser import parse_json
from sklearn.metrics import classification_report
from docllm.prompt.render import JinjaRenderer
from loguru import logger
from tenacity import retry, stop_after_attempt
from dotenv import load_dotenv
load_dotenv() 

# Create a lock for token updates
token_lock = threading.Lock()

#here update the path as per your path of application credentials.json file. You can ask the file with the seniors if not available
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/media/veracrypt1/GAC.json"
# from app.config import config_by_name




# Creating the classes for three different batch types
### The first class created is Classifier which has classify and _classify function. This class is inherited by LLMClassifier class. Understand the code thorougly to know the flow 
### The classes for three different batch types viz. NonBatched, Strict Batched and Spread Batched are created which further inherits the LLM classifier.


In [2]:
class Classifer(ABC):
    """Base Classifier for Document Splitting."""

    @abstractmethod
    def _classify(self, *args, **kwargs):
        """
        Classifies a list of dataframe to it respective labels
        """
        raise NotImplementedError

    def classify(self, df: List[Dict]) -> Dict:
        """Main function where the df from process_image/request is passed into.

        Args:
            df (List): A list containing dictionary of OCR data

        Returns:
            Dict: A dictionary with pages, classifications and classified as a key.
        """
        if not df:
            print("[DEBUG] No data, returning empty.")
            return {"pages": [], "classifications": [], "classified": False}

        # # convert dataframe of whole document to page datraframe
        # df_pages = pd.DataFrame(df)
        # merged_output={}
        # strict_batched_df = self.split_strict_df(df_pages)
        # for single_batch in strict_batched_df:
        #     _pages = list(set(single_batch.page))
        #     df_lst = [single_batch[single_batch.page == idx] for idx in _pages]
        #     out = self._classify(df_lst)
        #     # Initialize keys if not present
        #     merged_output.setdefault("pages", []).extend(out.get("pages", []))
        #     merged_output.setdefault("classifications", []).extend(out.get("classifications", []))
        #     merged_output["is classified"] = merged_output.get("is classified", True) and out.get("is classified", True)
          
        # return merged_output

        # convert dataframe of whole document to page datraframe 
        df_pages = pd.DataFrame(df)
        _pages = list(set(df_pages.page))
        logger.info(f"The pages in the parsed document is {_pages}")
        df_lst = [df_pages[df_pages.page == idx] for idx in _pages]
        doc_type_lst, pg_chunk_lst, total_input_tokens, total_output_tokens, elapsed = self._classify(df_lst)
        return doc_type_lst, pg_chunk_lst, total_input_tokens, total_output_tokens, elapsed
    



In [3]:
#this section of code is use to validate the fileds

def str_to_bool(v):
    return v.lower() in ("yes", "true", "t", "1")


class DoctypeConfig(BaseModel):
    doc_type_id: str = Field(..., description="Doctype ID")
    doc_type_title: str = Field(..., description="Doctype Title")
    prompt: str = Field("", description="Prompt for the doctype")


class RequestConfig(BaseModel):
    df: List[Dict] = Field(..., description="OCRed Dataframes for the file")
    doc_type_details: List[DoctypeConfig] = Field(..., description="Document Type Details")
    auto_classification_prompt: str = Field("", description="Prompt setup in auto_classify doctype")
    auto_classify: bool = Field(False, description="Auto classify flag")

    @field_validator("auto_classify")
    def convert_2_bool(cls, v):
        return str_to_bool(str(v))

In [4]:
DEFAULT_RES = {"pages": [], "classifications": [], "classified": True}

class LLMClassifier(Classifer):
    os.environ["LLM_PROVIDER"] = "openrouter"
    #pass your api_key here
    # LLM_PROVIDER = ProviderFactory.create_provider(provider_name= os.environ.get("LLM_PROVIDER", "openai"), api_key=os.getenv("OPENROUTER_API_KEY"))
    
    def __init__(self, request_config: RequestConfig, model_name:str) -> None:
        self.LLM_PROVIDER = ProviderFactory.create_provider(provider_name= os.environ.get("LLM_PROVIDER", "openai"), api_key=os.getenv("OPENROUTER_API_KEY"))
        self.request_config = request_config
        self.doc_types = [doc_type.doc_type_title for doc_type in request_config.doc_type_details]
        self.doc_types_mapper = {
            doc_type.doc_type_title: doc_type.doc_type_id for doc_type in request_config.doc_type_details
        }
        self.model_name = model_name
    def build_prompts(
        self, text_lst: List[str], product_prompt: str, is_auto_classify: bool, doc_types: Union[str, List[str]]
    ) -> Tuple[str, str]:
        config_name = "auto_classify_config" if is_auto_classify else "custom_doctype_config"
        config_path = f"objs/{config_name}.yaml"
        yaml_contents = YamlFileHandler(config_path).handle()
        system_prompt = yaml_contents.get("system_prompt", "")
        user_prompt = yaml_contents.get("user_prompt", "")
        final_user_prompt = JinjaRenderer().render(
            user_prompt,
            context={"ocr_text": text_lst, "doc_types": doc_types, "prompt": product_prompt},
        )
        return system_prompt, final_user_prompt

    # @retry(stop=stop_after_attempt(10))
    def call_llm(self, llm_provider: Provider, user_prompt: str, system_prompt: str) -> Dict[str, str]:
        res = llm_provider.call(
            system_prompt=system_prompt,
            user_prompt=user_prompt,

            #write your model_name here or you can pass as variable which is shown in the section below
            caller_config=CallerConfig(model_name = self.model_name, allow_model=True),
        )
        if isinstance(res, str):
            output = res
        else:
            output = res.text
        return parse_json(output) , llm_provider.prompt_tokens, llm_provider.completion_tokens
    
    def create_chunk(self, classifier_out: List[Tuple]) -> Tuple[List, List]:
        """Creates chunk based on the classification output. It combines same doctype pages
        that are present adjancent to each other.

        Args:
            classifier_out (List[Tuple]): Each item in the list is the output from the
                `app.services.classify.classifier.DocClassifierV2._classify_df`.

        Returns:
            Tuple[List, List]: Tuple containing chunk lists for pages and doc_types.

        Example:
            Case 1:
            _______
            If the classifier_out:
            [
             ('form_1040___start', 1), ('form_1040', 2), ('form_1040___start', 3),
             ('form_1040', 4), ('form_8995a___start', 5), ('form_8995a', 6), ('acord25', 7), ('auto_classify', 8),
             ('acord25', 9), ('auto_classify', 10)
            ]

            `create_chunk` output is:
            (
              [
               [1, 2, 3, 4], [5, 6], [7], [8], [9], [10]
              ],
              [
               ['form_1040___start', 'form_1040', 'form_1040___start', 'form_1040'],
               ['form_8995a___start', 'form_8995a'],
               ['acord25'],
               ['auto_classify'],
               ['acord25'], ['auto_classify']
              ]
            )

            Case 2:
            _______
            [
             ('form_1040___start', 1), ('form_1040', 2),
             ('form_1040___start', 3),('form_1040', 4),
             ('form_8995a___start', 5), ('form_8995a', 6),
             ('form_1040___start', 7),('form_1040', 8),
            ]

            Output
            _______
            (
             [[1, 2, 3, 4], [5, 6], [7, 8]],
             [['form_1040___start', 'form_1040', 'form_1040___start', 'form_1040'],
              ['form_8995a___start', 'form_8995a'],
              ['form_1040___start', 'form_1040']
             ]
            )
        """
        n_pages = len(classifier_out)
        last_doc_type = classifier_out[0][0]

        page_chunk = []
        page_chunk_lst = []
        doc_types = []
        doc_types_lst = []
        start_keyword = "___start"
        chunked = 0  # flag to make sure if we miss out pages to chunk.

        for doc_type, page_num in classifier_out:
            if not last_doc_type == doc_type:
                if (last_doc_type + start_keyword == doc_type) or (doc_type + start_keyword == last_doc_type):
                    page_chunk.append(page_num)
                    doc_types.append(doc_type)
                    continue
                page_chunk = sorted(list(set(page_chunk)))
                chunked += len(page_chunk)
                page_chunk_lst.append(page_chunk)
                doc_types_lst.append(doc_types)
                page_chunk = [page_num]
                doc_types = [doc_type]
            else:
                page_chunk.append(page_num)
                doc_types.append(doc_type)
            last_doc_type = doc_type

        # check the flag and append the remaining pages chunks.
        if chunked != n_pages and page_chunk:
            page_chunk_lst.append(page_chunk)
            doc_types_lst.append(doc_types)

        # almost a dead code due to the above `chunked` flag
        # If all pages is auto_classify/delete or single doc_type
        if not page_chunk_lst:
            page_chunk_lst.append(list(range(1, n_pages + 1)))
            doc_types_lst.append([doc_type] * n_pages)

        last_chunk_page = page_chunk_lst[-1][-1]
        if last_chunk_page != n_pages:
            page_chunk = list(range(last_chunk_page + 1, n_pages + 1))
            doc_types_lst.append([classifier_out[-1][0]] * len(page_chunk))
            # page_chunk = sorted(list(set([last_chunk_page + 1, n_pages])))
            page_chunk_lst.append(page_chunk)
        return page_chunk_lst, doc_types_lst

    def _handle_auto_classify(self, text_lst: List[str]) -> Tuple[List, List]:
        """Handling split when upload doctype is auto classify"""
        
        system_prompt, user_prompt = self.build_prompts(
            text_lst, self.request_config.auto_classification_prompt, is_auto_classify=True, doc_types=self.doc_types
        )

        output, input_tokens, output_tokens = self.call_llm(self.LLM_PROVIDER, user_prompt, system_prompt)
        if not output:
            logger.info(f"[INFO] [Handle AutoClassify] LLM output is empty || llm_output: {output}")
            return [], []
        
        classification_output = [(el.get("label"), el.get("page")) for el in output.get("page_classifications")]
        logger.info(f"[INFO] Classification Output: {classification_output}")
        pg_chunk_lst, doc_types_lst = self.create_chunk(classification_output)
        logger.info(f"[INFO] After create_chunk, Page Chunks: {pg_chunk_lst}, Doc types: {doc_types_lst}")
        return pg_chunk_lst, doc_types_lst, input_tokens, output_tokens

    def _handle_custom_doctype(self, text_lst: List[str], product_prompt: str, doc_type: str) -> List[str]:
        logger.info(f"WE are in the handle custom doctype part")
        """Handle split when upload doctype is other than auto classify"""
        system_prompt, user_prompt = self.build_prompts(
            text_lst, product_prompt, is_auto_classify=False, doc_types=doc_type
        )
        output = self.call_llm(LLMClassifier.LLM_PROVIDER, user_prompt, system_prompt)
        if not output:
            logger.info(f"[INFO] [Handle Custom Doctype] LLM output is empty || llm_output: {output}")
            return []
        doc_types_lst = [[el.get("label") for el in output.get("page_classifications")]]
        pg_chunk_lst = [[el.get("page") for el in output.get("page_classifications")]]
        # logger.info(f"[INFO] After custom doctype handling, Page Chunks: {pg_chunk_lst}, Doc types: {doc_types_lst}")
        return pg_chunk_lst, doc_types_lst
    
    def process_doc_type_classification(self, text_lst, single_batch_text, df_lst):
        if self.request_config.auto_classify:  # Auto-classification logic
            pg_chunk_lst, doc_types_lst, input_tokens_obtained, output_tokens_obtained = self._handle_auto_classify(single_batch_text)
            logger.info(f"In the process doctype classification the total input and output tokens are:{input_tokens_obtained} and {output_tokens_obtained}")
            if not pg_chunk_lst:
                return DEFAULT_RES           
            return pg_chunk_lst, doc_types_lst, input_tokens_obtained, output_tokens_obtained

        else:  # Single doctype logic
            doc_type_config = self.request_config.doc_type_details[0]
            if not doc_type_config.prompt:
                logger.info(f"[INFO] Prompt not found for {doc_type_config.doc_type_title} || Doctype config: {doc_type_config}")
                return DEFAULT_RES

            if len(text_lst) == 1:
                logger.info(f"[INFO] Skipping custom doctype split as it has single page || len(text_lst): {len(text_lst)}")
                return DEFAULT_RES

            default_pg_chunk_lst = [[df.page.unique().tolist()[0] + 1 for df in df_lst]]
            default_doc_types_lst = [[doc_type_config.doc_type_title] * len(default_pg_chunk_lst)]

            llm_pg_chunk_lst, llm_doc_types_lst = self._handle_custom_doctype(
                text_lst, doc_type_config.prompt, self.doc_types[0]
            )

            return llm_pg_chunk_lst or default_pg_chunk_lst, llm_doc_types_lst or default_doc_types_lst

In [5]:
class LLMNonBatched(LLMClassifier):
     
    def __init__(self, request_config: Union[Dict, RequestConfig], model_name:str) -> None:
        # request_config["doc_type_details"] = [
        #     DoctypeConfig(**item) for item in request_config["doc_type_details"]
        # ]
        cfg = RequestConfig(**request_config) if isinstance(request_config, dict) else request_config
        super().__init__(cfg,model_name)  # This calls the parent class's __init__


    def _classify(self, df_lst: List[pd.DataFrame]) -> Tuple[List[str], List, int, int]:

        text_lst = [get_line_text(df) if not df.empty else "EMPTY PAGE" for df in df_lst]
        start = time.time()
        pg_chunk_lst,doc_types_lst, total_input_tokens, total_output_tokens = self.process_doc_type_classification(text_lst,text_lst,df_lst)
        elapsed=time.time()-start
        logger.info(f"[INFO] The Total time taken for the non batched data is: {elapsed}")
        # pg_chunk_lst, doc_types_lst = LLMClassifier.process_chunk(pg_chunk_lst, doc_types_lst, merge_autoclassify=False)
        # logger.info(f"[INFO] After process_chunk, Page Chunks: {pg_chunk_lst}, Doc types: {doc_types_lst}")
        return doc_types_lst, pg_chunk_lst, total_input_tokens, total_output_tokens, elapsed

In [6]:

class LLMStrictBatched(LLMClassifier):
    def __init__(self, request_config: dict, model_name: str) -> None:
        # Initialize RequestConfig from dict if needed
        cfg = RequestConfig(**request_config) if isinstance(request_config, dict) else request_config
        super().__init__(cfg, model_name)

    def split_strict_df(self, text_lst: List[str], batch_size: int = 25) -> List[List[str]]:
        """Split list of texts into strictly sized batches."""
        return [text_lst[i : i + batch_size] for i in range(0, len(text_lst), batch_size)]

    def _classify(self, df_lst: List[pd.DataFrame]) -> Tuple[List[str], List, int, int]:
        # 1. Extract text from DataFrames
        text_lst = [get_line_text(df) if not df.empty else "EMPTY PAGE" for df in df_lst]
        logger.info(f"The total length of obtained text list is {len(df_lst)} ")
        # 2. Batch the text
        batched_texts = self.split_strict_df(text_lst)
        num_batches = len(batched_texts)

        # Placeholder for ordered results and token counts
        results: List[tuple] = [None] * num_batches
        total_input_tokens = 0
        total_output_tokens = 0

        def run_batch(idx: int, batch: List[str]):
            """Wrapper to process one batch and tag with its index."""
            try:
                start = time.time()
                pg_chunks, doc_types, input_tokens, output_tokens = self.process_doc_type_classification(
                    text_lst, batch, df_lst
                )
                elapsed = time.time() - start
                logger.debug(f"Batch {idx} processed in {elapsed:.2f}s | Input tokens: {input_tokens}, Output tokens: {output_tokens}")
                return idx, pg_chunks, doc_types, input_tokens, output_tokens
            except Exception as e:
                logger.error(f"Error processing batch {idx}: {e}", exc_info=True)
                return idx, [], [], 0, 0  # Fallback to zero tokens

        # 3. Execute batches in parallel
        overall_start = time.time()
        with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
            futures = [executor.submit(run_batch, i, batch) for i, batch in enumerate(batched_texts)]
            for future in concurrent.futures.as_completed(futures):
                idx, pg_chunks, doc_types, input_tokens, output_tokens = future.result()
                results[idx] = (pg_chunks, doc_types)
                with token_lock:
                    total_input_tokens = input_tokens
                    total_output_tokens = output_tokens
        elapsed = time.time()-overall_start
        logger.info(f"Processed {num_batches} batches in {elapsed:.2f}s")
        logger.info(f"Total input tokens: {total_input_tokens}, Total output tokens: {total_output_tokens}")

        # 4. Flatten results in correct order
        all_pg_chunks: List = []
        all_doc_types: List = []
        for pg_chunks, doc_types in results:
            all_pg_chunks.extend(pg_chunks)
            all_doc_types.extend(doc_types)

        logger.info(f"After merging chunks: pages={len(all_pg_chunks)}, types={len(all_doc_types)}")

        #Returning tokens along with results
        return all_doc_types, all_pg_chunks, total_input_tokens, total_output_tokens, elapsed

In [7]:
class LLMSpreadBatch(LLMClassifier):

    def __init__(self, request_config: dict, model_name:str):
        request_config = RequestConfig(**request_config)
        super().__init__(request_config,model_name)

    def split_spread_text(self, text_lst, batch_size=25, step=20):
        batched_texts = []

        for i in range(0, len(text_lst), step):
            end_batch_index = i + batch_size

            if(end_batch_index > len(text_lst) - 1):
                end_batch_index=len(text_lst)-1
                batch = text_lst[i : end_batch_index+1]
                batched_texts.append(batch)
                break

            batch = text_lst[i : end_batch_index]
            batched_texts.append(batch)

        #uncomment the below line if you want to see the length of total pages and batched texts
        # logger.debug(f"The length of text_list  is:{len(text_lst)}")    
        # logger.debug(f"The length of batched text is:{len(batched_texts)}")
        return batched_texts

    
    def _classify(self, df_lst: List[pd.DataFrame]) -> Tuple[List[str], List, int, int]:
        # Step 1: Extract text from all DataFrames
        text_lst = [get_line_text(df) if not df.empty else "EMPTY PAGE" for df in df_lst]

        start = time.time()
        
        # Step 2: Batch the text list
        batched_texts = self.split_spread_text(text_lst)
        num_batches = len(batched_texts)

        # Placeholder for ordered results and token counters
        results: List[tuple] = [None] * num_batches
        total_input_tokens = 0
        total_output_tokens = 0

        def run_batch(idx: int, batch: List[str]):
            """Wrapper to process one batch and tag with its index."""
            try:
                start = time.time()
                pg_chunks, doc_types, input_tokens, output_tokens = self.process_doc_type_classification(
                    text_lst, batch, df_lst
                )
                elapsed = time.time() - start
                logger.debug(f"Batch {idx} processed in {elapsed:.2f}s | Input Tokens: {input_tokens}, Output Tokens: {output_tokens}")
                return idx, pg_chunks, doc_types, input_tokens, output_tokens
            except Exception as e:
                logger.error(f"Error processing batch {idx}: {e}", exc_info=True)
                return idx, [], [], 0, 0

        # Step 3: Execute batches in parallel
        overall_start = time.time()
        with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
            futures = [executor.submit(run_batch, i, batch) for i, batch in enumerate(batched_texts)]
            for future in concurrent.futures.as_completed(futures):
                idx, pg_chunks, doc_types, input_tokens, output_tokens = future.result()
                results[idx] = (pg_chunks, doc_types)
                with token_lock:
                    total_input_tokens = input_tokens
                    total_output_tokens = output_tokens
        elapsed=time.time()-overall_start
        logger.info(f"Processed {num_batches} batches in {elapsed:.2f}s")
        logger.info(f"Total input tokens: {total_input_tokens}, Total output tokens: {total_output_tokens}")

        # Step 4: Flatten results in correct order
        all_pg_chunks: List = []
        all_doc_types: List = []
        for pg_chunks, doc_types in results:
            all_pg_chunks.extend(pg_chunks)
            all_doc_types.extend(doc_types)

        logger.info(f"After merging chunks: pages={len(all_pg_chunks)}, types={len(all_doc_types)}")

        #Return all outputs and token counts
        return all_doc_types, all_pg_chunks, total_input_tokens, total_output_tokens, elapsed        

### Here we initialize the variables like which baching are we going to do and what model we are going to use. You can change the model name and do the execution for different LLM models

#### !Don't forget to update the test_folder_path that locates the folder where you test data is located

In [60]:
#update this as per the batching type you would like to perform 
is_spread_batched:bool = True
is_strict_batched:bool = True
is_non_batched: bool = True

#Change your model name here for different LLM calls
model_name="google/gemini-2.5-flash-preview"

BASE_DIR = os.getcwd()

#update this as the folder path where your dataset is located. for eg: I have my folder named test where my dataset is located
test_folder_path = os.path.join(BASE_DIR, "test_real")

COST_MAPPING = {
        "meta-llama/llama-3-70b-instruct": [8.1e-7, 8.1e-7],
        "openai/gpt-4o": [0.000005, 0.000015],
        "openai/gpt-4-turbo": [0.00001, 0.00003],
        "anthropic/claude-3-opus": [0.000015, 0.000075],
        "google/gemini-2.5-flash-preview": [1.5e-7, 6e-7],
        "google/gemini-flash-1.5": [2.5e-7, 7.5e-7],
        "google/gemini-2.0-flash-001": [1e-7, 4e-7],
        "qwen/qwen-2-72b-instruct": [5.9e-7, 7.9e-7],
        "anthropic/claude-3.5-sonnet": [0.000003, 0.000015],
        "openai/gpt-4.1-mini": [0.0000004,0.0000016],
}

#### This section performs parallel processing and pdfs are parsed using google ocr provider independently

In [61]:

#this is the actual implementation of parsing pdf using google ocr provider
def get_google_ocr_raw_data(file_path: str) -> Tuple[List[np.array], List[pd.DataFrame]]:
    df_list = []
    image_list = []
    if file_path.endswith("pdf"):
        images = PdfImages(file_path)
    else:
        images = [cv2.imread(file_path)]
    for index, image in enumerate(images):
        (df, cdf), (image, angle) = read_everything(image)
        df_list.append(df)
        image_list.append(image)
    return image_list, df_list

def run_processing_pdfs(file_path):
    logger.info(f"\n[INFO] Processing: {os.path.basename(file_path)}")
    image_list, df_list = get_google_ocr_raw_data(file_path)
    for i, df in enumerate(df_list):
        df['page'] = i
    final_df = pd.concat(df_list).reset_index(drop=True)
    return final_df  # <-- return only result

#execution of pdfs in parallel as pdfs are independent of each other and this can significantly improve the time 
def process_all_pdfs(test_folder_path):
    pdf_files = [os.path.join(test_folder_path, f) for f in os.listdir(test_folder_path) if f.endswith(".pdf")]
    results_dict = {}

    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        future_to_file = {executor.submit(run_processing_pdfs, file): file for file in pdf_files}
        for future in concurrent.futures.as_completed(future_to_file):
            file = future_to_file[future]
            result = future.result()
            if result is not None  and not result.empty: 
                results_dict[os.path.basename(file)] = result
    return results_dict

#### To run the parallel processing uncomment the previous lines of codes and the parsing can take time depending on the number of data present in you test folder

In [62]:
# only run this if you have't prepared a pickle file else you can comment it out
#all_results=process_all_pdfs(test_folder_path)

#### Uncomment the line if you want to store the result of parsing in the pickle file so you dont need to parse the pdfs again and again

In [63]:
# # this is also the same case

# import pickle

# with open("all_results_real.pkl", "wb") as f:
#     pickle.dump(all_results, f)


In [None]:
#change the below path to locate the path where your pkl file is stored
#loading the stored pickle file
with open("tmp/all_results_real.pkl", "rb") as f:
    all_results = pickle.load(f)

In [65]:
all_results.keys()

dict_keys(['merged_output_8.pdf', 'merged_output_2.pdf', 'merged_output_14.pdf', 'merged_output_17.pdf', 'merged_output_10.pdf', 'merged_output_5.pdf', 'merged_output_1.pdf', 'merged_output_9.pdf', 'merged_output_6.pdf', 'merged_output_19.pdf', 'merged_output_16.pdf', 'merged_output_12.pdf', 'merged_output_15.pdf', 'merged_output_3.pdf', 'merged_output_11.pdf', 'merged_output_18.pdf', 'merged_output_7.pdf', 'merged_output_20.pdf', 'merged_output_4.pdf', 'merged_output_13.pdf'])

In [66]:
len(all_results.keys())

20

## Funtion to pass the the payload template and calling the batching types to get the result of document classification

In [67]:
def run_classification_on_file(parsed_result,file_path, model_name):
    try:
        logger.info(f"\n[INFO] Processing: {os.path.basename(file_path)}")
        final_df = parsed_result
        df_digital_dict = final_df.to_dict(orient="records")
        

        payload_template = {
            "auto_classify": True,
            "auto_classification_prompt": "Split each document type into separate pages.",
            "df": df_digital_dict,
            "doc_type_details": [
                {"doc_type_id": "acord25", "doc_type_title": "Acord 25", "prompt": "Review every field extracted"},
                {"doc_type_id": "invoice", "doc_type_title": "Invoice", "prompt": "Review every field extracted"},
                {"doc_type_id": "form1040", "doc_type_title": "Form 1040", "prompt": "Review every field extracted."},
                {"doc_type_id": "form1040a", "doc_type_title": "Form 1040 A", "prompt": "Review every field extracted"},
                {"doc_type_id": "form1040b", "doc_type_title": "Form 1040 B", "prompt": "Review every field extracted"},
                {"doc_type_id": "form1040c", "doc_type_title": "Form 1040 C", "prompt": "Review every field extracted"},
                {"doc_type_id": "form1040d", "doc_type_title": "Form 1040 D", "prompt": "Review every field extracted"},
                {"doc_type_id": "form1040e", "doc_type_title": "Form 1040 E", "prompt": "Review every field extracted"},
                {"doc_type_id": "w9", "doc_type_title": "W9", "prompt": "Review every field extracted."}
            ]
        }

        results = {}
        if is_non_batched:
            try:
                doc_types, chunks, in_tok, out_tok, total_time = LLMNonBatched(payload_template, model_name).classify(df=df_digital_dict)
                results["non_batched"] = [item for sub in doc_types for item in sub]
                results["total_input_tokens_non_batched"] = in_tok
                results["total_output_tokens_non_batched"] = out_tok
                results["total_time_non_batched"] = total_time
                ir, or_ = COST_MAPPING[model_name]
                results["cost_non_batched"] = (in_tok * ir) + (out_tok * or_)
            except Exception as e:
                results["non_batched"] = None
                results["total_input_tokens_non_batched"] = None
                results["total_output_tokens_non_batched"] = None
                results["cost_non_batched"] = None
                logger.error(f"Non-batched classification failed: {e}")


        if is_strict_batched:
            doc_types, chunks, in_tok, out_tok, total_time = LLMStrictBatched(payload_template, model_name).classify(df=df_digital_dict)
            results["strict_batched"] = [item for sub in doc_types for item in sub]
            results["total_input_tokens_strict_batched"] = in_tok
            results["total_output_tokens_strict_batched"] = out_tok
            results["total_time_strict_batched"] = total_time
            ir, or_ = COST_MAPPING[model_name]
            results["cost_strict_batched"] = (in_tok * ir) + (out_tok * or_)

        if is_spread_batched:
            doc_types, chunks, in_tok, out_tok, total_time = LLMSpreadBatch(payload_template, model_name).classify(df=df_digital_dict)
            results["spread_batched"] = [item for sub in doc_types for item in sub]
            results["total_input_tokens_spread_batched"] = in_tok
            results["total_output_tokens_spread_batched"] = out_tok
            results["total_time_spread_batched"] = total_time
            ir, or_ = COST_MAPPING[model_name]
            results["cost_spread_batched"] = (in_tok * ir) + (out_tok * or_)

        return (os.path.basename(file_path), results)

    except Exception as e:
        logger.error(f"[ERROR] Failed on {file_path}: {e}")
        return (file_path, None)





### This section is used for removing the overlapping of classification file that we get during spread batching. The overlappig should be removed to get the cleaned data

In [68]:
# def remove_overlap_classifications(doc_type_lst_spread_batched, window_size=25, stride=5):
#     # Flatten the entire list of classifications
#     step = window_size - stride  # How much to move the window after each batch
#     batch_idx = 0
#     final_array = []
#     i = 0

#     while i < len(doc_type_lst_spread_batched):
#         # If we've reached the end of the current window
#         if i == batch_idx + window_size:
#             # Skip the overlap and move to the next batch
#             i += stride 
#             batch_idx += step + 5 
#             continue
        
#         # Otherwise, add the current item to final array
#         final_array.append(doc_type_lst_spread_batched[i])
#         i += 1

#     return final_array



def remove_overlap_classifications(doc_type_lst_spread_batched, window_size=25, stride=5):
    step = window_size - stride  # Effective jump after overlap
    final_array = []
    mismatch_results = 0
    i = 0
    batch_num = 0
    #uncomment the below line if you want to see the total pages that has been in spread batched type
    #logger.debug(f"The total doc type spread length is : {len(doc_type_lst_spread_batched)}")
    while i < len(doc_type_lst_spread_batched):
        # Define current window range
        window_start = i
        window_end = i + window_size
        # if(window_end > len(doc_type_lst_spread_batched)):
        #     window_end=len(doc_type_lst_spread_batched)
        current_window = doc_type_lst_spread_batched[window_start:window_end]

        # Add non-overlapping part to final array
        if i == 0:
            final_array.extend(current_window)
        else:
            final_array.extend(current_window[stride:])  # skip overlap

            # Compare overlap between previous and current window
            prev_overlap = doc_type_lst_spread_batched[window_start-stride : window_start]
            next_overlap = doc_type_lst_spread_batched[window_start : window_start+stride]

            for j, (prev, nxt) in enumerate(zip(prev_overlap, next_overlap)):
                if prev != nxt:
                    mismatch_results+=1
            

        # Move to next window
        i += window_size
        batch_num += 1
    logger.info(f"[INFO] The total mismatched results in spread batching are: {mismatch_results}")

    return final_array, mismatch_results


In [69]:
def remove_overlap_pages(doc_type_lst_spread_batched, window_size=25, stride=5):
    # Flatten the entire list of pages

    step = window_size - stride  # How much to move the window after each batch
    batch_idx = 0
    final_array = []
    i = 0

    while i + window_size <= len(doc_type_lst_spread_batched):  # Ensure we don't go out of bounds
        # Add the indices of the pages in this batch (excluding the overlap)
        first_batch_indices = list(range(i, i + step))  # Get the indices of non-overlapping part
        final_array.extend(first_batch_indices)
        
        # Move the window forward by the full window size
        i += window_size

    # Add remaining indices if any
    final_array.extend(range(i, len(doc_type_lst_spread_batched)))

    return final_array



### This is the final palce where we call our classification to get the actual classification values

In [70]:
# Setup
output_dir = "./realgemini25_batching_and_splitting"
os.makedirs(output_dir, exist_ok=True)
cost_log_path = os.path.join(output_dir, "cost_log.txt")
time_log_path = os.path.join(output_dir, "time_log.txt")


# Log cost helper
def log_cost(batch_type, filename, input_tokens, output_tokens, cost):
    with open(cost_log_path, "a") as log_file:
        log_file.write(
            f"{model_name} | {filename} | {batch_type} | "
            f"Input Tokens: {input_tokens}, Output Tokens: {output_tokens}, "
            f"Cost: ${cost:.6f}\n"
        )

def log_time(batch_type, filename, total_time):
    with open(time_log_path, "a") as log_file:
        log_file.write(
            f"{model_name} | {filename} | {batch_type} | Time Taken: {total_time:.2f} seconds\n"
        )


# Save result helper
def save_results(batch_type, result, pdf_basename, filename):
    input_key = f"total_input_tokens_{batch_type}"
    output_key = f"total_output_tokens_{batch_type}"
    cost_key = f"cost_{batch_type}"
    output_key_main = batch_type

    if batch_type in result:
        logger.info(f"  {batch_type.replace('_', ' ').title()} - Input Tokens: {result[input_key]}, "
              f"Output Tokens: {result[output_key]}, Cost: ${result[cost_key]:.6f}")
        
        log_cost(batch_type.title().replace("_", "-"), filename,
                 result[input_key], result[output_key], result[cost_key])
        
        time_key = f"total_time_{batch_type}"
        if time_key in result and result[time_key] is not None:
            log_time(batch_type.title().replace("_", "-"), filename, result[time_key])


        output_path = os.path.join(output_dir, f"{pdf_basename}_{batch_type}.txt")
        with open(output_path, "w") as f:
            f.write("\n".join(result[output_key_main]))

        # Special handling for spread_batched
        if batch_type == "spread_batched":
            spread_cleaned, unmatching_result = remove_overlap_classifications(result[output_key_main])
            result[output_key_main] = spread_cleaned
            result["unmatching_result"] = unmatching_result

            # Overwrite cleaned version
            with open(output_path, "w") as f:
                f.write("\n".join(spread_cleaned))

            # Save unmatched entries
            if unmatching_result:
                unmatched_output_path = os.path.join(output_dir, "all_unmatched.txt")
                with open(unmatched_output_path, "a") as f:
                    f.write(f"\n=== {filename} ===\n")
                    f.write("\n".join(str(unmatching_result)))
                    f.write("\n" + "=" * 50 + "\n")

# Main loop
my_results = {}
ground_truth_files = []

for filename in os.listdir(test_folder_path):
    if filename.endswith(".txt") and not filename.endswith("_log.txt"):
        ground_truth_files.append(os.path.splitext(filename)[0])
    
    if filename.endswith(".pdf"):
        full_path = os.path.join(test_folder_path, filename)
        pdf_basename = os.path.splitext(filename)[0]
        parsed_result = all_results.get(filename)

        try:
            obtained_results = run_classification_on_file(parsed_result, full_path, model_name)
            file_name, result = obtained_results 
            
            if is_non_batched and "non_batched" in result:
                save_results("non_batched", result, pdf_basename, file_name)

            if is_strict_batched and "strict_batched" in result:
                save_results("strict_batched", result, pdf_basename, file_name)

            if is_spread_batched and "spread_batched" in result:
                save_results("spread_batched", result, pdf_basename, file_name)

        except Exception as e:
            logger.error(f"[ERROR] Failed on {filename}: {e}")


[32m2025-05-04 23:21:55.063[0m | [1mINFO    [0m | [36m__main__[0m:[36mrun_classification_on_file[0m:[36m3[0m - [1m
[INFO] Processing: merged_output_10.pdf[0m


[32m2025-05-04 23:21:55.969[0m | [1mINFO    [0m | [36m__main__[0m:[36mclassify[0m:[36m42[0m - [1mThe pages in the parsed document is [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182][0m
[32m2025-05-04 23:21:58.277[0m | [1mINFO   

In [71]:
# def process_model(model_name):
#     output_dir = "./parallel_batching_and_splitting"
#     os.makedirs(output_dir, exist_ok=True)
#     cost_log_path = os.path.join(output_dir, "cost_log.txt")
#     time_log_path = os.path.join(output_dir, "time_log.txt")

#     def log_cost(batch_type, filename, input_tokens, output_tokens, cost):
#         with open(cost_log_path, "a") as log_file:
#             log_file.write(
#                 f"{model_name} | {filename} | {batch_type} | "
#                 f"Input Tokens: {input_tokens}, Output Tokens: {output_tokens}, "
#                 f"Cost: ${cost:.6f}\n"
#             )

#     def log_time(batch_type, filename, total_time):
#         with open(time_log_path, "a") as log_file:
#             log_file.write(
#                 f"{model_name} | {filename} | {batch_type} | Time Taken: {total_time:.2f} seconds\n"
#             )

#     def save_results(batch_type, result, pdf_basename, filename):
#         input_key = f"total_input_tokens_{batch_type}"
#         output_key = f"total_output_tokens_{batch_type}"
#         cost_key = f"cost_{batch_type}"
#         output_key_main = batch_type

#         if batch_type in result:
#             logger.info(f"{model_name} | {filename} | {batch_type} - "
#                         f"Input: {result[input_key]}, Output: {result[output_key]}, "
#                         f"Cost: ${result[cost_key]:.6f}")
            
#             log_cost(batch_type.title().replace("_", "-"), filename,
#                      result[input_key], result[output_key], result[cost_key])
            
#             time_key = f"total_time_{batch_type}"
#             if time_key in result and result[time_key] is not None:
#                 log_time(batch_type.title().replace("_", "-"), filename, result[time_key])

#             output_path = os.path.join(output_dir, f"{pdf_basename}_{batch_type}.txt")
#             with open(output_path, "w") as f:
#                 f.write("\n".join(result[output_key_main]))

#             if batch_type == "spread_batched":
#                 spread_cleaned, unmatching_result = remove_overlap_classifications(result[output_key_main])
#                 result[output_key_main] = spread_cleaned
#                 result["unmatching_result"] = unmatching_result

#                 with open(output_path, "w") as f:
#                     f.write("\n".join(spread_cleaned))

#                 if unmatching_result:
#                     unmatched_output_path = os.path.join(
#                         output_dir, f"all_unmatched_{model_name.replace('/', '_')}.txt"
#                     )
#                     with open(unmatched_output_path, "a") as f:
#                         f.write(f"\n=== {filename} ===\n")
#                         f.write("\n".join(str(unmatching_result)))
#                         f.write("\n" + "=" * 50 + "\n")


#     for filename in os.listdir(test_folder_path):
#         if filename.endswith(".pdf"):
#             full_path = os.path.join(test_folder_path, filename)
#             pdf_basename = os.path.splitext(filename)[0]
#             parsed_result = all_results.get(filename)

#             try:
#                 file_name, result = run_classification_on_file(parsed_result, full_path, model_name)

#                 if is_non_batched and "non_batched" in result:
#                     save_results("non_batched", result, pdf_basename, file_name)

#                 if is_strict_batched and "strict_batched" in result:
#                     save_results("strict_batched", result, pdf_basename, file_name)

#                 if is_spread_batched and "spread_batched" in result:
#                     save_results("spread_batched", result, pdf_basename, file_name)

#             except Exception as e:
#                 logger.error(f"[ERROR] {model_name} | {filename} failed: {e}")

# # List of LLM models
# model_names = [
#         "google/gemini-2.5-flash-preview",
#         "google/gemini-2.0-flash-001",
#         "openai/gpt-4.1-mini"
# ]

# # Run in parallel using threads
# with concurrent.futures.ThreadPoolExecutor(max_workers=len(model_names)) as executor:
#     executor.map(process_model, model_names)


### Calculating the accuracy of classification for different batching types

In [72]:
# Directories
predicted_dir = "./realgemini25_batching_and_splitting"
ground_truth_dir = test_folder_path  # Already defined earlier
model_name_clean = model_name.replace("/", "_")
# Batch types to evaluate
batching_modes = ["non_batched", "strict_batched", "spread_batched"]

ground_truth_files = [
    f for f in os.listdir(ground_truth_dir)
    if f.endswith(".txt") and not f.endswith("_log.txt")
]

# Prepare accumulators for full evaluation and misclassified logs
all_y_true = {mode: [] for mode in batching_modes}
all_y_pred = {mode: [] for mode in batching_modes}
misclassified_logs = {mode: [] for mode in batching_modes}

for gt_file in ground_truth_files:
    gt_basename = os.path.splitext(gt_file)[0]
    gt_path = os.path.join(ground_truth_dir, gt_file)

    with open(gt_path, "r") as f:
        y_true = [line.strip() for line in f if line.strip()]

    print(f"\nEvaluating: {gt_file}")

    for batch_type in batching_modes:
        pred_filename = f"{gt_basename}_{batch_type}.txt"
        pred_path = os.path.join(predicted_dir, pred_filename)

        if not os.path.exists(pred_path):
            print(f"Missing prediction file for {batch_type}: {pred_filename}")
            continue

        with open(pred_path, "r") as f:
            y_pred = [line.strip() for line in f if line.strip()]

        if len(y_pred) != len(y_true):
            print(f"Length mismatch for {batch_type} on {gt_file} | y_true: {len(y_true)}, y_pred: {len(y_pred)}")
            continue

        # Collect all predictions
        all_y_true[batch_type].extend(y_true)
        all_y_pred[batch_type].extend(y_pred)

        # Collect misclassified examples
        for idx, (yt, yp) in enumerate(zip(y_true, y_pred)):
            if yt != yp:
                misclassified_logs[batch_type].append(
                    f"{gt_basename} [Line {idx+1}]: TRUE: {yt} --> PRED: {yp}"
                )

        # Print report
        report = classification_report(y_true, y_pred, zero_division=0)
        print(f"\nClassification Report for: {batch_type}")
        print(report)

# --- Overall reports ---
print("\n===== OVERALL BATCH-TYPE REPORTS =====\n")

for batch_type in batching_modes:
    if all_y_true[batch_type] and all_y_pred[batch_type]:
        overall_report = classification_report(all_y_true[batch_type], all_y_pred[batch_type], zero_division=0, digits=4)
        print(f"\n=== Overall Report for {batch_type} ===\n")
        print(overall_report)

        # Save report
        overall_report_path = os.path.join(predicted_dir, f"overall_{batch_type}_{model_name_clean}_report.txt")
        with open(overall_report_path, "w") as f:
            f.write(overall_report)

        # Save misclassified
        misclassified_path = os.path.join(predicted_dir, f"misclassified_{batch_type}_{model_name_clean}.txt")
        with open(misclassified_path, "w") as f:
            f.write("\n".join(misclassified_logs[batch_type]))
    else:
        print(f"\nNo data to generate report for {batch_type}.")




Evaluating: merged_output_19.txt

Classification Report for: non_batched
              precision    recall  f1-score   support

    Acord 25       1.00      1.00      1.00        12
   Form 1040       1.00      1.00      1.00        30
 Form 1040 A       1.00      1.00      1.00        15
 Form 1040 B       1.00      1.00      1.00        15
 Form 1040 C       1.00      1.00      1.00        30
 Form 1040 D       1.00      1.00      1.00        30
 Form 1040 E       1.00      1.00      1.00        30
     Invoice       1.00      1.00      1.00        15
          W9       1.00      1.00      1.00        15

    accuracy                           1.00       192
   macro avg       1.00      1.00      1.00       192
weighted avg       1.00      1.00      1.00       192


Classification Report for: strict_batched
              precision    recall  f1-score   support

    Acord 25       1.00      1.00      1.00        12
   Form 1040       1.00      1.00      1.00        30
 Form 1040 A   