# Prompt Template

In [1]:
CODE_ANALYSIS_TEMPLATE = """
You are an expert code analyst. Analyze the following Python code and provide a structured summary
that will be useful for generating unit tests. Identify:
1.  Main functions and classes.
2.  For each function/method:
    - Its purpose or a brief description.
    - Input parameters (name and type if inferable).
    - Return type (if inferable).
    - Key logic branches or behaviors.
    - Potential edge cases or interesting scenarios to test.
3.  Any global variables or external dependencies that might affect testing.

Respond in a JSON format with the following structure:
{{  
    "summary": "Overall summary of the code's purpose.",
    "components": [
    {{  
        "type": "function" | "class",
        "name": "component_name",
        "description": "...",
        "methods": [ // Only if type is class
        {{  
            "name": "method_name",
            "signature": "def method_name(param1, param2): ...",
            "description": "...",
            "parameters": [{{"name": "param_name", "type": "inferred_type"}}, ...], 
            "returns": "inferred_return_type",
            "key_behaviors": ["behavior1", "behavior2"],
            "edge_cases": ["edge_case1", ...]
        }} 
        ],
        // For functions directly under components, structure similar to methods above
        "signature": "def function_name(param1): ...", // If type is function
        "parameters": [...], // If type is function
        "returns": "...", // If type is function
        "key_behaviors": [...], // If type is function
        "edge_cases": [...] // If type is function
    }} 
    ],
    "dependencies": ["dep1", "dep2"]
}} 
Remember to include JSON strings without any extra formatting or signs.
Code to analyze:
{code_to_analyze}
"""

TEST_GENERATION_TEMPLATE= """
You are an expert Python test developer. Based on the following analysis of a Python code component
and the original code context, write comprehensive unit tests using the `unittest` framework.

Original Code Context (for reference, ensure your tests would import/access this correctly):
```python
{original_code_snippet}
```

Code Component Analysis:
Name: {component_name}
Type: {component_type}
Signature: `{component_signature}`
Description: {component_description}
Key Behaviors to Test: {key_behaviors}
Potential Edge Cases to Test: {edge_cases}

{feedback}

Your task:
1.  Create a Python class that inherits from `unittest.TestCase`. Name it descriptively.
2.  Write test methods (starting with `test_`) within this class.
3.  Each test method should target one behavior or edge case identified.
4.  Use appropriate `self.assertXXX` methods from `unittest` for assertions.
5.  Ensure the tests are self-contained and clearly written.
6.  Assume necessary functions/classes from the original code are importable or accessible in the test execution scope.
    (For example, if testing a function `my_function` from `original_code`, your test might call `source_module.my_function(...)`
    or assume `my_function` is directly available if `original_code` was executed in the global scope of tests.
    For now, assume the component `{component_name}` is directly callable/instantiable.)
7.  Do NOT include the `if __name__ == '__main__': unittest.main()` block.
8.  Only provide the Python code for the test class. Do not add any explanatory text before or after the code block.

Component to generate tests for: `{component_name}`
Test Class Code:
```python
# [Your generated unittest.TestCase class for {component_name} goes here]
```
"""
EVALUATION_TEMPLATE= """
You are an expert Senior QA Engineer and Python Developer. Your task is to review a suite of generated unit tests
against the original Python code and its prior analysis. You should assess the quality, completeness,
and likely effectiveness of these tests. DO NOT execute the code.

Original Python Code:
```python
{original_code}
```

Prior Code Analysis (identifying key components, behaviors, and edge cases that should be tested):
```json
{code_analysis_json}
```
Generated Unit Tests (using unittest framework):
```python
{test_code}
```

Based on your review of these three inputs, please provide:
1.  An overall qualitative assessment of the test suite's likely coverage and quality. Choose one: "low", "medium", "high".
2.  A numeric score from 1 (very poor) to 10 (excellent) representing your confidence in these tests.
3.  Specific feedback:
    - What aspects of the original code (based on the analysis) seem well-tested?
    - What specific functions, methods, logic paths, or edge cases (from the analysis or your own observation of the original code) appear to be untested or inadequately tested by the provided test suite?
    - Any other suggestions for improving these tests (e.g., missing assertions, incorrect mocking (if inferable), unclear test names, testing anti-patterns).

Respond in a JSON format with the following structure:
{{
    "qualitative_assessment": "low|medium|high",
    "confidence_score": <float_from_1_to_10>,
    "positive_feedback": ["Aspect 1 well-tested...", "Aspect 2..."],
    "areas_for_improvement": ["Function X seems untested for Y case...", "Edge case Z is missing..."],
    "other_suggestions": ["Consider testing X...", "Test Y could be clearer if..."]
}}
Remember to include JSON strings without any extra formatting or signs.
"""

# State

In [2]:
from langchain.schema import Document
from langgraph.graph import MessagesState
from typing_extensions import TypedDict, List, Dict, Any, Optional, Annotated
import operator

class UnitTestWorkflowState(MessagesState):
    original_code: str
    analyzed_code: str
    test_code: str
    evaluation: dict
    flow: Annotated[List[str], operator.add]
    max_generation_attempts: int
    generation_attempts: int

# Schema

In [3]:
from typing import List, Optional, Literal, Union, Any
from pydantic import BaseModel, Field

class Parameter(BaseModel):
    name: str 
    type: Optional[str] = None 

class MethodDetails(BaseModel):
    signature: Optional[str] = None 
    description: Optional[str] = None 
    parameters: List[Parameter] = Field(default_factory=list)
    returns: Optional[str] = None
    key_behaviors: List[str] = Field(default_factory=list) 
    edge_cases: List[str] = Field(default_factory=list)

class Method(MethodDetails):
    name: str 

class FunctionComponent(MethodDetails):
    type: Literal["function"] 
    name: str 

class ClassComponent(BaseModel):
    type: Literal["class"] 
    name: str 
    description: Optional[str] = None 
    methods: List[Method] = Field(default_factory=list) 

Component = Union[FunctionComponent, ClassComponent]

class CodeAnalysis(BaseModel):
    summary: Optional[str] = None 
    components: List[Component] = Field(default_factory=list, discriminator='type' if hasattr(Field, 'discriminator') else None) # Mặc định là danh sách rỗng
    dependencies: List[str] = Field(default_factory=list) 
    
class TestCodeEvaluation(BaseModel):
    qualitative_assessment: Literal["low", "medium", "high"]
    confidence_score: float 
    positive_feedback: List[str] = Field(default_factory=list) 
    areas_for_improvement: List[str] = Field(default_factory=list) 
    other_suggestions: List[str] = Field(default_factory=list)


# Helper

In [4]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain_core.output_parsers.json import JsonOutputParser

def get_model(model: str, temperature: float = 0.0) -> ChatGoogleGenerativeAI | ChatOpenAI:
    if "gemini" in model:
        return ChatGoogleGenerativeAI(
            model=model,
            temperature=temperature
        )
    if "gpt" in model:
        return ChatOpenAI(
            model=model,
            temperature=temperature
        )
    raise ValueError(f"Model {model} not supported")

def create_code_analysis_chain(
    model: str,
    temperature: float = 0.0,
) -> LLMChain:
    llm = get_model(model, temperature)
    prompt = ChatPromptTemplate.from_template(
    CODE_ANALYSIS_TEMPLATE
    )
    chain = prompt | llm |JsonOutputParser(pydantic_object=CodeAnalysis)
    return chain
def create_test_generation_chain(
    model: str,
    temperature: float = 0.0,
) -> LLMChain:
    llm = get_model(model, temperature)
    prompt = ChatPromptTemplate.from_template(
        TEST_GENERATION_TEMPLATE
    )
    chain = prompt | llm
    return chain
def create_evaluation_chain(
    model: str,
    temperature: float = 0.0,
) -> LLMChain:
    llm = get_model(model, temperature)
    prompt = ChatPromptTemplate.from_template(
        EVALUATION_TEMPLATE
    )
    chain = prompt | llm | JsonOutputParser(pydantic_object=CodeAnalysis)
    return chain
def decision_to_end_workflow(state: UnitTestWorkflowState) -> str:
    if state["generation_attempts"] >= state["max_generation_attempts"] or state["evaluation"]["qualitative_assessment"] == "high":
        return "end"
    return "regenerate"

  from .autonotebook import tqdm as notebook_tqdm


# Node Func

In [5]:
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv())

True

## Code Analysis

In [6]:
from typing import TypedDict, Dict, Any, List, Optional
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage , AIMessage

# TODO: import UnitTestWorkflowState

# --- Constants cho analyze_code_node ---
ANALYZE_NODE_NAME = "Analyze Code Node"
ANALYZE_FLOW_SUCCESS = f"{ANALYZE_NODE_NAME}: Analysis successful"

ANALYZE_FLOW_FAILED_NO_CODE = f"{ANALYZE_NODE_NAME}: Failed (No original code provided)"
ANALYZE_FLOW_FAILED_LLM = f"{ANALYZE_NODE_NAME}: Failed (LLM analysis error)"

MESSAGE_NO_ORIGINAL_CODE = "No original code found in state for analysis."
MESSAGE_FLOW_FAILED_LLM = "The original code provided has a syntax error, which prevents analysis."

def code_analysis(state: UnitTestWorkflowState, code_analysis_chain) -> UnitTestWorkflowState:
  """
  Analyze the code and provide a structured summary for generating unit tests.
  """
  original_code = state.get("original_code",None)
  if not original_code:
      return {
        "messages": [MESSAGE_NO_ORIGINAL_CODE],
        "flow": [ANALYZE_FLOW_FAILED_NO_CODE]
      }
  analyzed_code = code_analysis_chain.invoke({
      "code_to_analyze": original_code
  })
  if not analyzed_code:
      return {
        "messages": [MESSAGE_FLOW_FAILED_LLM],
        "flow": [ANALYZE_FLOW_FAILED_LLM]
      }
  return {
    "analyzed_code": analyzed_code,
    "messages": [AIMessage(content=str(analyzed_code))],
    "flow": [ANALYZE_FLOW_SUCCESS]
  }

## generate tests node

In [7]:
from typing import Dict, Any, List, Optional
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
import json
# (Giả sử UnitTestWorkflowState đã được định nghĩa ở đâu đó)
# from .state import UnitTestWorkflowState

# --- Constants cho generate_tests_node ---
GENERATE_NODE_NAME = "Generate Tests Node"
GENERATE_FLOW_SUCCESS = f"{GENERATE_NODE_NAME}: Test generation successful"
GENERATE_FLOW_FAILED_NO_ANALYSIS = f"{GENERATE_NODE_NAME}: Failed (No code analysis found)"
GENERATE_FLOW_FAILED_INVALID_ANALYSIS = f"{GENERATE_NODE_NAME}: Failed (Invalid code analysis format)"
GENERATE_FLOW_FAILED_LLM = f"{GENERATE_NODE_NAME}: Failed (LLM test generation error)"
GENERATE_FLOW_FAILED_SYNTAX_IN_TESTS = f"{GENERATE_NODE_NAME}: Failed (Generated tests have syntax errors)"
GENERATE_FLOW_FAILED_UNEXPECTED = f"{GENERATE_NODE_NAME}: Failed (Unexpected Error)"

MESSAGE_NO_CODE_ANALYSIS = "No code analysis found in state. Cannot generate tests."
MESSAGE_INVALID_ANALYSIS_FORMAT = "Code analysis data is missing 'components' or has an invalid format."
MESSAGE_SYNTAX_ERROR_IN_GENERATED_TESTS = "The LLM-generated test code has a syntax error."



def generate_tests_node(state: UnitTestWorkflowState, test_generation_chain ) -> Dict[str, Any]:
    """
    Generates unit test code based on the code analysis provided.
    Uses an LLM to write test cases for identified components.
    """

    analyzed_code: Optional[Dict[str, Any]] = state.get('analyzed_code')
    original_code: Optional[str] = state.get('original_code') 
    generation_attempts: int = state.get('generation_attempts', 0)    
    
    if not analyzed_code:
        error_system_message = SystemMessage(content=MESSAGE_NO_CODE_ANALYSIS)
        return {
            "messages": [error_system_message],
            "flow": [GENERATE_FLOW_FAILED_NO_ANALYSIS],
        }

    components_to_test = analyzed_code.get("components")
    if not components_to_test or not isinstance(components_to_test, list):
        error_system_message = SystemMessage(content=MESSAGE_INVALID_ANALYSIS_FORMAT)
        return {
            "messages":  [error_system_message],
            "flow":  [GENERATE_FLOW_FAILED_INVALID_ANALYSIS],
        }
    last_test_code = state.get("test_code", "")
    last_evaluation = json.dumps(state.get("evaluation"))
    if last_test_code and last_evaluation:
        feedback_msg = f"Previous generated test code:\n{last_test_code}\n\nPrevious evaluation for the test code:\n{last_evaluation}"
    else:
        feedback_msg = ""          
    
    test_codes = []

    for component in components_to_test:
        try:
            test_code = test_generation_chain.invoke({
                "original_code_snippet": original_code,
                "component_name": component["name"],
                "component_type": component["type"],
                "component_signature": component.get("signature", ""),
                "component_description": component.get("description", ""),
                "key_behaviors": component.get("key_behaviors", []),
                "edge_cases": component.get("edge_cases", []),
                "feedback": feedback_msg
            })
            test_codes.append(test_code.content)
        except Exception as e:
            error_system_message = SystemMessage(
                content=f"Error generating tests for component {component['name']}: {type(e).__name__} - {e}"
                )
            return {
                "messages": [error_system_message],
                "flow": [GENERATE_FLOW_FAILED_LLM],
            }

    if not test_codes:
        # This case means no components to test, or LLM failed for all without throwing an exception caught above
        no_tests_msg = "No test classes were generated. Check code analysis or LLM responses."
        error_system_message = SystemMessage(content=no_tests_msg)
        return {
            "messages": [error_system_message],
            "flow": [GENERATE_FLOW_FAILED_LLM], 
        }

    # Combine all generated test classes and necessary imports
    all_generated_tests_str =  "\n\n".join(test_codes)
    all_generated_tests_str += "\n\n# You might want to add a way to run these tests if needed for standalone execution\n"
    all_generated_tests_str += "# e.g., if __name__ == '__main__': unittest.main()\n"
    # NOTE: But for programmatic use (like coverage.py), this main block is often not needed or added later.
    return {
        "messages": [AIMessage(content=all_generated_tests_str)],
        "test_code": [all_generated_tests_str],
        "flow": [GENERATE_FLOW_SUCCESS],
        "generation_attempts": generation_attempts + 1,
    }
    

## Evaluate tests node

In [8]:
import json
from typing import TypedDict, Dict, Any, List, Optional


EVALUATE_LLM_NODE_NAME = "Evaluate Tests Quality (LLM)"
# EVALUATE_LLM_FLOW_ASSESSMENT_MET = f"{EVALUATE_LLM_NODE_NAME}: Test quality assessment target met"
# EVALUATE_LLM_FLOW_ASSESSMENT_NOT_MET = f"{EVALUATE_LLM_NODE_NAME}: Test quality assessment target NOT met"
EVALUATE_LLM_FLOW_ASSESSMENT_GENERATED = f"{EVALUATE_LLM_NODE_NAME}: Test quality assessment generated with qualitative_assessment"
EVALUATE_LLM_FLOW_MAX_ATTEMPTS_REACHED = f"{EVALUATE_LLM_NODE_NAME}: Max attempts reached, assessment target NOT met" 
EVALUATE_LLM_FLOW_FAILED_NO_TESTS = f"{EVALUATE_LLM_NODE_NAME}: Failed (No generated tests to evaluate)"
EVALUATE_LLM_FLOW_FAILED_NO_CODE = f"{EVALUATE_LLM_NODE_NAME}: Failed (No original code to evaluate against)"
EVALUATE_LLM_FLOW_FAILED_NO_ANALYSIS = f"{EVALUATE_LLM_NODE_NAME}: Failed (No code analysis provided for evaluation)"
EVALUATE_LLM_FLOW_FAILED_LLM_EVAL = f"{EVALUATE_LLM_NODE_NAME}: Failed (LLM evaluation error)"
EVALUATE_LLM_FLOW_FAILED_UNEXPECTED = f"{EVALUATE_LLM_NODE_NAME}: Failed (Unexpected Error)"

MESSAGE_NO_GENERATED_TESTS = "No generated test code found in state. Cannot evaluate." # Cập nhật message
MESSAGE_NO_ORIGINAL_CODE_FOR_EVALUATION = "No original code found in state. Cannot evaluate." # Cập nhật message
MESSAGE_NO_CODE_ANALYSIS_FOR_EVALUATION = "Code analysis is missing, which is crucial for LLM-based test evaluation."
MESSAGE_LLM_EVALUATION_ERROR = "An error occurred during LLM-based test evaluation."

def evaluate_tests_llm_node(state: UnitTestWorkflowState, eval_chain) -> Dict[str, Any]:
    """
    Evaluates the generated test code using an LLM to assess its quality,
    completeness, and likely coverage against the original code and its analysis.
    """
    messages: List[BaseMessage] = state.get('messages', [])
    test_code: Optional[str] = state.get('test_code')
    original_code: Optional[str] = state.get('original_code')
    analyzed_code: Optional[Dict[str, Any]] = state.get('analyzed_code')
    generation_attempts: int = state.get('generation_attempts')
    max_generation_attempts: int = state.get('max_generation_attempts', 3)
    current_flow: List[str] = state.get('flow', [])
    
    if not original_code:
        return {
            "messages": [SystemMessage(content=MESSAGE_NO_ORIGINAL_CODE_FOR_EVALUATION)],
            "flow": [EVALUATE_LLM_FLOW_FAILED_NO_CODE],
        }
    if not test_code:
        if generation_attempts>= max_generation_attempts:
            error_msg = f"{error_msg} Max attempts ({max_generation_attempts}) reached."
        else:
            error_msg = MESSAGE_NO_GENERATED_TESTS
        return {
            "messages": [SystemMessage(content=error_msg)],
            "flow": [EVALUATE_LLM_FLOW_FAILED_NO_TESTS],
        }
    if not analyzed_code:
        return {
            "messages": [SystemMessage(content=MESSAGE_NO_CODE_ANALYSIS_FOR_EVALUATION)],
            "flow": [EVALUATE_LLM_FLOW_FAILED_NO_ANALYSIS],
        }
    try: 
        code_analysis_str = json.dumps(analyzed_code, indent=2)
    except TypeError as e:
        error_msg = f"Could not serialize analyzed_code to JSON: {e}"
        return {
            "messages": [SystemMessage(content=error_msg)],
            "flow": [EVALUATE_LLM_FLOW_FAILED_UNEXPECTED],
        }
        
    llm_eval_result = eval_chain.invoke(
        {
        "original_code": original_code,
        "code_analysis_json": str(analyzed_code),
        "test_code": test_code
        }
    )
    if not llm_eval_result or not isinstance(llm_eval_result, dict):
        error_msg = "LLM evaluation returned no data or invalid format."
        return {
            "messages": [SystemMessage(content=error_msg)],
            "flow": [EVALUATE_LLM_FLOW_FAILED_LLM_EVAL],
        }

    return {
        "messages": [AIMessage(content=json.dumps(llm_eval_result))],
        "evaluation": llm_eval_result,
        "flow": [f"{EVALUATE_LLM_FLOW_ASSESSMENT_GENERATED} [{llm_eval_result.get('qualitative_assessment')}]" ],
    }
    

# Test Flow

In [9]:
code = ''' from langchain.chains.llm import LLMChain
from langchain.schema.messages import AIMessage
from core.utils.schema import Code
from core.graph.utils.state import CodeGenState
from typing_extensions import List
from langchain.schema.messages import BaseMessage

STEP_NAME = "Reflect"
def reflect(state: CodeGenState, code_gen_chain: LLMChain, framework: str) -> dict:
    """
    Performs a reflection step upon encountering an error during code generation.

    This node is typically triggered after a failed `code_check`. It invokes
    the provided `code_gen_chain` with the current message history (which should
    include the error message from the check) and context.

    It assumes the `code_gen_chain`, when prompted with the error context,
    will provide reflective text or analysis within the `prefix` field of its
    structured `Code` output.

    This reflection text is then appended to the message history as a new
    `AIMessage`, preserving the conversation context for the subsequent
    generation attempt. The iteration count remains unchanged.

    Args:
        state: The current graph state, containing messages, iterations,
            documentation, and previous generations.
        code_gen_chain: The LLMChain instance configured to generate
                        structured `Code` output. It's reused here for reflection.
        framework: The target coding framework (e.g., 'python').

    Returns:
        A dictionary containing updates for the graph state:
        - 'messages': The original messages list appended with the new
                    reflection AIMessage.
        - 'flow': A list containing the name of this node ("Reflect").
    """
    messages: List[BaseMessage] = state['messages']
    documentation_list: List[str] = [doc.page_content for doc in state.get('documentation', []) if hasattr(doc, 'page_content')]
    documentation: str = "\n".join(documentation_list)

    reflection_code: Code = code_gen_chain.invoke(
        {"context": documentation, "question": messages, "framework": framework}
    )
    reflection_message = AIMessage(
        content=f"Reflection on the error: {reflection_code.prefix}"
    )
    return {
        "messages": reflection_message, 
        "flow": [STEP_NAME]
        }
'''

## Test Flow

In [10]:
def merge_dict(dict1, dict2):
    merged_dict = {}
    for key, value in dict1.items():
        merged_dict[key] = value
    for key, value in dict2.items():
        if key in merged_dict:
            if isinstance(merged_dict[key], str) and isinstance(value, str):
                merged_dict[key] += value
            elif isinstance(merged_dict[key], list) and isinstance(value, list):
                merged_dict[key].extend(value)
            else:
                merged_dict[key] = value
        else:
            merged_dict[key] = value
    return merged_dict

In [11]:
# Input
input_state = UnitTestWorkflowState(
    original_code=code,
    analyzed_code="",
    flow=[]
)

In [12]:
# Create chains 
code_analysis_chain = create_code_analysis_chain(model="gemini-2.0-flash",temperature=0.0)
test_generation_chain = create_test_generation_chain(model="gemini-2.0-flash",temperature=0.0)
eval_chain = create_evaluation_chain(model="gemini-2.0-flash",temperature=0.0)

In [13]:
code_analysis_state = code_analysis(input_state, code_analysis_chain)

In [14]:
code_analysis_state

{'analyzed_code': {'summary': 'The code defines a reflection step in a code generation process. It uses an LLMChain to generate a reflection on an error encountered during code generation and appends this reflection to the message history.',
  'components': [{'type': 'function',
    'name': 'reflect',
    'description': 'Performs a reflection step upon encountering an error during code generation. It invokes the provided `code_gen_chain` with the current message history (which should include the error message from the check) and context. The reflection text is then appended to the message history as a new `AIMessage`, preserving the conversation context for the subsequent generation attempt.',
    'signature': 'def reflect(state: CodeGenState, code_gen_chain: LLMChain, framework: str) -> dict:',
    'parameters': [{'name': 'state', 'type': 'CodeGenState'},
     {'name': 'code_gen_chain', 'type': 'LLMChain'},
     {'name': 'framework', 'type': 'str'}],
    'returns': 'dict',
    'key_be

In [15]:
test_generation_state = generate_tests_node(code_analysis_state,test_generation_chain)

In [16]:
test_generation_state

{'messages': [AIMessage(content='```python\nimport unittest\nfrom unittest.mock import MagicMock, patch\n\nclass TestReflect(unittest.TestCase):\n\n    def test_reflect_basic(self):\n        """Tests the basic functionality of the reflect function."""\n        from promptflow.tools.open_model_llm import CodeGenState, reflect\n        mock_code_gen_chain = MagicMock()\n        mock_code_gen_chain.run.return_value = "This is a reflection."\n        state = CodeGenState(messages=["initial message"], documentation=["doc1", "doc2"])\n        state.documentation = [{"page_content": "doc1"}, {"page_content": "doc2"}]\n        framework = "test_framework"\n\n        result = reflect(state, mock_code_gen_chain, framework)\n\n        self.assertEqual(len(result["messages"]), 2)\n        self.assertEqual(result["messages"][-1].content, "This is a reflection.")\n        self.assertEqual(result["flow"], "REFLECTION")\n        mock_code_gen_chain.run.assert_called_once()\n\n    def test_reflect_empt

In [17]:
concated_state = merge_dict(input_state, code_analysis_state)
concated_state = merge_dict(concated_state, test_generation_state)
concated_state

{'original_code': ' from langchain.chains.llm import LLMChain\nfrom langchain.schema.messages import AIMessage\nfrom core.utils.schema import Code\nfrom core.graph.utils.state import CodeGenState\nfrom typing_extensions import List\nfrom langchain.schema.messages import BaseMessage\n\nSTEP_NAME = "Reflect"\ndef reflect(state: CodeGenState, code_gen_chain: LLMChain, framework: str) -> dict:\n    """\n    Performs a reflection step upon encountering an error during code generation.\n\n    This node is typically triggered after a failed `code_check`. It invokes\n    the provided `code_gen_chain` with the current message history (which should\n    include the error message from the check) and context.\n\n    It assumes the `code_gen_chain`, when prompted with the error context,\n    will provide reflective text or analysis within the `prefix` field of its\n    structured `Code` output.\n\n    This reflection text is then appended to the message history as a new\n    `AIMessage`, preserving

In [18]:
eval_state = evaluate_tests_llm_node(
    concated_state,
    eval_chain
)
eval_state

{'messages': [AIMessage(content='{"qualitative_assessment": "medium", "confidence_score": 6.0, "positive_feedback": ["The tests cover basic functionality, including cases with empty documentation, empty reflection string, missing documentation attribute, and missing \'page_content\' within documentation elements.", "The tests use mocks effectively to isolate the `reflect` function and control the behavior of the `code_gen_chain`."], "areas_for_improvement": ["The tests do not verify the arguments passed to `code_gen_chain.invoke`. It\'s important to ensure the context, question, and framework are passed correctly.", "The tests use `CodeGenState` directly, which might not be the intended usage. It should be initialized with the correct parameters, or a mock should be used to control its behavior more precisely.", "The tests do not assert the type of the returned \'messages\'. It should be a list containing an `AIMessage`.", "The tests do not check if the documentation is correctly extra

In [19]:
eval_state = merge_dict(concated_state, test_generation_state)
concated_state

{'original_code': ' from langchain.chains.llm import LLMChain\nfrom langchain.schema.messages import AIMessage\nfrom core.utils.schema import Code\nfrom core.graph.utils.state import CodeGenState\nfrom typing_extensions import List\nfrom langchain.schema.messages import BaseMessage\n\nSTEP_NAME = "Reflect"\ndef reflect(state: CodeGenState, code_gen_chain: LLMChain, framework: str) -> dict:\n    """\n    Performs a reflection step upon encountering an error during code generation.\n\n    This node is typically triggered after a failed `code_check`. It invokes\n    the provided `code_gen_chain` with the current message history (which should\n    include the error message from the check) and context.\n\n    It assumes the `code_gen_chain`, when prompted with the error context,\n    will provide reflective text or analysis within the `prefix` field of its\n    structured `Code` output.\n\n    This reflection text is then appended to the message history as a new\n    `AIMessage`, preserving

In [20]:
# Regenrate
test_generation_state = generate_tests_node(concated_state,test_generation_chain)
test_generation_state

{'messages': [AIMessage(content='```python\nimport unittest\nfrom unittest.mock import MagicMock\nfrom langchain.schema.messages import AIMessage, BaseMessage\nfrom core.utils.schema import Code\nfrom core.graph.utils.state import CodeGenState\nfrom typing import List, Dict, Any\n\nclass TestReflect(unittest.TestCase):\n\n    def test_reflect_basic(self):\n        """Tests the basic functionality of the reflect function."""\n        mock_code_gen_chain = MagicMock()\n        mock_code = Code(prefix="This is a reflection.", code="")\n        mock_code_gen_chain.invoke.return_value = mock_code\n        state: Dict[str, Any] = {\'messages\': [AIMessage(content="initial message")], \'documentation\': [{"page_content": "doc1"}, {"page_content": "doc2"}]}\n        framework = "test_framework"\n\n        result = reflect(state, mock_code_gen_chain, framework)\n\n        self.assertIsInstance(result["messages"], AIMessage)\n        self.assertEqual(result["messages"].content, "Reflection on th

# Graph

In [23]:
from langgraph.graph import StateGraph, END, START

workflow = StateGraph(UnitTestWorkflowState)
workflow.add_node("analyze_code_node", 
                lambda state: code_analysis(state, code_analysis_chain))
workflow.add_node("generate_tests_node",
                lambda state: generate_tests_node(state, test_generation_chain))
workflow.add_node("evaluate_tests_node",
                lambda state: evaluate_tests_llm_node(state, eval_chain))
workflow.add_edge(START, "analyze_code_node")
workflow.add_edge("analyze_code_node", "generate_tests_node")
workflow.add_edge("generate_tests_node", "evaluate_tests_node")
workflow.add_edge("evaluate_tests_node", END)
workflow.add_conditional_edges(
    "evaluate_tests_node",
    decision_to_end_workflow,
    {
        "end": END,
        "regenerate": "generate_tests_node",
    }
)
workflow = workflow.compile()

In [24]:
wf_result= workflow.invoke(
    {
        "original_code": code,
        "analyzed_code": "",
        "test_code": "",
        "flow": [],
        "max_generation_attempts": 3,
        "generation_attempts": 0
    }
)

In [25]:
def pretty_print(state):
    print("Flow:")
    for flow in state["flow"]:
        print(flow, end=" -> ")
    print("\n")
    print("Messages:")
    for message in state["messages"]:
        if isinstance(message, AIMessage):
            print("="*20, "AIMessage", "="*20)
        if isinstance(message, HumanMessage):
            print("="*20, "HumanMessage", "="*20)
        if isinstance(message, SystemMessage):
            print("="*20, "SystemMessage", "="*20)
        print(message.content)


In [26]:
pretty_print(wf_result)

Flow:
Analyze Code Node: Analysis successful -> Generate Tests Node: Test generation successful -> Evaluate Tests Quality (LLM): Test quality assessment generated with qualitative_assessment [medium] -> Generate Tests Node: Test generation successful -> Evaluate Tests Quality (LLM): Test quality assessment generated with qualitative_assessment [medium] -> Generate Tests Node: Test generation successful -> Evaluate Tests Quality (LLM): Test quality assessment generated with qualitative_assessment [high] -> 

Messages:
{'summary': 'The code defines a reflection step in a code generation process. It uses an LLMChain to analyze errors and generate reflective text, which is then added to the message history to guide subsequent code generation attempts.', 'components': [{'type': 'function', 'name': 'reflect', 'description': 'Performs a reflection step upon encountering an error during code generation. It invokes the provided `code_gen_chain` with the current message history (which should inc