In [2]:
class GraphState(BaseModel):
    """
    Represents the state of the graph processing workflow.

    Attributes:
        choice (str): The user's choice of processing method.
        input_text (str): The original input text from the user.
        extra_instructions (str): Any additional instructions for processing.
        username (str): The username of the requester.
        dataset (str): The dataset to be used for retrieval.
        broadness (str): A parameter controlling the breadth of the search.
        selected_choices (Optional[List[str]]): Selected choices for multi-header processing.
        word_amounts (Optional[List[int]]): Word count targets for each section.
        model (Any): The language model to be used.
        context (Optional[str]): Retrieved context for the query.
        instructions (Optional[str]): Processed instructions.
        question (Optional[str]): The processed question.
        result (Optional[str]): The final result of the processing.
        relevant_prompt (Optional[str]): The relevant prompt template for the chosen processing method.
    """
    choice: str
    input_text: str
    extra_instructions: str
    username: str
    dataset: str
    broadness: Optional[int] = None
    selected_choices: Optional[List[str]] = None
    word_amounts: Optional[List[int]] = None
    model: Any
    context: Optional[str] = None
    instructions: Optional[str] = None
    question: Optional[str] = None
    result: Optional[str] = None
    relevant_prompt: Optional[str] = None
    retrieved_docs: Optional[List[str]] = None  # New field to store retrieved documents
    relevant_docs: Optional[List[str]] = None
    company_name: str


def process_multiple_headers(state: GraphState) -> str:
    def retrieve_subtopic_docs(sub_topic: str, k: int = RETRIEVE_SUBTOPIC_CHUNKS) -> List[str]:
        vectorstore_primary = Chroma(
            collection_name=state.dataset,
            persist_directory=f"{CHROMA_FOLDER}/{state.username}",
            embedding_function=embedder,
        )
        try:
            retriever = vectorstore_primary.as_retriever(search_type="mmr", k=k)
            docs = retriever.get_relevant_documents(sub_topic)
            return [doc.page_content for doc in docs[:k]]
        except Exception as e:
            log.error(f"Error getting relevant documents for sub-topic: {e}")
            return []

    async def invoke_chain(sub_topic, words):
        # Retrieve additional documents specific to the sub-topic
        subtopic_docs = await retrieve_subtopic_docs(sub_topic)

        # Check relevance of subtopic documents
        relevant_subtopic_docs = await check_subtopic_relevance(subtopic_docs, sub_topic, state.model)

        # Combine the original context with the relevant sub-topic specific documents
        combined_context = state.context + "\n\n" + "\n\n".join(relevant_subtopic_docs)

        chain_input = {
            "context": combined_context,
            "extra_instructions": state.instructions or "",
            "question": state.question or "",
            "sub_topic": sub_topic,
            "word_amounts": words,
            "except_sub_topics": ",".join([choice for choice in (state.selected_choices or []) if choice != sub_topic]),
            "company_name": state.company_name,
        }
        prompt = ChatPromptTemplate.from_template(state.relevant_prompt)
        chain = (
                RunnablePassthrough() | prompt | state.model | StrOutputParser()
        )
        res = await chain.ainvoke(chain_input)
        return f"{sub_topic}:\n\n{res}"

    if state.selected_choices and state.word_amounts:
        results = await asyncio.gather(*[invoke_chain(sub_topic, words)
                                         for sub_topic, words in zip(state.selected_choices, state.word_amounts)])
        return "\n\n".join(results)
    else:
        return "No sub-topics or word amounts provided"


NameError: name 'BaseModel' is not defined