In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
%%capture
!pip install -qU torch==2.3.1 \
transformers==4.41.2 \
accelerate==0.31.0 \
pycaret==3.3.2 \
ipywidgets==8.1.3 \
transitions==0.9.2 \
graphviz==0.20.3

In [3]:
# Standard library imports
import copy
import traceback
from enum import Enum, auto
from pathlib import Path
from typing import Union

# Third-party library imports
import pandas as pd
import requests
import torch
import graphviz
from bs4 import BeautifulSoup
from pydantic import BaseModel, validator, HttpUrl, constr, ValidationError
from transitions import Machine
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers.utils import logging

# Set verbosity 40+ to only display errors and critical
logging.set_verbosity_error()
# Set random seed for reproducibility
torch.random.manual_seed(0)

<torch._C.Generator at 0x7cfa78cb7b50>

In [4]:
class LanguageModel:
    def __init__(self, model_name):
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype="auto",
            trust_remote_code=True,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # Set to evaluation mode
        self.model.eval()

    def generate_text(self, chat_history, generation_args):
        pipe = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
        )
        with torch.no_grad():
          output = pipe(chat_history, **generation_args)
        return output[0]['generated_text'].strip()

In [27]:
class DatasetLocationModel(BaseModel):
    location: Union[HttpUrl, constr(strip_whitespace=True)]

    @validator('location', pre=True)
    def check_location(cls, v):
        if cls.is_valid_local_path(v):
            if cls.has_valid_extension(v):
                return v
            raise ValueError('The local path does not point to a CSV or Parquet file')

        if cls.is_valid_url(v):
            if cls.has_valid_extension(v):
                return v
            raise ValueError('The URL does not point to a CSV or Parquet file')

        raise ValueError('The provided location is not a valid URL or local path')

    @staticmethod
    def is_valid_url(url: str) -> bool:
        try:
            HttpUrl(url=url)
            return True
        except ValidationError:
            return False

    @staticmethod
    def is_valid_local_path(path: str) -> bool:
        return Path(path).exists()

    @staticmethod
    def has_valid_extension(path: str) -> bool:
        valid_extensions = ('.csv', '.parquet')
        return path.lower().endswith(valid_extensions)

    @classmethod
    def validate_location(cls, location: str) -> bool:
        try:
            cls(location=location)
            return True
        except ValidationError:
            return False

In [28]:
class NodeConfig:
    def __init__(self):
        # General generation arguments
        self.generation_args_template = {
            "return_full_text": False,
            "do_sample": False,
        }

        self.entity_extraction_generation_args = copy.deepcopy(self.generation_args_template)
        self.code_generation_args = copy.deepcopy(self.generation_args_template)
        self.code_fix_generation_args = copy.deepcopy(self.generation_args_template)

        self.entity_extraction_generation_args.update({"max_new_tokens": 100})
        self.code_generation_args.update({"max_new_tokens": 1000})
        self.code_fix_generation_args.update({"max_new_tokens": 600})

        # Prompts
        self.entity_extraction_prompt_template = [
          {"role": "system", "content": "You are a helpful, and accurate, AI assistant. Always follow the instructions provided by user"},
          {"role": "user", "content": None}, # Placeholder for entity to be extracted
          ]
        self.code_gen_prompt_template = [
            {"role": "system", "content": "You are a helpful, and accurate, AI assistant, that generates bug free executable python code."},
            {"role": "user", "content": "Here is the documentation on how to use the pycaret library for finding best classification model and fit it on new dataset"},
        ]
        self.code_fix_prompt_template = [
            {"role": "system", "content": "You are a helpful, and accurate, AI assistant, that generates bug free executable python code. Follow the output format in the example"},
            {"role": "user", "content": None},  # Placeholder for example error message and code
            {"role": "assistant", "content": None},  # Placeholder for example fixed code
            {"role": "user", "content": None},  # Placeholder for actual error message and code
        ]

    def get_code_gen_prompt(self, library_doc, task, dataset_url, target_column):
        prompt = copy.deepcopy(self.code_gen_prompt_template)
        user_instruction = (
            f"Using the information from the documentation, write code to FIND AND EVALUATE best model for {task} on dataset located at url: {dataset_url} "
            f"and target column:{target_column} using pycaret library, don't fit it on new data. "
            "Only generate executable code and nothing else like explanation or reasoning. DO NOT INCLUDE ANY MARKDOWN FORMATTING SUCH AS TRIPLE BACKTICKS (```PYTHON). OUTPUT SHOULD BE PLAIN PYTHON CODE."
        )
        prompt[1]["content"] = prompt[1]["content"] + " \n " + "Library Documentation: " + library_doc + " \n " + "user_instruction: " + user_instruction
        return prompt

    def get_code_fix_prompt(self, example_err_msg, example_code_with_error, example_fixed_code, errors, code):
        prompt = copy.deepcopy(self.code_fix_prompt_template)
        prompt[1]["content"] = f"Example: Only generate executable code and nothing else like explanation or reasoning. Fix this error: {example_err_msg} in the python code: {example_code_with_error}."
        prompt[2]["content"] = example_fixed_code
        prompt[3]["content"] = f"Only generate executable code and nothing else like explanation or reasoning. DO NOT INCLUDE ANY MARKDOWN FORMATTING SUCH AS TRIPLE BACKTICKS (```PYTHON). OUTPUT SHOULD BE PLAIN PYTHON CODE. Fix this error: {errors} in the python code: {code}."
        return prompt


In [29]:
class Conversation:
    def __init__(self, language_model, generation_args, max_retries=5):
        self.lm = language_model
        self.generation_args = generation_args
        self.max_retries = max_retries
        self.dataset_url = None
        self.machine_learning_task = None
        self.target_column = None
        self.supported_ml_tasks = ['classification', 'regression', 'clustering']

    def extract_entities(self, user_input, entity_extraction_prompt_template):
      """Given user input, extract the dataset URL, machine learning task, and target column."""

      # Generate prompt for entity extraction
      dataset_input_prompt = copy.deepcopy(entity_extraction_prompt_template)
      machine_learning_task_input_prompt = copy.deepcopy(entity_extraction_prompt_template)
      target_column_input_prompt = copy.deepcopy(entity_extraction_prompt_template)

      dataset_input_prompt[1]["content"] = f"Given the context: {user_input}. If the context contains a url for a csv or parquet file, return the full url as response, otherwise only ouput one word False"
      machine_learning_task_input_prompt[1]["content"] = f"Given the context: {user_input}. Identify if the context mentions a machine learning task on the target column in the dataset if yes then return the machine learning task as response, like regression or classification or clustering; otherwise only ouput one word False"
      target_column_input_prompt[1]["content"] = f"Given the context: {user_input}. Identify if the context mentions a target column to be used for the machine leraning problem, if yes then return the target column  as response, otherwise only ouput one word False"

      if not self.dataset_url:
        self.dataset_url = self.lm.generate_text(dataset_input_prompt, self.generation_args)
        # Check if the URL is valid
        if not DatasetLocationModel.validate_location(self.dataset_url):
          self.dataset_url = None

      if not self.machine_learning_task:
        self.machine_learning_task = self.lm.generate_text(machine_learning_task_input_prompt, self.generation_args)
        # Check if the machine_learning_task is valid
        if not self.machine_learning_task or self.machine_learning_task.lower() not in self.supported_ml_tasks:
          self.machine_learning_task = None

      if not self.target_column:
        self.target_column = self.lm.generate_text(target_column_input_prompt, self.generation_args)
      # Check if the target_column is valid
      if self.dataset_url:
        data = None
        if self.dataset_url.endswith(".csv"):
          data = pd.read_csv(self.dataset_url, nrows= 10)
        else:
          data = pd.read_parquet(self.dataset_url).head(10)

        if not self.target_column in data.columns:
          self.target_column = None

      return None

    def is_chat_successful(self):
      return self.dataset_url and self.machine_learning_task and self.target_column


    def chat(self, entity_extraction_prompt_template):
      retries = 0
      while retries < self.max_retries and not (self.dataset_url and self.machine_learning_task and self.target_column):
        user_input = input("")
        self.extract_entities(user_input, entity_extraction_prompt_template)
        if self.dataset_url:
          print("Dataset URL:", self.dataset_url)
        else:
          print("Dataset location invalid try again")
        if self.machine_learning_task:
          print("Machine Learning Task:", self.machine_learning_task)
        else:
          print("Please choose machine task from the following: ", self.supported_ml_tasks)
        if self.target_column:
          print("Target:", self.target_column)
        else:
          print("Target columnn not found in the dataset.")

        retries += 1
        if retries == self.max_retries:
          print("Failed to extract entities after multiple retries.")

      return None

In [35]:
# Define the states for the workflow
class NodeState(Enum):
    COLLECTING_INPUTS = auto()
    GENERATING_CODE = auto()
    EXECUTING_CODE = auto()
    FIXING_ERRORS = auto()
    FINISHED = auto()
    MAX_RETRIES_REACHED = auto()

# Shared context for passing data between nodes
class WorkflowContext:
    def __init__(self, lm, documentation):
        self.inputs = None
        self.code = None
        self.fixed_code = None
        self.execution_success = None
        self.errors = None
        self.lm = lm
        self.library_doc = documentation

# Base class for all nodes in the workflow
class Node:
    def __init__(self, name, context, retries= 5):
        self.name = name
        self.context = context
        self.max_retries = retries
        self.transitions = []  # Track transitions for visualization

    def run(self):
        raise NotImplementedError("Each node must implement the run method")

    def log_transition(self, source_state, dest_state):
      self.transitions.append((source_state, dest_state))

# Node for collecting inputs
class CollectInputsNode(Node):
    def run(self):
        # Logic to collect inputs
        source_state = NodeState.COLLECTING_INPUTS
        self.context.inputs = self.collect_inputs()
        if self.inputs_collected():
            dest_state = NodeState.GENERATING_CODE
            self.log_transition(source_state, dest_state)
            return True
        else:
          dest_state = NodeState.COLLECTING_INPUTS
          self.log_transition(source_state, dest_state)
          return False

    def collect_inputs(self):
        # Define generation arguments
        config = NodeConfig()
        # Initialize Conversation
        conversor = Conversation(self.context.lm, config.entity_extraction_generation_args, self.max_retries)
        conversor.chat(config.entity_extraction_prompt_template)
        if conversor.is_chat_successful():
          return {'dataset_url': conversor.dataset_url, 'machine_learning_task': conversor.machine_learning_task, 'target_column': conversor.target_column}
        return None


    def inputs_collected(self):
        return self.context.inputs is not None

# Node for generating code
class GenerateCodeNode(Node):
    def run(self):
        # Logic to generate code

        config = NodeConfig()
        prompt = config.get_code_gen_prompt(
            self.context.library_doc,
            self.context.inputs['machine_learning_task'],
            self.context.inputs['dataset_url'],
            self.context.inputs['target_column']
        )
        self.context.code = self.generate_code(prompt, config.code_generation_args)
        #self.context.code = self.context.code.replace("pandas", "wakandas")
        source_state = NodeState.GENERATING_CODE
        dest_state = NodeState.EXECUTING_CODE
        self.log_transition(source_state, dest_state)
        return True

    def generate_code(self, code_gen_prompt, code_generation_args):
        # Implement code generation logic
        return self.context.lm.generate_text(code_gen_prompt, code_generation_args)

# Node for executing code
class ExecuteCodeNode(Node):
    def run(self):
        # Logic to execute code
        source_state = NodeState.EXECUTING_CODE
        success, errors = self.execute_code(self.context.code)
        self.context.execution_success = success
        self.context.errors = errors
        if success:
            dest_state = NodeState.FINISHED
        else:
            dest_state = NodeState.FIXING_ERRORS
        self.log_transition(source_state, dest_state)

        return success

    def check_formatting(self, code):
        # Check if the code contains markdown formatting with triple backticks
        if "```" in code:
            # Split the code by newlines
            lines = code.split('\n')
            # Filter out lines that start or end with triple backticks
            executable_lines = [line for line in lines if not line.strip().startswith('```')]
            # Join the remaining lines back into a single string
            return '\n'.join(executable_lines)
        else:
            # If no markdown formatting is detected, return the input as is
            return code


    def execute_code(self, code):
        code = self.check_formatting(code)
        # code execution logic
        try:
            print(code)
            exec(code)
            return True, None  # Indicate successful execution
        except Exception as e:
            error_message = str(traceback.format_exc())
            return False, [error_message]

# Node for fixing errors
class FixErrorsNode(Node):
    def __init__(self, name, context, max_retries=3):
        super().__init__(name, context)
        self.retries = 0
        self.max_retries = max_retries

    def run(self):
        # Logic to fix errors
        source_state = NodeState.FIXING_ERRORS
        self.context.fixed_code = self.fix_errors(self.context.errors, self.context.code)
        self.context.code = self.context.fixed_code
        self.retries += 1
        if self.retries >= self.max_retries:
            dest_state = NodeState.MAX_RETRIES_REACHED
        else:
            dest_state = NodeState.EXECUTING_CODE
        self.log_transition(source_state, dest_state)
        return self.retries < self.max_retries

    def fix_errors(self, errors, code):
        # Implement error fixing logic
        example_err_msg = (
            "Traceback (most recent call last):\n"
            '  File "<ipython-input-48-a50281d5d318>", line 3, in <cell line: 2>\n'
            "    exec(code_2_run)\n"
            '  File "<string>", line 6, in <module>\n'
            "NameError: name 'adde_two_numbers' is not defined"
        )

        example_code_with_error = (
            "def add_two_numbers(a, b):\n"
            "    return a + b\n\n"
            "a = 10\n"
            "b = 5\n"
            "print(adde_two_numbers(a, b))"
        )

        example_fixed_code = (
            "# Here is the fixed code\n"
            "def add_two_numbers(a, b):\n"
            "    return a + b\n\n"
            "a = 10\n"
            "b = 5\n"
            "print(add_two_numbers(a, b))"
        )
        prompt = NodeConfig().get_code_fix_prompt(
            example_err_msg,
            example_code_with_error,
            example_fixed_code,
            errors,
            code
        )
        return self.context.lm.generate_text(prompt, NodeConfig().code_fix_generation_args)

In [36]:
# The workflow graph using the transitions library
class Workflow:
    states = [
        'collecting_inputs',
        'generating_code',
        'executing_code',
        'fixing_errors',
        'finished',
        'max_retries_reached'
    ]

    def __init__(self, lm, documentation):
        self.context = WorkflowContext(lm, documentation)
        self.nodes = {
            'collecting_inputs': CollectInputsNode('collect_inputs', self.context),
            'generating_code': GenerateCodeNode('generate_code', self.context),
            'executing_code': ExecuteCodeNode('execute_code', self.context),
            'fixing_errors': FixErrorsNode('fix_errors', self.context),
        }
        self.current_node = self.nodes['collecting_inputs']
        self.state = NodeState.COLLECTING_INPUTS
        self.transitions = []  # Track global transitions

        # Set up the state machine
        self.machine = Machine(model=self, states=Workflow.states, initial='collecting_inputs')

        # Define transitions between states
        self.machine.add_transition('collect_inputs', 'collecting_inputs', 'generating_code', conditions='run_collecting_inputs')
        self.machine.add_transition('generate_code', 'generating_code', 'executing_code', conditions='run_generating_code')
        self.machine.add_transition('execute_code', 'executing_code', 'finished', conditions='run_executing_code')
        self.machine.add_transition('execution_failed', 'executing_code', 'fixing_errors')
        self.machine.add_transition('fix_errors', 'fixing_errors', 'executing_code', conditions='run_fixing_errors')
        self.machine.add_transition('max_retries', '*', 'max_retries_reached')

    def run_collecting_inputs(self):
        return self.nodes['collecting_inputs'].run()

    def run_generating_code(self):
        return self.nodes['generating_code'].run()

    def run_executing_code(self):
        success = self.nodes['executing_code'].run()
        if success:
            return True
        else:
            self.execution_failed()
            return False

    def run_fixing_errors(self):
        can_retry = self.nodes['fixing_errors'].run()
        if can_retry:
            return True
        else:
            self.max_retries()
            return False

    def run(self):
        while self.state not in ['finished', 'max_retries_reached']:
            print('Inside Workflow, current state is:', self.state)
            if self.state == 'collecting_inputs':
                self.collect_inputs()
            elif self.state == 'generating_code':
                self.generate_code()
            elif self.state == 'executing_code':
                self.execute_code()
            elif self.state == 'fixing_errors':
                self.fix_errors()

            # Collect global transitions for the workflow graph
            self.transitions.append((self.state, self.current_node.name, self.state))

    def visualize_workflow(self, filename='workflow_graph'):
        dot = graphviz.Digraph(comment='Workflow Execution')
        # Add nodes and edges to the graph
        for node_name, node in self.nodes.items():
            dot.node(node_name, node_name)
            for (source_state, dest_state) in node.transitions:
                dot.edge(source_state.name, dest_state.name, label=f'{node_name}')

        # Render the graph to a file
        dot.render(filename, format='png', cleanup=True)
        print(f"Workflow graph saved as {filename}.png")

In [10]:
model_name = "microsoft/Phi-3-mini-128k-instruct"
lm = LanguageModel(model_name)

config.json:   0%|          | 0.00/3.48k [00:00<?, ?B/s]

configuration_phi3.py:   0%|          | 0.00/11.2k [00:00<?, ?B/s]

modeling_phi3.py:   0%|          | 0.00/73.2k [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/16.3k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.44k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/599 [00:00<?, ?B/s]

In [37]:
def fetch_raw_html(url):
    try:
        response = requests.get(url)
        response.raise_for_status()
        return response.text
    except requests.exceptions.RequestException as e:
        print(f"Error fetching the URL: {e}")
        return None

def parse_html(html_content):
    soup = BeautifulSoup(html_content, 'html')

    # Initialize a list to keep all elements in order
    elements = []

    # Extract all relevant tags
    for tag in soup.find_all(['p', 'code', 'div'], recursive=True):
        if tag.name == 'p':
            elements.append(('paragraph', tag.get_text()))
        if tag.name == 'code':
            elements.append(('code', tag.get_text()))
        elif tag.name == 'div' and 'section' in tag.get('class', []):
            elements.append(('section', tag.get_text()))

    extracted_content = ""
    for i, (element_type, text) in enumerate(elements):
        # print(i, element_type, text)
        extracted_content = extracted_content + "\n" + f"{element_type.capitalize()} {i + 1}: {text}"

    return extracted_content


# Example URL
#url = "https://pycaret.gitbook.io/docs/get-started/quickstart"
url = "https://raw.githubusercontent.com/abhimanyu729/GenAIPlayground/aisc_presentation/aisc_demo/assets/pycaret_documentation.html"

# Fetch and parse HTML content
raw_html = fetch_raw_html(url)

# Fetch and parse HTML content
documentation_context = None
if raw_html:
    documentation_context = parse_html(raw_html)

In [38]:
workflow = Workflow(lm, documentation_context)

In [39]:
workflow.run()
## Test Cases:
# 1. The dataset I want to use is Titanic, and the column to classify on is Survived
# 2. I don't know where the data is
# 3. You can find data here: https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv
# 4. I'd like you to apply regression on the target column
# 5. Machine Learning Task: Clustering

Inside Workflow, current state is: collecting_inputs
The dataset I want to use is Titanic, and the column to classify on is Survived
Dataset location invalid try again
Machine Learning Task: Classification
Target: Survived
You can find data here: https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv
Dataset URL: https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv
Machine Learning Task: Classification
Target: Survived
Inside Workflow, current state is: generating_code
Inside Workflow, current state is: executing_code
import pandas as pd
from pycaret.classification import *

# Load dataset
data = pd.read_csv('https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv')

# Setup PyCaret environment
clf = setup(data=data, target='Survived')

# Compare and find the best model
best_model = compare_models()

# Tune the best model
tuned_model = tune_model(best_model)

# Evaluate the tuned model
evaluate_model(tuned_model)


Unnamed: 0,Description,Value
0,Session id,1198
1,Target,Survived
2,Target type,Binary
3,Original data shape,"(891, 12)"
4,Transformed data shape,"(891, 14)"
5,Transformed train set shape,"(623, 14)"
6,Transformed test set shape,"(268, 14)"
7,Numeric features,6
8,Categorical features,5
9,Rows with missing values,79.5%


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
lr,Logistic Regression,0.8123,0.8665,0.6656,0.8112,0.7289,0.5881,0.5964,0.256
ridge,Ridge Classifier,0.7337,0.8589,0.398,0.813,0.5307,0.3761,0.4238,0.117
et,Extra Trees Classifier,0.6886,0.8101,0.2303,0.8825,0.3497,0.2343,0.3293,0.37
nb,Naive Bayes,0.6725,0.7987,0.1799,0.8848,0.2912,0.1861,0.2877,0.101
rf,Random Forest Classifier,0.6388,0.847,0.063,0.8667,0.1155,0.0731,0.1745,0.258
lda,Linear Discriminant Analysis,0.626,0.5273,0.0304,0.0875,0.0452,0.0322,0.0402,0.1
knn,K Neighbors Classifier,0.6196,0.6047,0.3554,0.5108,0.4164,0.148,0.1547,0.123
dt,Decision Tree Classifier,0.6164,0.5,0.0,0.0,0.0,0.0,0.0,0.162
ada,Ada Boost Classifier,0.6164,0.5,0.0,0.0,0.0,0.0,0.0,0.102
gbc,Gradient Boosting Classifier,0.6164,0.5,0.0,0.0,0.0,0.0,0.0,0.167


Unnamed: 0_level_0,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,0.7778,0.8088,0.5833,0.7778,0.6667,0.5051,0.5168
1,0.8889,0.9476,0.75,0.9474,0.8372,0.7546,0.7665
2,0.7937,0.8333,0.625,0.7895,0.6977,0.5442,0.5528
3,0.8387,0.8896,0.7391,0.8095,0.7727,0.6481,0.6497
4,0.7903,0.8531,0.7083,0.7391,0.7234,0.5547,0.555
5,0.9194,0.8904,0.7917,1.0,0.8837,0.8233,0.8364
6,0.8548,0.9276,0.875,0.7778,0.8235,0.701,0.7045
7,0.8226,0.8695,0.7917,0.76,0.7755,0.6289,0.6293
8,0.8387,0.8553,0.6667,0.8889,0.7619,0.6437,0.6589
9,0.7742,0.8542,0.6667,0.7273,0.6957,0.5167,0.5179


Fitting 10 folds for each of 10 candidates, totalling 100 fits


interactive(children=(ToggleButtons(description='Plot Type:', icons=('',), options=(('Pipeline Plot', 'pipelin…

In [14]:
workflow.visualize_workflow(filename='workflow_graph_execution')

Workflow graph saved as workflow_graph_execution.png


In [None]:
# bug_understanding_prompt = [
#     {"role": "system", "content": "You are a helpful, and accurate, AI assistant, that's an expert in python programming"},
#     {"role": "user", "content": f"Given the following python code: {code} and the error: {errors}. Think and provide clear explanation in natural language about what's the root cause of the error. The explantion should be short and concise."},
# ]
# bug_understanding_generation_args = {
#     "max_new_tokens": 200,
#     "return_full_text": False,
#     "temperature": 0.0,
#     "do_sample": False,
# }

In [None]:
workflow.state

'finished'

In [None]:
# Why do you think the error exists provide resoning
# use reasoning error and
# input code to fix it.

In [None]:
# class NodeConfig:
#     def __init__(self):
#         # General generation arguments
#         self.generation_args_template = {
#             "return_full_text": False,
#             "do_sample": False,
#         }

#         self.entity_extraction_generation_args = copy.deepcopy(self.generation_args_template)
#         self.code_generation_args = copy.deepcopy(self.generation_args_template)
#         self.code_fix_generation_args = copy.deepcopy(self.generation_args_template)

#         self.entity_extraction_generation_args.update({"max_new_tokens": 100})
#         self.code_generation_args.update({"max_new_tokens": 1000})
#         self.code_fix_generation_args.update({"max_new_tokens": 600})

#         # Prompts
#         self.entity_extraction_prompt_template = [
#           {"role": "system", "content": "You are a helpful, and accurate, AI assistant. Always follow the instructions provided by user"},
#           {"role": "user", "content": None}, # Placeholder for entity to be extracted
#           ]
#         self.code_gen_prompt_template = [
#             {"role": "system", "content": "You are a helpful, and accurate, AI assistant, that generates bug free executable python code. DO NOT INCLUDE ANY MARKDOWN FORMATTING SUCH AS TRIPLE BACKTICKS (```PYTHON). OUTPUT SHOULD BE PLAIN PYTHON CODE."},
#             {"role": "user", "content": "Here is the documentation on how to use the pycaret library for finding best classification model and fit it on new dataset"},
#             {"role": "assistant", "content": None},  # Placeholder for library doc
#             {"role": "user", "content": None},  # Placeholder for task-specific content
#         ]
#         self.code_fix_prompt_template = [
#             {"role": "system", "content": "You are a helpful, and accurate, AI assistant, that generates bug free executable python code. DO NOT INCLUDE ANY MARKDOWN FORMATTING SUCH AS TRIPLE BACKTICKS (```PYTHON). OUTPUT SHOULD BE PLAIN PYTHON CODE. Follow the output format in the example"},
#             {"role": "user", "content": None},  # Placeholder for example error message and code
#             {"role": "assistant", "content": None},  # Placeholder for example fixed code
#             {"role": "user", "content": None},  # Placeholder for actual error message and code
#         ]

#     def get_code_gen_prompt(self, library_doc, task, dataset_url, target_column):
#         prompt = copy.deepcopy(self.code_gen_prompt_template)
#         prompt[2]["content"] = library_doc
#         prompt[3]["content"] = (
#             f"Write code to find best model for {task} on dataset located at url: {dataset_url} "
#             f"and target column:{target_column} using pycaret library, don't fit it on new data. "
#             "Only generate executable code and nothing else like explanation or reasoning."
#         )
#         # user_instruction = (
#         #     f"Using the information from the documentation, write code to find best model for {task} on dataset located at url: {dataset_url} "
#         #     f"and target column:{target_column} using pycaret library, don't fit it on new data. "
#         #     "Only generate executable code and nothing else like explanation or reasoning. Do not include any markdown formatting such as triple backticks (```python). Only output the plain Python code directly."
#         # )
#         # prompt[1]["content"] = prompt[1]["content"] + "\n" + library_doc + "\n" + user_instruction
#         return prompt

#     def get_code_fix_prompt(self, example_err_msg, example_code_with_error, example_fixed_code, errors, code):
#         prompt = copy.deepcopy(self.code_fix_prompt_template)
#         prompt[1]["content"] = f"Example: Only generate executable code and nothing else like explanation or reasoning. Fix this error: {example_err_msg} in the python code: {example_code_with_error}."
#         prompt[2]["content"] = example_fixed_code
#         prompt[3]["content"] = f"Only generate executable code and nothing else like explanation or reasoning. DO NOT INCLUDE ANY MARKDOWN FORMATTING SUCH AS TRIPLE BACKTICKS (```PYTHON). OUTPUT SHOULD BE PLAIN PYTHON CODE. Fix this error: {errors} in the python code: {code}."
#         return prompt
