#Goal: 
This system is meant to be a workflow that triggers upon receving a PDF. For the demo, we want the agents to receive this PDF, determine what it is, then take corresponding action depending on what the original request was. Refer to the reference architecture below to understand how the agents are connected together.

When runing this query, we only see one row of John Smith. We want to update this entry of John Smith with new information we provide from a medical discharge summary about John Smith. 

This demo uses DSPy because of its lightweight and pure python architecture. It makes it easy to organize and integrate with different python library and allows us to modularize our system, without locking us into some framework specific framework

In [0]:
from config import catalog, schema

In [0]:
query = f"SELECT * FROM {catalog}.{schema}.patient_visits WHERE first_name = 'John' AND last_name = 'Smith'"
df_john_smith = spark.sql(query)
display(df_john_smith)

#Reference Architecture

![ref](ref.jpeg)

#Caveat!! 

1. It's a lot in one notebook! To demostrate this as a demo, a lot is compacted into one notebook. Please use software development best practices when organizing your code, especially the more modular aspects of this notebook 
2. Model Serving Endpoints and Vector Search Endpoints and Genie Spaces were all created ahead of this demo. Notebooks 01 to 04 all created this for you. If you have not done so, this will not work. 
3. This CANNOT run on serverless. You must use a Databricks Runtime ML 16.X or equivalent

In [0]:
%pip install --upgrade --force-reinstall dspy mlflow transformers databricks-vectorsearch databricks-sdk requests pdf2image pillow markdown gradio

In [0]:
%pip install --pre -U dspy

In [0]:
%sh
sudo apt clean
sudo apt update --fix-missing -y
sudo apt-get install -y libpoppler-cpp-dev pkg-config poppler-utils

In [0]:
dbutils.library.restartPython()

In [0]:
import logging
logging.getLogger("py4j.clientserver").setLevel(logging.ERROR)
logging.getLogger("mlflow.tracking.client").setLevel(logging.ERROR)

In [0]:
import dspy
import os

import mlflow
import mlflow.deployments
from mlflow.models.signature import ModelSignature
from mlflow.pyfunc import PythonModel
from mlflow.types.schema import Schema, ColSpec, TensorSpec

from PIL import Image
import base64
import io
from io import BytesIO

import numpy as np
import time
import pandas as pd 
import json
import requests

from typing import Literal, Any

from databricks.vector_search.client import VectorSearchClient
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.dashboards import GenieAPI

from pdf2image import convert_from_path
from pyspark.sql.types import StringType, StructType, StructField, IntegerType
from pyspark.sql.functions import col
from pyspark.sql import Row

from IPython.display import Markdown

mlflow.dspy.autolog()

In [0]:
from config import volume_label, volume_name, catalog, schema, model_name, model_endpoint_name, embedding_table_name, embedding_table_name_index, registered_model_name, vector_search_endpoint_name, beit_model_name, tesseract_model_name

#The Development Lifecycle

#Step 1: Ingest the Medical Summary

This would ideally be in a Databricks Workflow to process the PDF immediately upon arrival after Autoloader or another ingestion method

##Find PDFs and Install Poppler

Poppler needs an extra step to work properly on Databricks

In [0]:
import os
notebook_path = os.path.dirname(dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get())
workspace_folder = f"/Workspace{notebook_path}/sample_pdf_sbc"
files = os.listdir(workspace_folder)
file_paths_list = [os.path.join(workspace_folder, f) for f in files if f.endswith('.pdf')]
print(f"Found {len(file_paths_list)} PDF files:")
for pdf in file_paths_list:
    print(f"  - {pdf}")

In [0]:
def install_poppler_on_nodes():
    """
    Install poppler on all cluster nodes
    """
    import subprocess
    import os
    
    try:
        subprocess.run(['apt-get', 'update'], check=True)
        subprocess.run(['apt-get', 'install', '-y', 'poppler-utils'], check=True)
        print("Poppler installed successfully")
    except subprocess.CalledProcessError as e:
        print(f"Error installing poppler: {e}")

sc.range(1).foreach(lambda x: install_poppler_on_nodes())

##Process the PDFs found in the file

In [0]:
def process_all_pdfs(pdf_paths):
    """
    Process all PDFs on driver node to avoid UDF distribution issues
    """
    all_pages = []
    
    def resize_image(image, max_short_dimension=768, max_long_dimension=2000):
        """Resize image while maintaining aspect ratio"""
        width, height = image.size
        
        if width > height:
            scaling_factor = min(max_long_dimension / width, max_short_dimension / height)
        else:
            scaling_factor = min(max_short_dimension / width, max_long_dimension / height)
        
        if scaling_factor < 1:
            new_width = int(width * scaling_factor)
            new_height = int(height * scaling_factor)
            return image.resize((new_width, new_height), Image.LANCZOS)
        
        return image
    
    for pdf_path in pdf_paths:
        try:
            if not os.path.exists(pdf_path):
                print(f"File not found: {pdf_path}")
                continue
            
            if not os.access(pdf_path, os.R_OK):
                print(f"File not readable: {pdf_path}")
                continue
            
            print(f"Processing: {pdf_path}")
            
            images = convert_from_path(
                pdf_path, 
                dpi=100,
                fmt='JPEG',
                poppler_path='/usr/bin'  
            )
            
            for i, image in enumerate(images):
                resized_image = resize_image(image)
                
                if resized_image.mode != 'RGB':
                    resized_image = resized_image.convert('RGB')
                
                quantized_image = resized_image.quantize(colors=256)
     
                quantized_image = quantized_image.convert('RGB')
                
                img_buffer = io.BytesIO()
                quantized_image.save(img_buffer, format='JPEG', quality=70, optimize=True)
                img_bytes = img_buffer.getvalue()
                

                base64_string = base64.b64encode(img_bytes).decode('utf-8')
                
                all_pages.append({
                    'pdf_path': pdf_path,
                    'page_number': i + 1,
                    'base64_image': base64_string
                })
            
            print(f"Successfully processed {len(images)} pages from {pdf_path}")
            
        except Exception as e:
            print(f"Error processing {pdf_path}: {str(e)}")
            import traceback
            traceback.print_exc()
            continue
    
    return all_pages
  

print(f"Processing {len(file_paths_list)} PDFs...")
all_page_data = process_all_pdfs(file_paths_list)
print(f"Total pages processed: {len(all_page_data)}")


pdf_schema = StructType([
    StructField("pdf_path", StringType(), True),
    StructField("page_number", IntegerType(), True),
    StructField("base64_image", StringType(), True)
])

df_pages = spark.createDataFrame(all_page_data, pdf_schema)

In [0]:
df_pages.show()

#Step 2: Build your tools on Databricks

Databricks offers a broad suite of infrastrucutre that allows you manage and govern everything in one place. This allows us to utilize many services and custom code/models to use as tools for our Agents. Below, we walk through setting up tools to hit Model Serving, Genie Spaces and Vector Search. 

For DSPy, all you need to do is define a python function. There is no need to attach it to a framework specific type. All you need is some python

## Model Serving
We host our vision models on Databricks Model Serving so that we can access and work with them with our data. The two we have ready is: 
1. Microsoft BeIT Open Source Vision Transformer 
2. Good Old Tesseract

In [0]:

def vision_transformer_tool(url, task):
    """Used to classify an image. Use this to answer the user's question about an image"""
    API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
    DATABRICKS_URL = dbutils.notebook.entry_point.getDbutils().notebook().getContext().browserHostName().get()
    if task == 'ocr':
      model = tesseract_model_name
      encoded_image = url
      input_data = pd.DataFrame({'image': [encoded_image]})

      input_json = input_data.to_json(orient='split')

      payload = {
          "dataframe_split": json.loads(input_json)
      }

      headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

      response = requests.post(
          url=f"https://{DATABRICKS_URL}/serving-endpoints/{model}/invocations", json=payload, headers=headers
      )

      result2 = response.json()
      print(result2['predictions'])
      return result2['predictions']
    
    if task == 'classification':
      model = beit_model_name
      data = {"inputs": [url]}

      headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

      response = requests.post(
          url=f"https://{DATABRICKS_URL}/serving-endpoints/{model}/invocations", json=data, headers=headers
      )

      result = response.json()
      print(result)
      return result['predictions'][0]['0']['label']  

Let's test they work using the first page of the PDF

In [0]:
one_base64_image = df_pages.select("base64_image").limit(1).collect()[0]["base64_image"]

In [0]:
classification = vision_transformer_tool(url=one_base64_image, task="classification")
ocr = vision_transformer_tool(url=one_base64_image, task="ocr")

##Genie Spaces

Databricks provides an API to interact with your Genie Spaces, allowing you to do natural language sql queries on your structured data. We can rely on the integration with Databricks to ensure we make the most accurate query for our data. 

There is currently no programatic way to create a Genie Space. Make sure you created one based on the tables in notebook 01

In [0]:
def hls_patient_genie(patient_name):
  """This function queries a genie space for more information about a patient""" 
  w = WorkspaceClient()
  genie_space_id = "01effef4c7e113f9b8952cf568b49ac7"

  # Start a conversation
  conversation = w.genie.start_conversation_and_wait(
      space_id=genie_space_id,
      content=f"{patient_name} always limit to one result"
  )

  response = w.genie.get_message_attachment_query_result(
    space_id=genie_space_id,
    conversation_id=conversation.conversation_id,
    message_id=conversation.message_id,
    attachment_id=conversation.attachments[0].attachment_id
  )

  return response.statement_response.result.data_array

In [0]:
hls_patient_genie("what kind of insurance does Sarah Roberts have")

##Vector Search

For this example, insurance information is also hidden in other PDF files. Due to time, we have already generated the vector search index for these PDFs and will be using it in this demo

In [0]:
vs_client = VectorSearchClient()

vector_search_endpoint_name = "one-env-shared-endpoint-4"
index_name = f"{embedding_table_name}_index"
index = vs_client.get_index(endpoint_name=vector_search_endpoint_name, index_name=f"{catalog}.{schema}.{index_name}")

def vector_search_for_patient_pdf(self, text_query):
    """Pulls matching Insurance Documents based on the text_query"""
    client = mlflow.deployments.get_deploy_client("databricks") 
    response = client.predict(
              endpoint=model_endpoint_name,
              inputs={"dataframe_split": {
                      "columns": ["text"],
                      "data": [[text_query]]
                      }
              }
            )
    text_embedding = response['predictions']['predictions']['embedding']
    index = vs_client.get_index(endpoint_name=vector_search_endpoint_name, index_name=f"{catalog}.{schema}.{index_name}")
    results = index.similarity_search(num_results=3, columns=["base64_image"], query_vector=text_embedding)
    return results['result']['data_array'][0][0]

#Step 3: Set up DSPy

We use DSPy due to its lightweight framework and declarative approach. There are minimum dependencies and we can build our program modularly like we would in Python Code. 

There's no need to learn framework specific classes or wait for framework specifc integrations. If there's a pure python approach or an SDK, you can use it with DSPy 

##DSPy LLM Configuration

DSPy uses LiteLLM in the backend to give universal access to LLMs no matter the provider or if it's local (through ollama for example). 

All you need to do is change the string! No new libraries needed!

In [0]:
claude = dspy.LM('databricks/databricks-claude-sonnet-4', cache=False)
claude_anthropic = dspy.LM('anthropic/claude-sonnet-4-20250514', api_key=dbutils.secrets.get(scope = "groq_key", key = "anthropic"), cache=False)
llama8b = dspy.LM("databricks/databricks-meta-llama-3-1-8b-instruct", cache=False)
llama4 = dspy.LM("databricks/databricks-llama-4-maverick", cache=False)
# openai_example = dspy.LM('openai/gpt-4o-mini', api_key='YOUR_OPENAI_API_KEY')
# bedrock_example = dspy.LM('"bedrock/anthropic.claude-3-sonnet-20240229-v1:0', AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME=AWS_REGION_NAME)
dspy.configure(lm=claude)

##DSPy BaseType

DSPy BaseTypes are essentially Pydantic BaseModels that will help us with data validation. For our case, we will define a memory type so capture the conversation so far. 

Because DSPy is a pure python framework, you can really bring any typing or library you like to handle this like Langmem. 

In [0]:
from typing import List, Any

class memoryHistory(dspy.BaseType):
  history: List[dict] 
  last_message: List[str]
  summary_so_far: str
  # placeholder: str

  def format(self) -> list[dict[str, Any]]:
    return [
      {
      "type": "memory", 
      "memory": {
        "history": self.history, 
        "message": self.last_message, 
        "summary": self.summary_so_far,
        "placeholder": self.placeholder,
        }
      }
      ]


In [0]:
memory_history = memoryHistory(
  history=[],
  last_message=[],
  summary_so_far=""
)

#Step 4: Build your agents
Great, now we have our resources ready. So let's build out our Agents that will interact with each other to complete the task. 

Remember, the goal is to update our Patient Database with up to date information about this patient based on the ambigious document we received. 

DSPy makes multi-agent development a breeze thanks to their declarative, pure python approach. 

##DSPy Signatures 

Signatures allow you to enforce typing and modularize your code, allowing you to identify and tweak specific parts of your code, instead of a wall of text. 

DSPy will use EVERYTHING within a signature to adapt to a prompt. Your docstring is where you can do some prompt engineering if you wish but it is not necessary. You should use it like you would when writing good documentation (assuming you do write good documentation) 

Recommendation: Organize this into its own file to import where you like or use mlflow prompt registry

In [0]:
class text_summarizer_extraction(dspy.Signature): 
  """Agent to summarize the ocr output and find keywords based on the original query."""

  ocr_input: str = dspy.InputField()
  original_query: str = dspy.InputField()
  memory_so_far: memoryHistory = dspy.InputField(desc="a history of the workflow so far")
  response: str = dspy.OutputField()
  summary_so_far: str = dspy.OutputField()
  keywords: str = dspy.OutputField()
  next_agent_or_tool: Literal["text_processing_agent", "patient_lookup_genie_agent", "final_agent"] = dspy.OutputField() 

class genie_agent(dspy.Signature): 
  """Agent to use Databricks Genie Space to find information about a patient. It creates a question based on the provided keywords in patient_information or memory_history to query the genie_space with only the patient's name. Then, it takes the genie_output, makes a text_query based on insurance type, insurance name and keyterms like deductible found in both genie_outputs and original_query and sends the text_query to patient_insurance_lookup"""

  patient_information: str = dspy.InputField(desc="Find the patient's name")
  original_query: str = dspy.InputField()
  memory_so_far: memoryHistory = dspy.InputField(desc="a history of the workflow so far")
  genie_output: str = dspy.OutputField()
  insurance_details: str = dspy.OutputField()
  response: str = dspy.OutputField()
  summary_so_far: str = dspy.OutputField()
  deductible: str = dspy.OutputField(desc="this is the result of patient_insurance_lookup")
  next_agent_or_tool: Literal["text_processing_agent", "patient_lookup_genie_agent", "final_agent"] = dspy.OutputField() 

class final_agent(dspy.Signature):
  """Agent to convert the collected information and write to a delta table based on the original_query.""" 

  original_query: str = dspy.InputField()
  genie_output: list = dspy.InputField() 
  ocr_summary: str = dspy.InputField() 
  deductible: str = dspy.InputField()
  completed_response: str = dspy.OutputField()

class document_analyzer(dspy.Signature):
  """Agent to analyze the document provided by reviewing the outputs of the model and determining if there's enough information to go to the next agent or try analyzing the document again with a different vision model""" 

  vision_model_output: str = dspy.InputField()
  response: str = dspy.OutputField() 
  next_agent_or_tool: Literal["text_processing_agent", "patient_lookup_genie_agent", "final_agent"] = dspy.OutputField() 

class insurance_finder(dspy.Signature):
  """Find the relevant information based on the text_query within the image"""

  image: dspy.Image = dspy.InputField()
  text_query: str = dspy.InputField()
  deductible: str = dspy.OutputField()
  other_information: str = dspy.OutputField()

#Step 5: Test your Agents independently to ensure they work 

Before putting our agents together, we can use prebuilt DSPy modules like dspy.Predict and dspy.ReAct to ensure that the signatures are performing as we expect them to. 

Let's try them out below

In [0]:
ocr_processing = dspy.Predict(text_summarizer_extraction)
outputs = ocr_processing(ocr_input=ocr, original_query="Update the patient's information based on this document and find their deductible if it's not in the table", memory_so_far=memory_history)

In [0]:
print(outputs)

In [0]:
patient_lookup = dspy.ReAct(genie_agent, tools=[hls_patient_genie], max_iters=1)
genie_outputs = patient_lookup(patient_information=outputs.keywords, memory_so_far=memory_history)

In [0]:
print(genie_outputs)

#Step 6: Put it all together in a custom DSPy Module

DSPy gives you the freedom to program your own module. Below, you'll see the custom module that packages everything you saw above into one class. We can then execute that class in one line. 

Everything follows object oriented programming and familiar patterns like Pytorch's forward method. Additionally, you can add any custom python logic within this module so that you don't have to completely rely on the agent to do everything. You know exactly what is happening and can control the flow

In [0]:
class dais_document_ingestor(dspy.Module):
  def __init__(self):
    super().__init__()
    self.patient_lookup_genie_agent = dspy.ReAct(genie_agent, tools=[self.hls_patient_genie, self.patient_insurance_lookup], max_iters=2)
    self.text_processing_agent = dspy.Predict(text_summarizer_extraction)
    self.final_agent = dspy.ReAct(final_agent, tools=[self.final_write_to_table], max_iters=1) #likely a tool call
    self.document_analyzer_agent = dspy.Predict(document_analyzer)
    self.insurance_finder = dspy.Predict(insurance_finder)
    self.memory_history = memoryHistory(history=[],
                                        last_message=[],
                                        summary_so_far=""
                                        )

  def patient_insurance_lookup(self, text_query):
    """Pulls matching Insurance Documents based on the text_query"""
    print(f"Starting patient insurance look up with results: {text_query}\n\n")
    # patient_insurance_extraction = dspy.Predict("genie_outputs-> patient_name_and_insurance_details: str")
    # text_query = patient_insurance_extraction(genie_outputs=genie_outputs).patient_name_and_insurance_details
    print(f"The Text query being sent for vector search: {text_query}\n\n")
    vs_client = VectorSearchClient()
    vector_search_endpoint_name = "one-env-shared-endpoint-4"
    index_name = f"{embedding_table_name}_index"
    client = mlflow.deployments.get_deploy_client("databricks") 
    response = client.predict(
              endpoint=model_endpoint_name,
              # endpoint=model_name,
              inputs={"dataframe_split": {
                      "columns": ["text"],
                      "data": [[text_query]]
                      }
              }
            )
    text_embedding = response['predictions']['predictions']['embedding']
    index = vs_client.get_index(endpoint_name=vector_search_endpoint_name, index_name=f"{catalog}.{schema}.{index_name}")
    results = index.similarity_search(num_results=3, columns=["base64_image"], query_vector=text_embedding)
    base64_string = results['result']['data_array'][0][0]
    image_data = base64.b64decode(base64_string) 
    pil_image = Image.open(io.BytesIO(image_data))
    dspy_image = dspy.Image.from_PIL(pil_image)
    with dspy.context(lm=claude_anthropic):
      print(f"Reviewing Vector Search Results\n\nCurrent Model: {claude_anthropic}\n\n")
      deductible = self.insurance_finder(image=dspy_image, text_query=text_query)
    return deductible.deductible
  
  def hls_patient_genie(self, patient_name):
    """This function queries a genie space for more information about a patient""" 
    w = WorkspaceClient()
    genie_space_id = "01effef4c7e113f9b8952cf568b49ac7"

    # Start a conversation
    conversation = w.genie.start_conversation_and_wait(
        space_id=genie_space_id,
        content=f"{patient_name} always limit to one result"
    )

    response = w.genie.get_message_attachment_query_result(
      space_id=genie_space_id,
      conversation_id=conversation.conversation_id,
      message_id=conversation.message_id,
      attachment_id=conversation.attachments[0].attachment_id
    )

    return response.statement_response.result.data_array
  

  def document_analyzer_tool(self, task, medical_summary_df):
    """Used to analyze a document. Medical_Summary_df should be a dataframe and task is either ocr or classification"""
    API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
    DATABRICKS_URL = dbutils.notebook.entry_point.getDbutils().notebook().getContext().browserHostName().get()
    encoded_image = medical_summary_df.select("base64_image").limit(1).collect()[0]["base64_image"]
    if task == 'ocr':
      model = tesseract_model_name
      x = 0 
      full_text = []
      for x in range(len(df_pages.collect())):
        encoded_image = df_pages.select("base64_image").collect()[x]["base64_image"]
        input_data = pd.DataFrame({'image': [encoded_image]})

        input_json = input_data.to_json(orient='split')

        payload = {
            "dataframe_split": json.loads(input_json)
        }

        headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

        response = requests.post(
            url=f"https://{DATABRICKS_URL}/serving-endpoints/{model}/invocations", json=payload, headers=headers
        )

        result2 = response.json()
        full_text.append(result2['predictions'])
        x+=1
      return full_text
    
    if task == 'classification':
      model = beit_model_name
      data = {"inputs": [encoded_image]}

      headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

      response = requests.post(
          url=f"https://{DATABRICKS_URL}/serving-endpoints/{model}/invocations", json=data, headers=headers
      )

      result = response.json()
      return result['predictions'][0]['0']['label']  
    
  def process_agent_response(self, memory_history: memoryHistory, response: str, summary_so_far: str):
    """Helper method to process agent responses and determine next steps"""
    memory_history.history.append(response)
    memory_history.summary_so_far = summary_so_far
    return memory_history
  
  def final_write_to_table(self, original_query, genie_output, ocr_summary, delta_table, deductible):
    """Spark code to write to the table the original_query specified""" 
    print(f"This is the final_write function\n\nGenie Output {genie_output}\n\nOrigina_query: {original_query}\n\nOCR Summary: {ocr_summary}\n\nDeductible: {deductible}\n\nDelta Table Name: {delta_table}")
    insert_data = [
      Row(first_name = genie_output[0],
      last_name = genie_output[1],
      insurance_provider_name = genie_output[2],
      insurance_type = genie_output[3],
      insurance_policy_number=genie_output[4],
      email=genie_output[5],
      city=genie_output[6],
      practice_visited_practice_id=int(genie_output[7]),
      doctor_notes=ocr_summary,
      deductible=deductible)
    ]
    df_with_new_column = spark.createDataFrame(insert_data)
    df_with_new_column.write.format("delta").mode("append").option("mergeSchema", "true").saveAsTable(delta_table)
    
    return "placehoolder"

  
  def handle_question(self, original_query, medical_summary_df, next_agent, memory_history, agent_response, ocr_result):
    """Processes a document given to the agent. Multiple agents work together using a variety of LLMs to solve the workflow"""
    next_agent = next_agent
    print("Beginning Agent Interaction\n\n")
    while True:
        if next_agent == 'text_processing_agent':
          print("Text Extract AGENT STARTING\n\n")
          with dspy.context(lm=llama4):
            print(f"Current Model: {llama4}\n\n")
            text_agent_response = self.text_processing_agent(
                ocr_input=ocr_result,
                original_query=original_query,
                memory_so_far=memory_history
            )
          memory_history = self.process_agent_response(memory_history=memory_history, response=text_agent_response.response, summary_so_far=text_agent_response.summary_so_far)
          agent_response = text_agent_response.response
          next_agent = text_agent_response.next_agent_or_tool
          print(f"Completed Text AGENT... Moving to {next_agent}\n\n")
          continue

        elif next_agent == 'patient_lookup_genie_agent':
          print("Patient Look Up Genie AGENT STARTING\n")
          with dspy.context(lm=claude):
            print(f"Current Model: {claude}\n\n")
            genie_agent_response = self.patient_lookup_genie_agent(
                patient_information=ocr_result,
                original_query=original_query,
                memory_so_far=memory_history
            )
          memory_history = self.process_agent_response(memory_history=memory_history, response=genie_agent_response.response,summary_so_far=genie_agent_response.summary_so_far)
          agent_response = genie_agent_response.response
          next_agent = genie_agent_response.next_agent_or_tool
          print(f"Completed Patient Look Up Genie... Moving to {next_agent}\n\n")
          continue

        elif next_agent == 'final_agent':
          print("Wrapping up with the final_agent\n")
          with dspy.context(lm=claude):
            print(f"Current Model: {claude}\n\n")
            final_agent_response = self.final_agent(
                original_query=original_query,
                genie_output=genie_agent_response.genie_output,
                ocr_summary=text_agent_response.summary_so_far,
                deductible=genie_agent_response.deductible
            )
          
          break

    return final_agent_response.completed_response
  
  def forward(self, initial_query: str, medical_summary_df):
    """Main interaction loop"""
    next_agent = ""
    memory_history = memoryHistory(
      history=[],
      last_message=[],
      summary_so_far=""
    )
    vision_model_output = self.document_analyzer_tool(task='classification',medical_summary_df=medical_summary_df)
    with dspy.context(lm=llama4):
      document_analyzer_results = self.document_analyzer_agent(vision_model_output=vision_model_output)
    print(f"Image Analyzer Agent Response: {document_analyzer_results.response}\n\n")

    vision_model_output = self.document_analyzer_tool(task='ocr',medical_summary_df=medical_summary_df)
    with dspy.context(lm=llama4):
      document_analyzer_results = self.document_analyzer_agent(vision_model_output=vision_model_output)
    print(f"Image Analyzer Agent Response: {document_analyzer_results.response}\n\n")
    print("Completed Image Analysis. Handing off to Handle_Question to begin Agent Interaction\n\n")
    results = self.handle_question(original_query=initial_query, next_agent=document_analyzer_results.next_agent_or_tool, medical_summary_df=medical_summary_df, memory_history=memory_history, agent_response=document_analyzer_results.response, ocr_result = vision_model_output)
    return results



In [0]:
original_query = "Update the austin_choi_demo_catalog.agents.patient_visits table with information from the document. Also, add their insurance deductible to the table."

dais_ingestor_agent = dais_document_ingestor()
dais_output = dais_ingestor_agent(initial_query=original_query, medical_summary_df=df_pages)
Markdown(dais_output)

In [0]:
query = f"SELECT * FROM {catalog}.{schema}.patient_visits WHERE first_name = 'John' AND last_name = 'Smith'"
df_john_smith = spark.sql(query)
display(df_john_smith)

In [0]:
query = f"ALTER TABLE {catalog}.{schema}.patient_visits DROP COLUMN deductible"
df_john_smith = spark.sql(query)
display(df_john_smith)

In [0]:
query = f"DELETE FROM {catalog}.{schema}.patient_visits WHERE first_name = 'John' AND last_name = 'Smith' AND reason_for_visit is NULL"
df_john_smith = spark.sql(query)
display(df_john_smith)

#Congrats! The demo is complete

If you want to see the demo in a gradio UI, run the code below! You'll need to manually upload the sample_pdf_sbc document located in the folder!

In [0]:
import gradio as gr
import pandas as pd
import base64
from PIL import Image
import io
from typing import List, Tuple
import tempfile
import os
from pdf2image import convert_from_path
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
import traceback

class dais_document_ingestor(dspy.Module):
  def __init__(self):
    super().__init__()
    self.patient_lookup_genie_agent = dspy.ReAct(genie_agent, tools=[self.hls_patient_genie, self.patient_insurance_lookup], max_iters=2)
    self.text_processing_agent = dspy.Predict(text_summarizer_extraction)
    self.final_agent = dspy.ReAct(final_agent, tools=[self.final_write_to_table], max_iters=1) #likely a tool call
    self.document_analyzer_agent = dspy.Predict(document_analyzer)
    self.insurance_finder = dspy.Predict(insurance_finder)
    self.memory_history = memoryHistory(history=[],
                                        last_message=[],
                                        summary_so_far=""
                                        )

  def patient_insurance_lookup(self, text_query):
    """Pulls matching Insurance Documents based on the text_query"""
    print(f"Starting patient insurance look up with results: {text_query}\n\n")
    # patient_insurance_extraction = dspy.Predict("genie_outputs-> patient_name_and_insurance_details: str")
    # text_query = patient_insurance_extraction(genie_outputs=genie_outputs).patient_name_and_insurance_details
    print(f"The Text query being sent for vector search: {text_query}\n\n")
    vs_client = VectorSearchClient()
    vector_search_endpoint_name = "one-env-shared-endpoint-4"
    index_name = f"{embedding_table_name}_index"
    client = mlflow.deployments.get_deploy_client("databricks") 
    response = client.predict(
              endpoint=model_endpoint_name,
              # endpoint=model_name,
              inputs={"dataframe_split": {
                      "columns": ["text"],
                      "data": [[text_query]]
                      }
              }
            )
    text_embedding = response['predictions']['predictions']['embedding']
    index = vs_client.get_index(endpoint_name=vector_search_endpoint_name, index_name=f"{catalog}.{schema}.{index_name}")
    results = index.similarity_search(num_results=3, columns=["base64_image"], query_vector=text_embedding)
    base64_string = results['result']['data_array'][0][0]
    image_data = base64.b64decode(base64_string) 
    pil_image = Image.open(io.BytesIO(image_data))
    dspy_image = dspy.Image.from_PIL(pil_image)
    with dspy.context(lm=claude_anthropic):
      print(f"Reviewing Vector Search Results\n\nCurrent Model: {claude_anthropic}\n\n")
      deductible = self.insurance_finder(image=dspy_image, text_query=text_query)
    return deductible.deductible
  
  def hls_patient_genie(self, patient_name):
    """This function queries a genie space for more information about a patient""" 
    w = WorkspaceClient()
    genie_space_id = "01effef4c7e113f9b8952cf568b49ac7"

    # Start a conversation
    conversation = w.genie.start_conversation_and_wait(
        space_id=genie_space_id,
        content=f"{patient_name} always limit to one result"
    )

    response = w.genie.get_message_attachment_query_result(
      space_id=genie_space_id,
      conversation_id=conversation.conversation_id,
      message_id=conversation.message_id,
      attachment_id=conversation.attachments[0].attachment_id
    )

    return response.statement_response.result.data_array
  

  def document_analyzer_tool(self, task, medical_summary_df):
    """Used to analyze a document. Medical_Summary_df should be a dataframe and task is either ocr or classification"""
    API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
    DATABRICKS_URL = dbutils.notebook.entry_point.getDbutils().notebook().getContext().browserHostName().get()
    encoded_image = medical_summary_df.select("base64_image").limit(1).collect()[0]["base64_image"]
    if task == 'ocr':
      model = tesseract_model_name
      x = 0 
      full_text = []
      for x in range(len(df_pages.collect())):
        encoded_image = df_pages.select("base64_image").collect()[x]["base64_image"]
        input_data = pd.DataFrame({'image': [encoded_image]})

        input_json = input_data.to_json(orient='split')

        payload = {
            "dataframe_split": json.loads(input_json)
        }

        headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

        response = requests.post(
            url=f"https://{DATABRICKS_URL}/serving-endpoints/{model}/invocations", json=payload, headers=headers
        )

        result2 = response.json()
        full_text.append(result2['predictions'])
        x+=1
      return full_text
    
    if task == 'classification':
      model = beit_model_name
      data = {"inputs": [encoded_image]}

      headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

      response = requests.post(
          url=f"https://{DATABRICKS_URL}/serving-endpoints/{model}/invocations", json=data, headers=headers
      )

      result = response.json()
      return result['predictions'][0]['0']['label']  
    
  def process_agent_response(self, memory_history: memoryHistory, response: str, summary_so_far: str):
    """Helper method to process agent responses and determine next steps"""
    memory_history.history.append(response)
    memory_history.summary_so_far = summary_so_far
    return memory_history
  
  def final_write_to_table(self, original_query, genie_output, ocr_summary, delta_table, deductible):
    """Spark code to write to the table the original_query specified""" 
    print(f"This is the final_write function\n\nGenie Output {genie_output}\n\nOrigina_query: {original_query}\n\nOCR Summary: {ocr_summary}\n\nDeductible: {deductible}\n\nDelta Table Name: {delta_table}")
    insert_data = [
      Row(first_name = genie_output[0],
      last_name = genie_output[1],
      insurance_provider_name = genie_output[2],
      insurance_type = genie_output[3],
      insurance_policy_number=genie_output[4],
      email=genie_output[5],
      city=genie_output[6],
      practice_visited_practice_id=int(genie_output[7]),
      doctor_notes=ocr_summary,
      deductible=deductible)
    ]
    df_with_new_column = spark.createDataFrame(insert_data)
    df_with_new_column.write.format("delta").mode("append").option("mergeSchema", "true").saveAsTable(delta_table)
    
    return "placehoolder"

  
  def handle_question(self, original_query, medical_summary_df, next_agent, memory_history, agent_response, ocr_result, progress=gr.Progress()):
    """Processes a document given to the agent. Multiple agents work together using a variety of LLMs to solve the workflow"""
    next_agent = next_agent
    progress_number = 0.55
    print("Beginning Agent Interaction\n\n")
    while True:
        if next_agent == 'text_processing_agent':
          print("Text Extract AGENT STARTING\n\n")
          with dspy.context(lm=llama4):
            # print(f"Current Model: {llama4}\n\n")
            text_agent_response = self.text_processing_agent(
                ocr_input=ocr_result,
                original_query=original_query,
                memory_so_far=memory_history
            )
          memory_history = self.process_agent_response(memory_history=memory_history, response=text_agent_response.response, summary_so_far=text_agent_response.summary_so_far)
          agent_response = text_agent_response.response
          next_agent = text_agent_response.next_agent_or_tool
          # print(f"Completed Text AGENT... Moving to {next_agent}\n\n")
          progress(progress_number+0.1, f"Completed Text AGENT... Moving to {next_agent}\n\n")
          continue

        elif next_agent == 'patient_lookup_genie_agent':
          print("Patient Look Up Genie AGENT STARTING\n")
          with dspy.context(lm=claude):
            # print(f"Current Model: {claude}\n\n")
            genie_agent_response = self.patient_lookup_genie_agent(
                patient_information=ocr_result,
                original_query=original_query,
                memory_so_far=memory_history
            )
          memory_history = self.process_agent_response(memory_history=memory_history, response=genie_agent_response.response,summary_so_far=genie_agent_response.summary_so_far)
          agent_response = genie_agent_response.response
          next_agent = genie_agent_response.next_agent_or_tool
          # print(f"Completed Patient Look Up Genie... Moving to {next_agent}\n\n")
          progress(progress_number+0.3, f"Completed Patient Look Up Genie... Moving to {next_agent}\n\n")
          continue

        elif next_agent == 'final_agent':
          # print("Wrapping up with the final_agent\n")
          with dspy.context(lm=claude):
            # print(f"Current Model: {claude}\n\n")
            final_agent_response = self.final_agent(
                original_query=original_query,
                genie_output=genie_agent_response.genie_output,
                ocr_summary=text_agent_response.summary_so_far,
                deductible=genie_agent_response.deductible
            )
            progress(0.95, f"Completed Final Agent... Finalizing...\n\n")
          
          break

    return final_agent_response.completed_response
  
  def forward(self, initial_query: str, medical_summary_df, progress=gr.Progress()):
    """Main interaction loop"""
    next_agent = ""
    memory_history = memoryHistory(
      history=[],
      last_message=[],
      summary_so_far=""
    )
    vision_model_output = self.document_analyzer_tool(task='classification',medical_summary_df=medical_summary_df)
    with dspy.context(lm=llama4):
      document_analyzer_results = self.document_analyzer_agent(vision_model_output=vision_model_output)
    # print(f"Image Analyzer Agent Response: {document_analyzer_results.response}\n\n")
    progress(0.4, desc=f"Image Analyzer Agent Response: {document_analyzer_results.response}\n\n")

    vision_model_output = self.document_analyzer_tool(task='ocr',medical_summary_df=medical_summary_df)
    with dspy.context(lm=llama4):
      document_analyzer_results = self.document_analyzer_agent(vision_model_output=vision_model_output)
    # print(f"Image Analyzer Agent Response: {document_analyzer_results.response}\n\n")
    progress(0.45, desc=f"Image Analyzer Agent Response: {document_analyzer_results.response}\n\n")
    # print("Completed Image Analysis. Handing off to Handle_Question to begin Agent Interaction\n\n")
    progress(0.5, desc="Completed Image Analysis. Handing off to Handle_Question to begin Agent Interaction\n\n")
    results = self.handle_question(original_query=initial_query, next_agent=document_analyzer_results.next_agent_or_tool, medical_summary_df=medical_summary_df, memory_history=memory_history, agent_response=document_analyzer_results.response, ocr_result = vision_model_output)
    return results



In [0]:
# Initialize the document ingestor
dais_ingestor_agent = dais_document_ingestor()

def process_all_pdfs(pdf_paths):
    """
    Process all PDFs on driver node to avoid UDF distribution issues
    """
    all_pages = []
    
    def resize_image(image, max_short_dimension=768, max_long_dimension=2000):
        """Resize image while maintaining aspect ratio"""
        width, height = image.size
        if width > height:
            scaling_factor = min(max_long_dimension / width, max_short_dimension / height)
        else:
            scaling_factor = min(max_short_dimension / width, max_long_dimension / height)
        
        if scaling_factor < 1:
            new_width = int(width * scaling_factor)
            new_height = int(height * scaling_factor)
            return image.resize((new_width, new_height), Image.LANCZOS)
        return image
    
    for pdf_path in pdf_paths:
        try:
            if not os.path.exists(pdf_path):
                print(f"File not found: {pdf_path}")
                continue
                
            if not os.access(pdf_path, os.R_OK):
                print(f"File not readable: {pdf_path}")
                continue
                
            print(f"Processing: {pdf_path}")
            images = convert_from_path(
                pdf_path,
                dpi=100,
                fmt='JPEG',
                poppler_path='/usr/bin'
            )
            
            for i, image in enumerate(images):
                resized_image = resize_image(image)
                
                if resized_image.mode != 'RGB':
                    resized_image = resized_image.convert('RGB')
                
                quantized_image = resized_image.quantize(colors=256)
                quantized_image = quantized_image.convert('RGB')
                
                img_buffer = io.BytesIO()
                quantized_image.save(img_buffer, format='JPEG', quality=70, optimize=True)
                img_bytes = img_buffer.getvalue()
                base64_string = base64.b64encode(img_bytes).decode('utf-8')
                
                all_pages.append({
                    'pdf_path': pdf_path,
                    'page_number': i + 1,
                    'base64_image': base64_string
                })
                
            print(f"Successfully processed {len(images)} pages from {pdf_path}")
            
        except Exception as e:
            print(f"Error processing {pdf_path}: {str(e)}")
            traceback.print_exc()
            continue
    
    return all_pages

def process_uploaded_pdfs(files):
    """Convert uploaded PDF files to base64 encoded images using the provided function"""
    if not files:
        return None
    
    # Convert Gradio NamedString objects to regular strings
    file_paths_list = []
    for f in files:
        if f:
            # Convert to string to avoid Spark serialization issues
            file_path = str(f)
            if os.path.exists(file_path):
                file_paths_list.append(file_path)
    
    if not file_paths_list:
        return None
    
    print(f"Processing {len(file_paths_list)} PDFs...")
    
    # Use the provided PDF processing function
    all_page_data = process_all_pdfs(file_paths_list)
    
    print(f"Total pages processed: {len(all_page_data)}")
    
    # Create Spark DataFrame
    try:
        pdf_schema = StructType([
            StructField("pdf_path", StringType(), True),
            StructField("page_number", IntegerType(), True),
            StructField("base64_image", StringType(), True)
        ])
        
        df_pages = spark.createDataFrame(all_page_data, pdf_schema)
        return df_pages
    except Exception as e:
        print(f"Error creating Spark DataFrame: {str(e)}")
        # Fallback for non-Spark environments
        return pd.DataFrame(all_page_data)

def process_document(files, query, progress=gr.Progress()):
    """Main processing function for Gradio interface"""
    
    # Validate inputs
    if not files:
        return "⚠️ Please upload at least one PDF document.", ""
    
    if not query:
        return "⚠️ Please enter a query.", ""
    
    try:
        # Update progress
        progress(0.1, desc="Converting PDFs to images...")
        
        # Convert uploaded PDFs to DataFrame format
        medical_summary_df = process_uploaded_pdfs(files)
        
        progress(0.2, desc="Converting PDFs to images...")

        if medical_summary_df is None:
            return "❌ Error processing uploaded PDF files.", ""
        
        # Get page count for feedback
        try:
            page_count = medical_summary_df.count()
        except:
            page_count = len(medical_summary_df)
        
        # Update progress
        progress(0.3, desc=f"Processing {page_count} pages through AI agents...")
        
        # Run the document ingestor
        result = dais_ingestor_agent(
            initial_query=query, 
            medical_summary_df=medical_summary_df
        )
        
        # Update progress
        progress(1.0, desc="Complete!")
        
        # Format the output
        status = f"✅ Document processed successfully! ({page_count} pages analyzed)"
        
        # Convert result to markdown if it's not already
        if hasattr(result, '__str__'):
            result_markdown = str(result)
        else:
            result_markdown = result
            
        return status, result_markdown
        
    except Exception as e:
        error_msg = f"❌ Error during processing: {str(e)}"
        traceback.print_exc()
        return error_msg, ""

def create_interface():
    """Create and configure the Gradio interface"""
    
    with gr.Blocks(title="DAIS Document Ingestor", theme=gr.themes.Soft()) as demo:
        
        # Header
        gr.Markdown(
            """
            # 🏥 DAIS Medical Document Ingestor
            
            This system processes medical PDF documents using multiple AI agents to:
            - Extract patient information from PDFs
            - Look up insurance details
            - Analyze document content
            - Update database tables
            
            Upload your medical PDFs and specify what actions to take.
            """
        )
        
        with gr.Row():
            with gr.Column(scale=1):
                # File upload
                file_input = gr.File(
                    label="Upload Medical PDFs",
                    file_count="multiple",
                    file_types=[".pdf"],
                    type="filepath"
                )
                
                # Query input
                query_input = gr.Textbox(
                    label="Query / Instructions",
                    placeholder="e.g., Update the patient_visits table with information from the document. Also, add their insurance deductible to the table.",
                    value="Update the austin_choi_demo_catalog.agents.patient_visits table with information from the document. Also, add their insurance deductible to the table.",
                    lines=3
                )
                
                # Process button
                process_btn = gr.Button("🚀 Process PDFs", variant="primary", size="lg")
            
            with gr.Column(scale=1):
                # Status output
                status_output = gr.Textbox(
                    label="Status",
                    lines=1,
                    interactive=False
                )
                
                # Result output
                result_output = gr.Markdown(
                    label="Processing Results",
                    value="*Results will appear here after processing...*"
                )
        
        # Additional information
        with gr.Accordion("ℹ️ How it works", open=False):
            gr.Markdown(
                """
                ### Processing Pipeline:
                1. **PDF Conversion**: Converts PDF pages to images with optimized sizing and quality
                2. **Document Analysis**: Classifies and extracts text from converted images using OCR
                3. **Patient Lookup**: Searches for patient information in the Genie space
                4. **Insurance Verification**: Finds matching insurance documents via vector search
                5. **Text Processing**: Summarizes and extracts key information
                6. **Database Update**: Writes the processed information to Delta tables
                
                ### PDF Processing Details:
                - PDFs are converted at 100 DPI for optimal quality/performance balance
                - Images are resized maintaining aspect ratio (max 768x2000 pixels)
                - Each page is processed individually and encoded as base64
                - Multi-page PDFs are fully supported
                
                ### Required Components:
                - Databricks environment with appropriate permissions
                - Poppler utilities installed for PDF conversion
                - Access to Genie spaces and vector search endpoints
                - Configured LLM models (Claude, Llama4)
                - Delta table write permissions
                """
            )
        
        # System requirements info
        with gr.Accordion("⚙️ System Requirements", open=False):
            gr.Markdown(
                """
                ### Required Python packages:
                ```python
                pip install pdf2image
                pip install pillow
                ```
                
                ### System dependencies:
                - Poppler utilities (for PDF conversion)
                  - Ubuntu/Debian: `apt-get install poppler-utils`
                  - CentOS/RHEL: `yum install poppler-utils`
                """
            )
        
        # Event handlers
        process_btn.click(
            fn=process_document,
            inputs=[file_input, query_input],
            outputs=[status_output, result_output],
            show_progress=True
        )
        
        # Clear button
        gr.Button("🗑️ Clear").click(
            fn=lambda: (None, 
                       "Update the austin_choi_demo_catalog.agents.patient_visits table with information from the document. Also, add their insurance deductible to the table.",
                       "",
                       "*Results will appear here after processing...*"),
            outputs=[file_input, query_input, status_output, result_output]
        )
    
    return demo

# Launch the interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch(
        share=True,  # Set to True to create a public link
        server_name="0.0.0.0",  # For Databricks environments
        server_port=8080,
        debug=True
    )