<div style="background-color: #ADD8E6; border: 1px solid gray; padding: 3px">
    This notebook consists of 2 agentic workflows:
        <h3>Data Generation Workflow</h3>
        <li><b>Data Augmentation</b>: Augments the provided image dataset.</li>
        <h3>Validation Workflow</h3>
        <li><b>Image Validator</b>: Identifies whether a valid driver's license exists in the given image.</li>
        <li><b>Data Extractor</b>: Extracts relevant metadata from the image.</li>
        <li><b>Application Validator</b>: Given the extracted metadata associated with the application, uses a set of predefined rules to validate the driver's license application.</li>
</div>

In [None]:
##############################################################################
# Imports
##############################################################################
# import pysqlite3 as sqlite3
# import sys
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import operator
from typing import Annotated, TypedDict, List, Optional, Literal
from langgraph.graph import StateGraph, END, START
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.graph.message import add_messages
from PIL import Image
import pytesseract
import io
import json
from datetime import datetime
import re
import os
import requests
from flow_extensions import CustomLLMMultimodalBlock, CustomDeleteColumnsBlock
from io import BytesIO
from dotenv import load_dotenv
import mimetypes
import base64
from urllib.parse import urlparse
from PIL import Image
from io import BytesIO
import requests
load_dotenv()
import traceback
from openai import OpenAI
import instructor
from pydantic import BaseModel, Field, TypeAdapter
from more_itertools import chunked
import utils
from datasets import load_dataset, DatasetDict, Dataset
from sdg_hub.core.flow import FlowRegistry, Flow
import pandas as pd
from typing import Any, Optional
import asyncio
import nest_asyncio
nest_asyncio.apply()

In [None]:
##############################################################################
# State Definitions
##############################################################################

class LicenseState(TypedDict):
    """Enhanced state definition"""

    github_repo: str

    github_subfolder: str

    vision_model: str

    analytics_data: pd.DataFrame

    messages: Annotated[List[BaseMessage], add_messages]

In [None]:
##############################################################################
# Structured Output
##############################################################################

class DriversLicenseField(BaseModel):
    
    value: str = Field("", description="Name of field")
    
    missing_error_reason: str = Field("", description="Reason for missing field")

    is_valid: Optional[bool] = Field(None, description="Indicates whether the license is valid.")
    
    application_value: str = Field("", description="Value of the corresponding field in the application")
    
    invalid_error_reason: str = Field("", description="Reason for invalid field")

    

class DriversLicenseMetadata(BaseModel):

    application_id: str = Field("", description="Unique identifier")

    model: str = Field("", description="Name of LLM used to generate metadata")
    
    name: DriversLicenseField = Field(description="Name of driver's license owner")
    
    date_of_birth: DriversLicenseField = Field(description="Date of birth of driver's license owner")
    
    expiration_date: DriversLicenseField = Field(description="Expiration date of driver's license")
    
    state_issued: DriversLicenseField = Field(description="State where the license was issued")
    
    issuance_date: DriversLicenseField = Field(description="Date when the license was issued")

    photo_orientation: DriversLicenseField = Field("", description="The skew of the license in the photo")

class LicenseApplication(BaseModel):

    application_id: str = Field(description="Unique identifier")

    image_path: str = Field(description="Image path")

    application_data: dict = Field(description="Submitted application data")
    

In [None]:
##############################################################################
# Tools
##############################################################################

def image_to_base64(image_path, encode_image_bytes=False):
    """Transforms image at provided local path or URL into base64-encoded representation."""
    
    def is_valid_http_url(input_path):
        """Returns whether or not the input is a valid URL."""
    
        parsed_url = urlparse(input_path)
    
        is_http_url = all([parsed_url.scheme in ('http', 'https'), parsed_url.netloc])
    
        return is_http_url
        
    
    try:
        start_time = datetime.now()
    
        mime_type, _ = mimetypes.guess_type(image_path)
    
        if "image" in mime_type:
    
            if encode_image_bytes:
    
                if is_valid_http_url(image_path):
        
                    response = requests.get(image_path)
    
                    response.raise_for_status()
    
                    stream_to_read = response.content
    
                else:
                    
                    with open(image_path, "rb") as image_file:
                        
                        stream_to_read = image_file.read()
    
                img = base64.b64encode(stream_to_read).decode("utf-8")
        
                return f"data:{mime_type};base64,{img}"
    
            else:
    
                return image_path
                
    
        else:
    
            raise Exception(f"Mime type {mime_type} not supported")
        
        processing_time = (datetime.now() - start_time).total_seconds()
        
        print(f"Image loaded: time: {processing_time:.2f}s")
            
    except Exception as e:
        
        error_msg = f"Image loading error: {str(e)}"
    
        print(f"- {error_msg}")

        traceback.print_exc()
    
    return None

In [None]:
##############################################################################
# Nodes
##############################################################################

def load_and_convert_from_github(state: LicenseState) -> LicenseState:
    """Converts the images in the provided github repo folder into a paired representation of images and their applications."""
    print("✓ STEP 1: Conversion of Github Repositories into a paired image-to-application representation")
    
    print("="*60)
    
    ############################################################################################################################
    # Group the files by ID
    ############################################################################################################################
    applications = utils.group_files_by_id(state["github_repo"], state["github_subfolder"])

    ############################################################################################################################
    # Use static rules to transform the data into specifically identified submitted fields
    ############################################################################################################################
    submitted_fields = utils.convert_to_submitted_fields(applications, "patterns.json")

    ############################################################################################################################
    # Represent the files as a dataframe
    ############################################################################################################################
    df = pd.DataFrame(submitted_fields)
    
    return {
        "analytics_data" : df,
        "messages": [AIMessage(content="License applications retrieved from github and converted.")]
    }

def get_extracted_data(state: LicenseState) -> LicenseState:
    """Extracts the drivers license data from the images in the provided git repository."""

    print("✓ STEP 2: AI-Powered Data Extraction")
    
    print("="*60)

    ############################################################################################################################
    # Retrieve the application data
    ############################################################################################################################
    df = state["analytics_data"]

    ############################################################################################################################
    # Add the current model
    ############################################################################################################################
    model_prefix = state["vision_model"]
    
    df["model_name"] = os.getenv(f"{model_prefix}_LLM_NAME")

    ############################################################################################################################
    # Build a dataset for sdg_hub
    ############################################################################################################################
    dataset = Dataset.from_pandas(df)

    ############################################################################################################################
    # Run sdg_hub
    ############################################################################################################################
    flow_path = "flows/drivers_license_validation/flow.yaml"
    
    flow = Flow.from_yaml(flow_path)
    
    flow.set_model_config(
        model=os.getenv(f"{model_prefix}_LLM_NAME"),
        api_base=os.getenv(f"{model_prefix}_LLM_BASE"),
        api_key=os.getenv(f"{model_prefix}_LLM_KEY"),
        temperature=0,
        
        max_tokens = 8192,
        response_format={"type": "json_object"},
        top_k=1,
    )

    converted_dataset = flow.generate(dataset, max_concurrency=10)
    
    converted_df = converted_dataset.to_pandas()

    state["analytics_data"] = converted_df

    return {
        "analytics_data" : converted_df,
        "messages": [AIMessage(content="License data extracted.")]
    }


def validate_extracted_data(state: LicenseState) -> LicenseState:
    """Node: Validate extracted data"""
    print("✓ STEP 3: Data Validation")
    
    print("="*60)
    
    # TODO
    
    return state


def compile_final_result(state: LicenseState) -> LicenseState:
    """Node: Generate reports"""
    print("✓ STEP 4: Compiling Final Results")
    
    print("="*60)
    
    # TODO

    ############################################################################################################################
    # Serialize the application data to a file
    ############################################################################################################################]
    
    return {"messages": [HumanMessage(content="DONE")]}

In [None]:
##############################################################################
# Graph
##############################################################################
def create_license_extraction_graph():
    """Create the LangGraph workflow"""
    workflow = StateGraph(LicenseState)
    
    # Add all nodes
    workflow.add_node("load_from_github", load_and_convert_from_github)
    
    workflow.add_node("extract_data", get_extracted_data)
    
    workflow.add_node("validate_data", validate_extracted_data)
    
    workflow.add_node("compile_result", compile_final_result)
    
    # Define edges
    workflow.add_edge(START, "load_from_github")
    
    workflow.add_edge("load_from_github", "extract_data")
    
    workflow.add_edge("extract_data", "validate_data")
    
    workflow.add_edge("validate_data", "compile_result")
    
    workflow.add_edge("compile_result", END)
    
    return workflow.compile()

### Execute Code Translation Flow
Execute the flow!

In [None]:
##############################################################################
# Execute the Flow
##############################################################################

def extract_license_info(github_repo: str, github_subfolder: str, vision_model: str) -> pd.DataFrame:
    """
    Main extraction function
    """

    start_time = datetime.now()

    print(f"Started pipeline: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

    initial_state = {
        "github_repo": github_repo,

        "github_subfolder": github_subfolder,

        "vision_model": model_prefix,

        "messages": [HumanMessage(f"Validate driver's license data for {model_prefix} model...")]
        
    }

    graph = create_license_extraction_graph()
            
    # final_state = app.invoke(initial_state)

    config = {"configurable": {"thread_id": 42, "recursion_limit": 5}}

    stream = graph.stream(initial_state, config, stream_mode="values")

    for event in stream:
        last_message = event['messages'][-1]
        
        last_message.pretty_print()

        if "DONE" in last_message.content:

            return event["analytics_data"]

In [None]:
vision_models = ["LLAMASCOUT4", "GEMMA27B", "GEMMA12B"]

target_dir = "reports"
    
datasets = []

##############################################################################
# Generate Extracted Data and Evaluations
##############################################################################
for model_prefix in vision_models:

    try:

        df = extract_license_info("https://github.com/agapebondservant/dla_poc", "notebooks/data2", model_prefix)

        datasets.append(df)

    except Exception as e:

        print(f"Error occurred while processing with model {model_prefix}: e")

        traceback.print_exc()

In [None]:
##############################################################################
# Generate Reports
##############################################################################
combined_df = pd.concat(datasets)
# transformed_df = utils.data_report_prep(combined_df)
utils.generate_csv_report(combined_df, target_dir)
utils.generate_jsonl_report(combined_df, target_dir)
# transformed_df