In [1]:
from pydantic import Field
from typing import TypedDict

class ExtractCode(TypedDict):
    '''
    Formatted Output for llm
    Attributes:
    code (str): code generated by the llm
    language (str): language of the code generated by the llm 
    '''
    extracted_code: str = Field(
    ...,
    description="Solution of the code",
    )
    language: str | None = Field(
    ..., description="Programming language of the code"
    )

In [2]:
import nest_asyncio
nest_asyncio.apply()


In [3]:
from langgraph.graph.message import add_messages
from typing import Annotated,Dict


class State(TypedDict):
    """State class containing messages and extracted code."""

    messages: Annotated[list, add_messages]

In [4]:
from pydantic.dataclasses import dataclass
from typing import Any

@dataclass
class AgentDeps:
    api_key : str
    http_client : Any
    
    class Config:
        arbitrary_types_allowed = True

In [5]:
from typing import Optional, Type, Any, Literal, get_type_hints
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.graph.state import CompiledStateGraph
from langgraph.managed import RemainingSteps
from langchain_core.messages import HumanMessage


class MessagesWithSteps(MessagesState):
    remaining_steps: RemainingSteps


def end_or_reflect(state: MessagesWithSteps) -> Literal[END, "graph"]:
    if state["remaining_steps"] < 5:
        return END
    if len(state["messages"]) == 0:
        return END
    last_message = state["messages"][-1]
    if isinstance(last_message, HumanMessage):
        return "graph"
    else:
        return END


def create_reflection_graph(
    graph: CompiledStateGraph,
    reflection: CompiledStateGraph,
    state_schema: Optional[Type[Any]] = None,
    config_schema: Optional[Type[Any]] = None,
) -> StateGraph:
    _state_schema = state_schema or graph.builder.schema

    if "remaining_steps" in _state_schema.__annotations__:
        raise ValueError(
            "Has key 'remaining_steps' in state_schema, this shadows a built in key"
        )

    if "messages" not in _state_schema.__annotations__:
        raise ValueError("Missing required key 'messages' in state_schema")

    class StateSchema(_state_schema):
        remaining_steps: RemainingSteps

    rgraph = StateGraph(StateSchema, config_schema=config_schema)
    rgraph.add_node("graph", graph)
    rgraph.add_node("reflection", reflection)
    rgraph.add_edge(START, "graph")
    rgraph.add_edge("graph", "reflection")
    rgraph.add_conditional_edges("reflection", end_or_reflect)
    return rgraph

In [6]:
import os
from dotenv import load_dotenv

load_dotenv()

CODE_SYSTEM_PROMPT = os.getenv("CODE_SYSTEM_PROMPT")
JUDGE_SYSTEM_PROMPT = os.getenv("JUDGE_SYSTEM_PROMPT")
GROQ_API_TOKEN = os.getenv("GROQ_API_KEY")
COMPILER_URL = os.getenv("COMPILER_URL")


In [20]:
from pydantic_ai import Agent
from httpx import AsyncClient

agent_deps = AgentDeps(api_key=GROQ_API_TOKEN,http_client=AsyncClient)

code_model_settings = {"temperature": 0.3}
judge_model_settings = {"temperature": 0.7}

code_agent : Agent[None,ExtractCode] = Agent(
            model="groq:llama3-8b-8192",
            system_prompt=CODE_SYSTEM_PROMPT,
            model_settings=code_model_settings,
            output_type = ExtractCode,
        )

judge_agent : Agent[None,str] = Agent(
            model="groq:meta-llama/llama-4-scout-17b-16e-instruct",
            system_prompt=JUDGE_SYSTEM_PROMPT,
            model_settings=judge_model_settings,
        )

In [21]:
from langchain_core.messages import AIMessage

def call_model(state: Dict[str,Any]):
    code = code_agent.run_sync(state["messages"][-1].content)
    state["messages"].append(AIMessage(code.output["extracted_code"]))
    return state
    

In [22]:
import requests
from langchain_core.messages import HumanMessage

def compiler(state : Dict[str,Any]):
    response = requests.post(COMPILER_URL,
                            json={
                                "code": state["messages"][-1].content,
                                "language": "python",
                            },
                        )
    result = response.json()
    if response.status_code == 200:
        output = result.get("output")
    else:
        output = result
    judgement = judge_agent.run_sync(state["messages"][-1].content+"\n\n\nOutput:\n\n"+output)
    state["messages"].append(HumanMessage(judgement.output))
    return state
        

In [23]:
from langgraph.graph import StateGraph,START,END

agent_graph = (
            StateGraph(State)
            .add_node(call_model, "call_model")
            .add_edge("call_model", END)
            .add_edge(START,"call_model")
            .compile()
        )

judge_graph = (
            StateGraph(State)
            .add_node(compiler,"compiler")
            .add_edge("compiler",END)
            .add_edge(START,"compiler")
            .compile()
)

reflection_agent = create_reflection_graph(agent_graph,judge_graph).compile()

In [24]:
import pprint

state = {
    "messages": [
        {
            "role": "user",
            "content": (
                "Given an array of integers nums and an integer target, "
                "return indices of the two numbers such that they add up to target. "
                "You may assume that each input would have exactly one solution, "
                "and you may not use the same element twice. Return the answer in any order.\n\n"
                "Example:\nInput: nums = [2, 7, 11, 15], target = 9\nOutput: [0, 1]"
            )
        }
    ]
}

result = reflection_agent.invoke(state)
print(result["messages"][-2].content)

def twoSum(nums: list[int], target: int) -> list[int]:
    """Returns the indices of the two numbers in the list that add up to the target.
    Args:
    - nums: A list of integers.
    - target: The target sum.
    Returns:
    - A list containing the indices of the two numbers that add up to the target.
    - If no solution is found, an empty list is returned."
    if not isinstance(nums, list) or len(nums) < 2:
        raise ValueError('Input must be a list with at least two elements')
    if not all(isinstance(num, int) for num in nums):
        raise ValueError('All elements in the list must be integers')
    if not isinstance(target, int):
        raise ValueError('Target must be an integer')
    num_indices = {}
    for i, num in enumerate(nums):
        complement = target - num
        if complement in num_indices:
            return [num_indices[complement], i]
        num_indices[num] = i
    return []
