# Pips

In [2]:
!pip install -q langchain pytest black pycodestyle flake8 graphviz networkx matplotlib openai numpy pandas tqdm langchain-core langgraph
!pip install -q transformers flash-attn==2.5.5 bitsandbytes>=0.41.1 accelerate==0.27.2

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


# Imports

In [3]:
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from transformers import AutoModelForCausalLM, AutoTokenizer
import json, os, ast, logging, re, pandas as pd, csv
from langchain_core.runnables import RunnableLambda
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List, Dict, Any, Union
from langchain_core.messages import HumanMessage
from langgraph.graph.message import add_messages
from datetime import datetime
from openai import OpenAI
import subprocess
import networkx as nx
import warnings, tempfile
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.DEBUG)

In [4]:
# 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
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Model Intallation

In [5]:
# def load_model(model_id: str):
#     model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True,torch_dtype="auto")
#     tokenizer = AutoTokenizer.from_pretrained(model_id)
#     return model, tokenizer

# model_id = "NousResearch/Hermes-2-Theta-Llama-3-8B"
# # model_id = "microsoft/Phi-3-mini-4k-instruct"
# model, tokenizer = load_model(model_id)

# Prompts

In [16]:
code_schema = ResponseSchema(name="code", description="Python code")
output_parser = StructuredOutputParser.from_response_schemas([code_schema])
format_instructions = output_parser.get_format_instructions()
PROMPT_TEMPLATES = {
    "code_generation": """
    Provide a Python single solution oriented code solution to the following prompt.

    Prompt: {prompt}

    {format_instructions}
    """,

    "unit_test_generation": """
    Generate 2 to 3 pytest test cases with appropriate assertions for the following code. Append the original code to your response. Your response must be in JSON format with a single key "code".

    Code to test:
    {code}

    {format_instructions}
    """,

    "error_solution": """
    You are an ErrorSolutionAgent. Your task is to provide solutions to fix the errors in the given code. Use the provided prompt, code, and error details to formulate your solution. Your response must be in JSON format with a single key "code". Explanations are not needed.

    Prompt: {prompt}
    Code:
    {code}
    Errors:
    {errors}

    {format_instructions}
    """,

    "code_regeneration": """
    You are a CodeRegeneratorAgent. Based on the given prompt and errors, regenerate the code using the provided solution. Ensure the regenerated code addresses the previous errors. Your response must be in JSON format with a single key "code".

    Prompt: {prompt}
    Previous Code:
    {previous_code}
    Errors:
    {errors}
    Solution:
    {solution}

    {format_instructions}
    """
}

# State

In [17]:
class AgentGraphState(TypedDict):
    row_data: str
    status: str
    code_generator_response: Dict[str, str]
    compiler_errors: List[str]
    error_solver_response: Dict[str, str]
    regenerated_code: List[str]
    end_node_status: str

state = {
    "row_data": "", 
    "code_generator_response": {"code": ""},
    "compiler_errors": [],
    "error_solver_response": {"solution": ""},
    "regenerated_code": [""],
    "status": "",
    "end_node_status": ""
}

def log_state_to_csv(state):
    file_exists = os.path.isfile('log.csv')
    
    with open('log.csv', 'a', newline='') as csvfile:
        fieldnames = state.keys()
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        if not file_exists:
            writer.writeheader() 

        writer.writerow(state)  


In [18]:
class ModelManager:
    _instances = {}

    @classmethod
    def get_model(cls, model_id: str):
        if model_id not in cls._instances:
            logging.info(f"Loading model '{model_id}' for the first time.")
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                device_map="cuda",
                trust_remote_code=True,
                torch_dtype="auto"
            )
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            cls._instances[model_id] = (model, tokenizer)
        else:
            logging.info(f"Using cached model '{model_id}'.")
        return cls._instances[model_id]


# Util

In [19]:
# ----------------------------------------------------------------------------------------------------------------
# Reformat Code
# ----------------------------------------------------------------------------------------------------------------

def reformat_code(code: str) -> str:
    try:
        with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file:
            temp_file.write(code.encode())
            temp_file_path = temp_file.name
        
        subprocess.run(["black", temp_file_path], check=True)
        
        with open(temp_file_path, 'r') as formatted_file:
            formatted_code = formatted_file.read()
        
        os.remove(temp_file_path)
        
        return formatted_code
    except subprocess.CalledProcessError as e:
        logging.error(f"Formatting failed with error: {e}")
        return code
    except Exception as e:
        logging.error(f"An unexpected error occurred: {e}")
        return code
# ----------------------------------------------------------------------------------------------------------------
# Run Command
# ----------------------------------------------------------------------------------------------------------------
def run_command(command: str, code: str, options: list = None):
    try:
        logging.debug(f"Running command: {command} with options: {options}")
        logging.debug(f"Code to be executed:\n{code}")
        logging.debug(f"~---------------------------------------------~")

        with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file:
            temp_file.write(code.encode())
            temp_file_path = temp_file.name

        cmd = [command] + (options or []) + [temp_file_path]
        logging.debug(f"Executing command: {' '.join(cmd)}")

        result = subprocess.run(cmd, capture_output=True, text=True)
        logging.debug(f"Command completed with return code: {result.returncode}")

        logging.debug(f"Command output:\n{result.stdout}")
        if result.returncode != 0:
            logging.error(f"Command error output:\n{result.stderr}")

        os.remove(temp_file_path)
        debug_info = {
            "command": ' '.join(cmd),
            "return_code": result.returncode,
            "stdout": result.stdout,
            "stderr": result.stderr
        }
        logging.debug(f"Debug Info: {debug_info}")

        return {
            "success": result.returncode == 0,
            "errors": result.stderr.splitlines() if result.returncode != 0 else [],
            "output": result.stdout
        }
    except Exception as e:
        logging.error(f"Command '{command}' failed: {e}")
        return {
            "success": False,
            "errors": [str(e)],
            "output": ""
        }
# ----------------------------------------------------------------------------------------------------------------
# Strip ANSI Escape Codes
# ----------------------------------------------------------------------------------------------------------------
def strip_ansi_escape_codes(text: str) -> str:
    ansi_escape = re.compile(r'\x1b\[[0-?]*[ -/]*[@-~]')
    return ansi_escape.sub('', text)

# ----------------------------------------------------------------------------------------------------------------
# Is Valid Python
# ----------------------------------------------------------------------------------------------------------------
def is_valid_python(code: str) -> bool:
    try:
        compile(code, '<string>', 'exec')
        return True
    except SyntaxError as e:
        logging.error(f"Syntax error in test code: {e}")
        return False

# ----------------------------------------------------------------------------------------------------------------
# Handle Error
# ----------------------------------------------------------------------------------------------------------------
def handle_error(state, errors, filter_blank_lines=False):
    if filter_blank_lines:
        errors = [error for error in errors if not is_blank_line_issue(error)]
        all_errors = [err for sublist in errors for err in sublist if err]
        unique_errors = list(set(all_errors))
        formatted_errors = "\n".join(f"- {error}" for error in unique_errors)
        logging.error(f"Compiler errors encountered:\n{formatted_errors}")
        state.update_state("compiler_errors", unique_errors)
    return state

# ----------------------------------------------------------------------------------------------------------------
# Is Blank Line Issue
# ----------------------------------------------------------------------------------------------------------------
def is_blank_line_issue(error: str) -> bool:
    return any(keyword in error for keyword in ["E302", "E305", "W292"])

# ----------------------------------------------------------------------------------------------------------------
# Extract Functions and Imports
# ----------------------------------------------------------------------------------------------------------------
def append_missing_functions_and_imports(pre_code: str, post_code: str) -> str:
    post_code = clean_invalid_imports(post_code)
    
    logging.debug(f"pre code before appending missing functions --->: {pre_code}")
    logging.debug(f"~---------------------------------------------~")
    
    logging.debug(f"post code before appending missing functions --->: {post_code}")
    logging.debug(f"~---------------------------------------------~")
    
    try:
        pre_functions, pre_imports = extract_functions_and_imports(pre_code)
    except SyntaxError as e:
        logging.error(f"Syntax error in pre_code: {e}")
        pre_functions, pre_imports = {}, []
    
    try:
        post_functions, post_imports = extract_functions_and_imports(post_code)
    except SyntaxError as e:
        logging.error(f"Syntax error in post_code: {e}")
        post_functions, post_imports = {}, []

    for func_name, func_code in pre_functions.items():
        if func_name != 'main' and func_name not in post_functions:
            post_code += f"\n\n{func_code}"

    for imp in pre_imports:
        if imp not in post_imports and is_valid_import(imp):
            post_code = imp + "\n" + post_code

    lines = post_code.strip().split('\n')
    if lines and not (lines[0].strip().startswith('import') or lines[0].strip().startswith('def')):
        lines.pop(0)
        post_code = '\n'.join(lines)

    post_code = '\n'.join(line.rstrip() for line in post_code.splitlines())

    logging.debug(f"post code after appending missing functions --->: {post_code}")
    logging.debug(f"~---------------------------------------------~")
    
    return post_code

# ----------------------------------------------------------------------------------------------------------------
# Extract Functions
# ----------------------------------------------------------------------------------------------------------------
def is_valid_import(import_statement: str) -> bool:
    try:
        tree = ast.parse(import_statement)
        for node in ast.walk(tree):
            if isinstance(node, ast.Import):
                for alias in node.names:
                    importlib.import_module(alias.name)
            elif isinstance(node, ast.ImportFrom):
                importlib.import_module(node.module)
        return True
    except Exception:
        return False

def clean_invalid_imports(code: str) -> str:
    lines = code.splitlines()
    cleaned_lines = []
    for line in lines:
        if line.startswith(('import ', 'from ')):
            if is_valid_import(line):
                cleaned_lines.append(line)
        else:
            cleaned_lines.append(line)
    return '\n'.join(cleaned_lines)


def extract_functions_and_imports(code: str):
    tree = ast.parse(code)
    functions = {}
    imports = []
    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            func_name = node.name
            func_code = ast.unparse(node)
            functions[func_name] = func_code
        elif isinstance(node, (ast.Import, ast.ImportFrom)):
            imports.append(ast.unparse(node))
    return functions, imports

def extract_function_names(code: str) -> set:
    pattern = re.compile(r"def (\w+)\s*\(")
    return set(pattern.findall(code))

# ----------------------------------------------------------------------------------------------------------------
# Ensure Single Main Function
# ----------------------------------------------------------------------------------------------------------------
def ensure_single_main(test_code: str) -> str:
    main_pattern = re.compile(r"def main\s*\(", re.IGNORECASE)
    if len(main_pattern.findall(test_code)) > 1:
        test_code = re.sub(r"def main\s*\([^\)]*\):\n(?:\s+.*\n)*", "", test_code, count=len(main_pattern.findall(test_code)) - 1)
    return test_code

# ----------------------------------------------------------------------------------------------------------------
# Method: extract_code_from_json
# ----------------------------------------------------------------------------------------------------------------
def extract_code_from_json(text: str) -> str:
    try:
        json_pattern = re.compile(r'{\s*"code":\s*"((?:[^"\\]|\\.)*?)"\s*}', re.MULTILINE | re.DOTALL)
        match = json_pattern.search(text)
        
        if match:
            raw_json_code = match.group(1)
            logging.debug(f"Raw JSON code: {raw_json_code}")            
            json_code = raw_json_code.replace('\\n', '\n').replace('\\"', '"').strip()
            logging.debug(f"Formatted JSON code: {json_code}")
            
            try:
                data = json.loads(f'{{"code": "{json_code}"}}')
                logging.debug(f"Extracted code: {data.get('code')}")
                
                escaped_code = data.get('code', 'No code found').replace('"', '\\"')
                return f'{{"code": "{escaped_code}"}}'
            except json.JSONDecodeError as e:
                logging.debug(f'Invalid JSON format: {e}')
                return f'{{"error": "Invalid JSON format: {e}"}}'
        else:
            logging.debug('No JSON block found')
            return 'No JSON block found'
    except Exception as e:
        logging.debug(f'Error during extraction: {e}')
        return f'{{"error": "Error during extraction: {e}"}}'
# ----------------------------------------------------------------------------------------------------------------
# Method: remove_prompt_from_response
# ----------------------------------------------------------------------------------------------------------------
def remove_prompt_from_response(prompt: str, response: str) -> str:
    try:
        response_cleaned = response.replace(prompt, "").strip()   
        json_block_pattern = re.compile(r'```json(.*?)```', re.DOTALL)
        match = json_block_pattern.search(response_cleaned)
        
        if match:
            json_block = match.group(1).strip()
            logging.debug(f"response cleaned: {json_block}")
            logging.debug(f"--------------------------")
            # return f"{json_block}"
            return f"```json\n{json_block}\n```"
        else:
            logging.debug('No JSON block found')
            return 'No JSON block found'
    except Exception as e:
        logging.debug(f'Error during extraction: {e}')
        return f"Error during extraction: {e}"

# ----------------------------------------------------------------------------------------------------------------


# Agent

In [20]:
class Agent:
    def __init__(self, state: dict, model_id: str = "microsoft/Phi-3-mini-4k-instruct", cache_dir: str = "my_models/phi_3_mini", temperature: float = 0, max_tokens: int = 500):
        self.state = state
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.model_id = model_id
        self.model, self.tokenizer = ModelManager.get_model(self.model_id)

    def get_llm(self):
        return self.model, self.tokenizer

    def complete(self, prompt: str, max_tokens: int, max_retries: int = 5):
        model, tokenizer = self.get_llm()
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
        
        retries = 0
        while retries < max_retries:
            try:
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    temperature=self.temperature,
                    do_sample=True
                )
                result = tokenizer.decode(outputs[0], skip_special_tokens=True)
                result = remove_prompt_from_response(prompt, result)

                code = ""
                try:
                    if 'output_parser' in globals():
                        parsed_output = output_parser.parse(result)
                        code = parsed_output.get("code", "")
                    else:
                        logging.warning("No output_parser available. Attempting to extract JSON block.")
                except Exception as parse_exception:
                    logging.error(f"Parsing error: {parse_exception}")

                if not code:
                    logging.info("Attempting backup extraction of JSON block.")
                    code = extract_code_from_json(result)
                    
                    if code and not code.startswith(('No JSON block found', 'Empty JSON block', 'Malformed JSON block', 'Invalid JSON format')):
                        pass
                    else:
                        retries += 1
                        logging.info(f"Retrying model call due to backup extraction failure. Retry {retries}/{max_retries}.")
                        if retries >= max_retries:
                            logging.error("Maximum retries reached. Returning failure response.")
                            return {"choices": [{"text": "Failed to generate valid code after multiple retries."}]}
                        continue

                if code:
                    logging.info(f"Code generated: {code}")
                    return {"choices": [{"text": code}]}
                else:
                    logging.warning("No valid code found in the response.")
                    return {"choices": [{"text": "No valid code found."}]}

            except Exception as e:
                logging.error(f"An error occurred during inference: {e}")
                return {"choices": [{"text": "An error occurred during inference."}]}

        logging.error("Failed to generate valid code after multiple attempts.")
        return {"choices": [{"text": "Failed to generate valid code after multiple retries."}]}

    def update_state(self, key: str, value: any):
        if key not in self.state:
            logging.error(f"State key '{key}' not found in Agent state.")
            return
        
        if isinstance(self.state[key], list):
            self.state[key].append(value)
        else:
            self.state[key] = value
        logging.info(f"Updated state '{key}' with value: {value}")

# Code Generator

In [21]:
class CodeGeneratorAgent(Agent):
    def __init__(self, state: Dict[str, Any], model_id: str = None, cache_dir: str = None, temperature: float = 0, max_tokens: int = 500):
        super().__init__(state, model_id, cache_dir, temperature, max_tokens)

    def invoke(self, prompt: str):
        logging.info(f"Invoking CodeGeneratorAgent with prompt: {prompt}")
        
        formatted_prompt = PROMPT_TEMPLATES["code_generation"].format(
            prompt=prompt,
            format_instructions=format_instructions
        )
        
        logging.info(f"Formatted Prompt: {formatted_prompt}")

        response = self.complete(formatted_prompt, max_tokens=self.max_tokens)
        code = response["choices"][0]["text"].strip()
        
        if not code:
            logging.warning("No valid code found in response.")
        # else:
        #     # logging.info(f"Code generated: {code}")

        self.update_state("code_generator_response", {"code": code})

        return self.state


# Compiler Agent

In [22]:
class CompilerAgent(Agent):
    def __init__(self, state: AgentGraphState, model_id: str = None, cache_dir: str = None, temperature: float = 0, max_tokens: int = 500):
        super().__init__(state, model_id, cache_dir, temperature, max_tokens)
        logging.basicConfig(level=logging.INFO)

    def invoke(self, code: str):
        if not code:
            logging.error("No code provided to CompilerAgent")
            self.update_state("status", False)
            return handle_error(self.state, ["No code provided"])

        # Format Testing    
        formatted_code = reformat_code(code)
        logging.debug(f"Reformatted code:\n{formatted_code}")
        self.update_state("code_generator_response", {"code": formatted_code})

        # Static Analysis    
        static_analysis = run_command("pycodestyle", formatted_code)
        logging.info(f"Static analysis result: {static_analysis}")
        if not static_analysis["success"]:
            logging.warning(f"Static analysis failed with errors: {static_analysis['errors']}")
            self.update_state("status", False)
            return handle_error(self.state, static_analysis["errors"], filter_blank_lines=True)

        # Dynamic Analysis      
        dynamic_analysis = self.perform_dynamic_analysis(formatted_code)
        logging.info(f"Dynamic analysis result: {dynamic_analysis}")
        if not dynamic_analysis["success"]:
            logging.warning(f"Dynamic analysis failed with errors: {dynamic_analysis['errors']}")
            self.update_state("status", False)
            return handle_error(self.state, dynamic_analysis["errors"])

        # # Adding Dummy Error for Testing
        # dummy_error = "invalid syntax."
        # logging.error(dummy_error)
        # self.update_state("compiler_errors", [dummy_error])
        # self.update_state("status", False)
        
        return self.state

    def generate_unit_test(self, code: str):
        formatted_prompt = PROMPT_TEMPLATES["unit_test_generation"].format(code=code,
            format_instructions=format_instructions)
        try:
            response = self.complete(formatted_prompt, max_tokens=self.max_tokens)
            test_code = response["choices"][0]["text"].strip()
            
            if not test_code:
                logging.warning("Unit test code is empty.")
                return None, None
            
            logging.debug(f"Generated unit test before --->: {test_code}")
            logging.debug(f"~---------------------------------------------~")
            test_code = append_missing_functions_and_imports(code,test_code)
        
            
            logging.debug(f"Unit test after appending missing functions --->: {test_code}")
            logging.debug(f"~---------------------------------------------~")
            
            return test_code, formatted_prompt
        
        except Exception as e:
            logging.error(f"An error occurred during inference for unit test: {e}")
            return None, None

    def perform_dynamic_analysis(self, code: str):
        logging.info("Performing dynamic analysis using pytest.")
        
        for attempt in range(3):
            unit_test, formatted_prompt = self.generate_unit_test(code)
            logging.debug(f"Attempt {attempt + 1}: Generated unit test --->: ")
            logging.debug(f"~---------------------------------------------~")

            if not unit_test:
                logging.warning("Unit test code is empty.")
                continue  

            formatted_test = reformat_code(unit_test)
            if not is_valid_python(formatted_test):
                logging.error("Formatted test code is not valid Python.")
                continue  

            pytest_result = run_command("pytest", formatted_test, options=["--disable-warnings", "--maxfail=1", "--tb=short"])
            
            cleaned_output = strip_ansi_escape_codes(pytest_result["output"])
            
            if pytest_result["success"]:
                logging.info("Dynamic analysis completed successfully.")
                return {"success": True, "errors": [], "details": {"unit_test": formatted_prompt, "pytest_output": cleaned_output}}
            
            errors = pytest_result["errors"]
            logging.error(f"Dynamic analysis failed with errors: {errors}")
            if attempt == 2:
                return {"success": False, "errors": errors, "details": {"unit_test": formatted_prompt, "pytest_output": cleaned_output}}

        return {"success": False, "errors": ["Failed to generate valid unit tests after 3 attempts."], "details": {"unit_test": formatted_prompt, "pytest_output": ""}}


# Solution Agent

In [23]:
class SolutionGeneratorAgent(Agent):
    def __init__(self, state, solution_model="gpt-4o", solution_model_api_key="your_openai_api_key_here", temperature=0, max_tokens=500):
        self.state = state
        self.client = OpenAI(api_key=solution_model_api_key)
        self.model = solution_model
        self.temperature = temperature
        self.max_tokens = max_tokens
        logging.basicConfig(level=logging.INFO)

    def complete(self, prompt):
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant that responds in python code only."},
                    {"role": "user", "content": prompt}
                ],
                temperature=self.temperature
            )
            return response
        except Exception as e:
            logging.error(f"Failed to generate completion: {e}")
            return None

    def invoke(self, prompt, code, errors):
        logging.info(f"Invoking SolutionGeneratorAgent with errors.")
        formatted_prompt = f" rewrite the code as my before code\n{code} had Error in code:\nErrors:\n{errors}\n {format_instructions}"
        response = self.complete(formatted_prompt)
        
        if response is None:
            self.update_state("status", False)
            self.update_state("compiler_errors", "Failed to generate solution.")
            return self.state
        
        try:
            response_message = response.choices[0].message.content
            solution = response_message.strip()        
        except (AttributeError, KeyError) as e:
            logging.error(f"Error accessing response: {e}")
            self.update_state("status", False)
            self.update_state("compiler_errors", "Error accessing response.")
            return self.state
        
        self.update_state("error_solver_response", {"solution": solution})
        return self.state

# Regenerator Agent

In [24]:
class RegeneratorAgent(Agent):
    def __init__(self, state: AgentGraphState, model_id: str = None, cache_dir: str = None, temperature: float = 0, max_tokens: int = 500):
        super().__init__(state, model_id, cache_dir, temperature, max_tokens)

    def invoke(self, prompt: str, previous_code: str, errors: List[str], solution: str):
        logging.info(f"Invoking RegeneratorAgent with previous code: {previous_code}")
        formatted_prompt = PROMPT_TEMPLATES["code_regeneration"].format(
            prompt=prompt,
            previous_code=previous_code,
            errors="\n".join(errors),
            solution=solution,
            format_instructions=format_instructions
        )
        logging.debug(f"Prompt : {formatted_prompt}")
        logging.debug(f"--------------------------")
        response = self.complete(formatted_prompt, self.max_tokens)
        
        if response is None:
            logging.error("Failed to generate regenerated code.")
            self.update_state("status", False)
            self.update_state("code_regenerator_response", {"regenerated_code": ""})
            return self.state
        
        try:
            regenerated_code = response["choices"][0]["text"].strip()
        except (AttributeError, KeyError, IndexError) as e:
            logging.error(f"Error accessing response: {e}")
            self.update_state("status", False)
            self.update_state("code_regenerator_response", {"regenerated_code": ""})
            return self.state
        
        self.update_state("code_regenerator_response", {"regenerated_code": regenerated_code})
        return self.state

class EndNodeAgent:
    def __init__(self, state: dict):
        self.state = state

    def update_state(self, key: str, value: str) -> None:
        if key in self.state:
            self.state[key] = value
        else:
            logging.warning(f"Key {key} not found in state.")

    def invoke(self):
        logging.info("EndNodeAgent invoked. Workflow complete.")
        self.update_state("end_node_status", "completed")
        logging.info("Printing all states:")
        for key, value in self.state.items():
            print(f"{key}: {value}")
        
        log_state_to_csv(self.state)
        return self.state

# Graph

In [25]:
def create_graph(server=None, model_id=None, model_endpoint=None, solution_model=None, solution_model_api_key=None, temperature=0, max_tokens=500, cache_dir=None):
    # print("Parameters passed to create_graph:")
    # print(f"server: {server}")
    # print(f"model_id: {model_id}")
    # print(f"model_endpoint: {model_endpoint}")
    # print(f"solution_model: {solution_model}")
    # print(f"solution_model_api_key: {solution_model_api_key}")
    # print(f"temperature: {temperature}")
    # print(f"max_tokens: {max_tokens}")
    # print(f"cache_dir: {cache_dir}")
   
    
    graph = StateGraph(AgentGraphState)  

    try:
        graph.add_node(
            "code_generator",
            lambda state: CodeGeneratorAgent(
                state=state,
                model_id=model_id,
                cache_dir=cache_dir,
                temperature=temperature,
                max_tokens=max_tokens
            ).invoke(
                prompt=state.get("row_data", "")
            )
        )

        graph.add_node(
            "compiler",
            lambda state: CompilerAgent(
                state=state,
                model_id=model_id,
                cache_dir=cache_dir,
                temperature=temperature,
                max_tokens=max_tokens
            ).invoke(
                code=state.get("code_generator_response", {}).get("code", "")
            )
        )

        graph.add_node(
            "solution_generator",
            lambda state: SolutionGeneratorAgent(
                state=state,
                solution_model=solution_model,
                solution_model_api_key=solution_model_api_key,
                temperature=temperature,
                max_tokens=max_tokens
            ).invoke(
                errors=state.get("compiler_errors", []),
                code=state.get("code_generator_response", {}).get("code", ""),
                prompt=state.get("row_data", "")
            )
        )

        graph.add_node(
            "regenerator",
            lambda state: RegeneratorAgent(
                state=state,
                model_id=model_id,
                cache_dir=cache_dir,
                temperature=temperature,
                max_tokens=max_tokens
            ).invoke(
                prompt=state.get("row_data", ""),
                previous_code=state.get("code_generator_response", {}).get("code", ""),
                errors=state.get("compiler_errors", []),
                solution=state.get("error_solver_response", {}).get("solution", "")
            )
        )

        graph.add_node("end", lambda state: EndNodeAgent(state).invoke())

        graph.set_entry_point("code_generator")
        graph.set_finish_point("end")

        graph.add_edge("code_generator", "compiler")

        graph.add_conditional_edges(
            "compiler",
            lambda state: "end" if not state.get("compiler_errors", []) or state.get("status", False) else "solution_generator"
        )

        graph.add_edge("solution_generator", "regenerator")
        graph.add_edge("regenerator", "compiler")

        graph.add_conditional_edges(
            "regenerator",
            lambda state: "end" if state.get("regenerated_code") else "compiler"
        )
         

    except Exception as e:
        logging.error(f"Error creating graph: {e}")
        raise

    return graph

def compile_workflow(graph):
    workflow = graph.compile()
    return workflow

# Runner

In [26]:
# ----------------------------------------------------------------------------------------------------------------
# Main Function
def main():
    """
    Main function to create and compile the workflow, process queries, and save log data.
    """
    # ----------------------------------------------------------------------------------------------------------------
    # Configuration
    server = 'meta-llama'
    model_endpoint = 'https://api.huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct'
    temperature = 0.7
    iterations = 40
    max_tokens = 500

    solution_model = "gpt-4o"
    solution_model_api_key = "sk-proj-u5E5H5wto2XBCmflKC9ST3BlbkFJO4kccJSV51ExvE8JiwQN"
    
    # ----------------------------------------------------------------------------------------------------------------
    # Model Configuration
    # model_id = "microsoft/Phi-3-mini-4k-instruct"
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    # model_id = "NousResearch/Hermes-2-Theta-Llama-3-8B"
    # model_3b, tokenizer_3b = load_model_3b(model_id, cache_dir)

    # Model 8b Configuration (commented out)
    # api_token = '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'
    # model_id = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
    # cache_dir = "my_models/phi_8b"
    # model_8b, tokenizer_8b = load_model_8b(model_id, cache_dir, api_token)
    
    # ----------------------------------------------------------------------------------------------------------------
    # Create Graph and Compile Workflow
    print("Creating graph and compiling workflow...")
    graph = create_graph(
        server=server,
        model_id=model_id,
        solution_model=solution_model,
        solution_model_api_key=solution_model_api_key,
        model_endpoint=model_endpoint,
        temperature=temperature,
        max_tokens=max_tokens,
    )
    workflow = compile_workflow(graph)
    print("Graph and workflow created.")

    # ----------------------------------------------------------------------------------------------------------------
    seeds_df = pd.read_csv('seed.csv', header=0)
    log_data = []
    verbose = False
    # ----------------------------------------------------------------------------------------------------------------
    for index, row in seeds_df.iterrows():
        query = row['Questions']
        dict_inputs = {"row_data": [query]}
        limit = {"recursion_limit": iterations}

        try:
            for event in workflow.stream(dict_inputs, limit):
                if verbose:
                    print("\nState Dictionary:", event)
                else:
                    print("\n")
        except Exception as e:
            logging.error(f"An error occurred while processing query '{query}': {e}")

    # ----------------------------------------------------------------------------------------------------------------
    print("Processing complete. Log data saved to logs.csv.")

if __name__ == "__main__":
    main()


INFO:root:Loading model 'meta-llama/Meta-Llama-3.1-8B-Instruct' for the first time.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /meta-llama/Meta-Llama-3.1-8B-Instruct/resolve/main/config.json HTTP/1.1" 200 0


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /meta-llama/Meta-Llama-3.1-8B-Instruct/resolve/main/generation_config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /meta-llama/Meta-Llama-3.1-8B-Instruct/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
INFO:root:Invoking CodeGeneratorAgent with prompt: ['Wirte a python code to multiply two numbers']
INFO:root:Formatted Prompt: 
    Provide a Python single solution oriented code solution to the following prompt.

    Prompt: ['Wirte a python code to multiply two numbers']

    The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":

```json
{
	"code": string  // Python code
}
```
    
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
DEBUG:root:response cleaned: {
	"code": "def multiply_numbers(num1, num2):\n    return num1 * num2\n\n# Test the function\nprint(multiply_numbers(5, 10))