In [2]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Dict, Any

from agents_nodes.clear_valid_input_validator.input_validator import validate_medical_input_agent
from vision_models.input_image_classification.image_classifier import classify_image


In [3]:
class AgentState(TypedDict):
    input_text: str
    attatchments: List[Dict[str, Any]]

In [4]:
def first_input_validation_node(state: AgentState):
    input_text = state["input_text"]
    attachments = state["attatchments"]

    first_input_validation_result = validate_medical_input_agent(input_text, attachments)
    state["input_validation_result"] = first_input_validation_result
    return state

In [6]:
def first_input_not_valid_fallback_node(state: AgentState):
    validation_results = state["input_validation_result"]

    if validation_results == "TEXT_VALID_ATTACHMENT_NOT_VALID":
        pass
    elif validation_results == "TEXT_NOT_VALID_ATTACHMENT_VALID":
        pass
    elif validation_results == "TEXT_NOT_VALID_ATTACHMENT_NOT_VALID":
        pass
    return state

In [7]:
def input_image_classification_node(state: AgentState):
    input_images_titles_and_paths = state["input_images_titles_and_paths"] # should be a dictionarie with 'title' key and 'path' value for each item

    results = [] # a list of lists, each list have [title, path, classification]
    for title, path in input_images_titles_and_paths.items():
        classification = classify_image(title, path)

        results.append([title, path, classification])
    
    state["input_images_classification_results"] = results

    return state

In [8]:
workflow = StateGraph(AgentState)
workflow.add_node("first_input_validation_node", first_input_validation_node)
workflow.add_node("first_input_not_valid_fallback_node", first_input_not_valid_fallback_node)
workflow.add_node("input_image_classification_node", input_image_classification_node)


<langgraph.graph.state.StateGraph at 0x1c618300680>