In [None]:
# | default_exp _code_generator.chat

In [None]:
# | export

from typing import *
import random
import logging
import time
from collections import defaultdict
from pathlib import Path

import openai
from fastcore.foundation import patch
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

from faststream_gen._code_generator.constants import (
    DEFAULT_PARAMS,
    MAX_RETRIES,
    STEP_LOG_DIR_NAMES,
    MAX_NUM_FIXES_MSG,
    INCOMPLETE_DESCRIPTION,
    DESCRIPTION_EXAMPLE,
    LOGS_DIR_NAME,
)
from faststream_gen._components.logger import get_logger, set_level
from faststream_gen._code_generator.prompts import SYSTEM_PROMPT
from faststream_gen._code_generator.helper import add_tokens_usage
from faststream_gen._components.package_data import get_root_data_path

In [None]:
from tempfile import TemporaryDirectory

import pytest

from faststream_gen._components.logger import suppress_timestamps
from faststream_gen._code_generator.constants import OpenAIModel

In [None]:
# | export

logger = get_logger(__name__, level=logging.WARNING)

In [None]:
suppress_timestamps()
logger = get_logger(__name__, level=20)
logger.info("ok")

[INFO] __main__: ok


In [None]:
# | export

# Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb


def _retry_with_exponential_backoff(
    initial_delay: float = 1,
    exponential_base: float = 2,
    jitter: bool = True,
    max_retries: int = 10,
    max_wait: float = 60,
    errors: tuple = (
        openai.error.RateLimitError,
        openai.error.ServiceUnavailableError,
        openai.error.APIError,
    ),
) -> Callable:
    """Retry a function with exponential backoff."""

    def decorator(
        func: Callable[[str], Tuple[str, str]]
    ) -> Callable[[str], Tuple[str, str]]:
        def wrapper(*args, **kwargs):  # type: ignore
            num_retries = 0
            delay = initial_delay

            while True:
                try:
                    return func(*args, **kwargs)

                except errors as e:
                    num_retries += 1
                    if num_retries > max_retries:
                        raise Exception(
                            f"Maximum number of retries ({max_retries}) exceeded."
                        )
                    delay = min(
                        delay
                        * exponential_base
                        * (1 + jitter * random.random()),  # nosec
                        max_wait,
                    )
                    logger.info(
                        f"Note: OpenAI's API rate limit reached. Command will automatically retry in {int(delay)} seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits",
                    )
                    time.sleep(delay)

                except Exception as e:
                    raise e

        return wrapper

    return decorator

In [None]:
@_retry_with_exponential_backoff()
def mock_func():
    return "Success"

actual = mock_func()
expected = "Success"

print(actual)
assert actual == expected

Success


In [None]:
# Test max retries exceeded
@_retry_with_exponential_backoff(max_retries=1)
def mock_func_error():
    raise openai.error.RateLimitError


with pytest.raises(Exception) as e:
    mock_func_error()

print(e.value)
assert str(e.value) == "Maximum number of retries (1) exceeded."

[INFO] __main__: Note: OpenAI's API rate limit reached. Command will automatically retry in 3 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits
Maximum number of retries (1) exceeded.


In [None]:
# | export

def _get_relevant_document(query: str) -> str:
    """Load the vector database and retrieve the most relevant document based on the given query.

    Args:
        query: The query for relevance-based document retrieval.

    Returns:
        The content of the most relevant document as a string.
    """
    db_path = get_root_data_path() / "docs"
    db = FAISS.load_local(db_path, OpenAIEmbeddings()) # type: ignore
    results = db.max_marginal_relevance_search(query, k=1, fetch_k=3)
    results_str = "\n".join([result.page_content for result in results])
    return results_str

In [None]:
query = "What is FastStream?"
actual = _get_relevant_document(query)
print(actual[:200])
assert len(actual) > 0

[INFO] faiss.loader: Loading faiss with AVX2 support.
[INFO] faiss.loader: Successfully loaded faiss with AVX2 support.
hide:
  - navigation
  - footer

Release Notes

FastStream is a new package based on the ideas and experiences gained from FastKafka and Propan. By joining our forces, we picked up the best from both 


In [None]:
# | export


class CustomAIChat:
    """Custom class for interacting with OpenAI

    Attributes:
        model: The OpenAI model to use. If not passed, defaults to gpt-3.5-turbo-16k.
        system_prompt: Initial system prompt to the AI model. If not passed, defaults to SYSTEM_PROMPT.
        initial_user_prompt: Initial user prompt to the AI model.
        params: Parameters to use while initiating the OpenAI chat model. DEFAULT_PARAMS used if not provided.
    """

    def __init__(
        self,
        model: str,
        user_prompt: Optional[str] = None,
        params: Dict[str, float] = DEFAULT_PARAMS,
        semantic_search_query: Optional[str] = None,
    ):
        """Instantiates a new CustomAIChat object.

        Args:
            model: The OpenAI model to use. If not passed, defaults to gpt-3.5-turbo-16k.
            user_prompt: The user prompt to the AI model.
            params: Parameters to use while initiating the OpenAI chat model. DEFAULT_PARAMS used if not provided.
            semantic_search_query: A query string to fetch relevant documents from the database
        """
        self.model = model
        self.messages = [
            {"role": role, "content": content}
            for role, content in [
                ("system", SYSTEM_PROMPT),
                ("user", self._get_doc(semantic_search_query)),
                ("user", user_prompt),
            ]
            if content is not None
        ]
        self.params = params

    @staticmethod
    def _get_doc(semantic_search_query: Optional[str] = None) -> str:
        if semantic_search_query is None:
            return ""
        return _get_relevant_document(semantic_search_query)
    
    @_retry_with_exponential_backoff()
    def __call__(self, user_prompt: str) -> Tuple[str, Dict[str, int]]:
        """Call OpenAI API chat completion endpoint and generate a response.

        Args:
            user_prompt: A string containing user's input prompt.

        Returns:
            A tuple with AI's response message content and the total number of tokens used while generating the response.
        """
        self.messages.append(
            {"role": "user", "content": f"{user_prompt}\n==== YOUR RESPONSE ====\n"}
        )
        prompt_str = "\n\n".join([f"===Role:{m['role']}===\n\nMessage:\n{m['content']}" for m in self.messages])
        logger.info(f"\n\nPrompt to the model: \n\n{prompt_str}")
        
        response = openai.ChatCompletion.create(
            model=self.model,
            messages=self.messages,
            temperature=self.params["temperature"],
        )

        return (
            response["choices"][0]["message"]["content"],
            response["usage"],
        )

In [None]:
# | notest

TEST_INITIAL_USER_PROMPT = """
You should respond with 0, 1 or 2 and nothing else. Below are your rules:

==== RULES: ====

If the ==== APP DESCRIPTION: ==== section is not related to FastKafka or contains violence, self-harm, harassment/threatening or hate/threatening information then you should respond with 0.

If the ==== APP DESCRIPTION: ==== section is related to FastKafka but focuses on what is it and its general information then you should respond with 1. 

If the ==== APP DESCRIPTION: ==== section is related to FastKafka but focuses how to use it and instructions to create a new app then you should respond with 2. 
"""

ai = CustomAIChat(user_prompt = TEST_INITIAL_USER_PROMPT, model=OpenAIModel.gpt3.value)
response, usage = ai("Name the tallest mountain in the world")

print(response)
print(usage)

assert response == "0"

[INFO] __main__: 

Prompt to the model: 

===Role:system===

Message:

You are an expert Python developer, tasked to generate executable Python code as a part of your work with the FastStream framework. 

You are to abide by the following guidelines:

1. You must never enclose the generated Python code with ``` python. It is mandatory that the output is a valid and executable Python code. Please ensure this rule is never broken.

2. Some prompts might require you to generate code that contains async functions. For example:

async def app_setup(context: ContextRepo):
    raise NotImplementedError()

In such cases, it is necessary to add the "import asyncio" statement at the top of the code. 

You will encounter sections marked as:

==== APP DESCRIPTION: ====

These sections contain the description of the FastStream app you need to implement. Treat everything below this line, until the end of the prompt, as the description to follow for the app implementation.


===Role:user===

Message:

In [None]:
# | export


class ValidateAndFixResponse:
    """Generates and validates response from OpenAI

    Attributes:
        generate: A callable object for generating responses.
        validate: A callable object for validating responses.
        max_retries: An optional integer specifying the maximum number of attempts to generate and validate a response.
    """

    def __init__(
        self,
        generate: Callable[..., Any],
        validate: Callable[..., Any],
        max_retries: Optional[int] = MAX_RETRIES,
    ):
        self.generate = generate
        self.validate = validate
        self.max_retries = max_retries

    def fix(
        self,
        prompt: str,
        total_usage: List[Dict[str, int]],
        step_name: Optional[str] = None,
        log_dir_path: Optional[str] = None,
        **kwargs: Dict[str, Any],
    ) -> Tuple[str, List[Dict[str, int]]]:
        raise NotImplementedError()

In [None]:
# | export


def _save_log_results(
    step_name: str,
    log_dir_path: str,
    messages: List[Dict[str, str]],
    response: str,
    error_str: str,
    retry_cnt: int,
    **kwargs: Dict[str, int],
) -> None:
    if log_dir_path is not None and "attempt" in kwargs:
        step_dir = Path(log_dir_path) / step_name
        step_dir.mkdir(parents=True, exist_ok=True)

        attempt_dir = step_dir / f'attempt_{kwargs["attempt"] + 1}'  # type: ignore
        attempt_dir.mkdir(parents=True, exist_ok=True)

        try_dir = attempt_dir / f"try_{retry_cnt+1}"
        try_dir.mkdir(parents=True, exist_ok=True)

        formatted_msg = "\n".join(
            [f"===={m['role']}====\n\n{m['content']}\n\n" for m in messages]
        )

        with open((try_dir / "input.txt"), "w", encoding="utf-8") as f_input, open(
            (try_dir / "output.txt"), "w", encoding="utf-8"
        ) as f_output, open(
            (try_dir / "errors.txt"), "w", encoding="utf-8"
        ) as f_errors:
            f_input.write(formatted_msg)
            f_output.write(response)
            f_errors.write(error_str)

In [None]:
with TemporaryDirectory() as d:
    messages = [{"role": "role", "content": "content"}]
    kwargs = {"attempt": 2}
    for step_name in ["app", "test"]:
        _save_log_results(step_name, d, messages, "response", "error_str", 0, **kwargs)

        step_dir = Path(d) / step_name
        assert step_dir.exists()

        attempt_dir = step_dir / "attempt_3"
        assert attempt_dir.exists()

        try_dir = attempt_dir / "try_1"
        assert try_dir.exists()

        print(list(Path(try_dir).glob('**/*')))
        assert (Path(d) / step_dir / "attempt_3" / f"try_1" / "input.txt").exists()
        assert (Path(d) / step_dir / "attempt_3" / f"try_1" / "output.txt").exists()
        assert (Path(d) / step_dir / "attempt_3" / f"try_1" / "errors.txt").exists()

[PosixPath('/tmp/tmpb96cmfze/app/attempt_3/try_1/errors.txt'), PosixPath('/tmp/tmpb96cmfze/app/attempt_3/try_1/output.txt'), PosixPath('/tmp/tmpb96cmfze/app/attempt_3/try_1/input.txt')]
[PosixPath('/tmp/tmpb96cmfze/test/attempt_3/try_1/errors.txt'), PosixPath('/tmp/tmpb96cmfze/test/attempt_3/try_1/output.txt'), PosixPath('/tmp/tmpb96cmfze/test/attempt_3/try_1/input.txt')]


In [None]:
# | export


def _construct_prompt_with_error_msg(
    response: str,
    errors: str,
) -> str:
    """Construct prompt message along with the error message.

    Args:
        prompt: The original prompt string.
        response: The invalid response string from OpenAI.
        errors: The errors which needs to be fixed in the invalid response.

    Returns:
        A string combining the original prompt, invalid response, and the error message.
    """
    prompt_with_errors = (
        f"\n\n==== YOUR RESPONSE (WITH ISSUES) ====\n\n{response}"
        + f"\n\nRead the contents of ==== YOUR RESPONSE (WITH ISSUES) ==== section and fix the below mentioned issues:\n\n{errors}"
    )
    return prompt_with_errors

In [None]:
response = "some response"
errors = """error 1
error 2
error 3
"""

expected = """

==== YOUR RESPONSE (WITH ISSUES) ====

some response

Read the contents of ==== YOUR RESPONSE (WITH ISSUES) ==== section and fix the below mentioned issues:

error 1
error 2
error 3
"""
actual = _construct_prompt_with_error_msg(response, errors)
print(actual)

assert actual == expected



==== YOUR RESPONSE (WITH ISSUES) ====

some response

Read the contents of ==== YOUR RESPONSE (WITH ISSUES) ==== section and fix the below mentioned issues:

error 1
error 2
error 3



In [None]:
# | export


@patch  # type: ignore
def fix(
    self: ValidateAndFixResponse,
    prompt: str,
    total_usage: List[Dict[str, int]],
    step_name: str,
    output_directory: str,
    **kwargs: Dict[str, Any],
) -> List[Dict[str, int]]:
    """Fix the response from OpenAI until no errors remain or maximum number of attempts is reached.

    Args:
        prompt: The initial prompt string.
        kwargs: Additional keyword arguments to be passed to the validation function.

    Returns:
        str: The generated response that has passed the validation.

    Raises:
        ValueError: If the maximum number of attempts is exceeded and the response has not successfully passed the validation.
    """
    total_tokens_usage: Dict[str, int] = defaultdict(int)
    log_dir_path = Path(output_directory) / LOGS_DIR_NAME
    for i in range(self.max_retries):  # type: ignore
        response, usage = self.generate(prompt)
        total_tokens_usage = add_tokens_usage([total_tokens_usage, usage])
        
        errors = self.validate(response, output_directory, **kwargs)
        error_str = "\n".join(errors)
        _save_log_results(
            step_name,
            str(log_dir_path),
            self.generate.messages,  # type: ignore
            response,
            error_str,
            i,
            **kwargs,
        )
        if len(errors) == 0:
            total_usage.append(total_tokens_usage)
            return total_usage

        self.generate.messages[-1]["content"] = self.generate.messages[-1][ # type: ignore
            "content"
        ].rsplit("==== YOUR RESPONSE ====", 1)[0]
        prompt = _construct_prompt_with_error_msg(response, error_str)
        logger.info(f"Validation failed, trying again...Errors:\n{error_str}")

    total_usage.append(total_tokens_usage)
    
    # we send False to notify the generated code contains bugs
    raise ValueError(total_usage, False)

In [None]:
fixture_initial_prompt = "some valid prompt"
expected = "some valid prompt"
max_retries = 3


class FixtureGenerate:
    def __init__(self, user_prompt):
        self.messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    def __call__(self, prompt):
        self.messages.append({"role": "user", "content": prompt})
        usage = {"prompt_tokens": 129, "completion_tokens": 1, "total_tokens": 130}
        return fixture_initial_prompt, usage
    
def fixture_validate(response, output_directory, attempt):
        return []

with TemporaryDirectory() as d:
    kwargs = {"attempt": 0}
    fixture_generate = FixtureGenerate(fixture_initial_prompt)
    v = ValidateAndFixResponse(fixture_generate, fixture_validate, max_retries)
    actual = v.fix(fixture_initial_prompt, [], STEP_LOG_DIR_NAMES["app"], d, **kwargs)
    print(actual)
    
    assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_1").exists()
    assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_1" / "input.txt").exists()
    assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_1" / "output.txt").exists()
    assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_1" / "errors.txt").exists()

    with open((Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_1" / "input.txt"), "r", encoding="utf-8") as f:
        print(f.read())

[defaultdict(<class 'int'>, {'prompt_tokens': 129, 'completion_tokens': 1, 'total_tokens': 130})]
====system====


You are an expert Python developer, tasked to generate executable Python code as a part of your work with the FastStream framework. 

You are to abide by the following guidelines:

1. You must never enclose the generated Python code with ``` python. It is mandatory that the output is a valid and executable Python code. Please ensure this rule is never broken.

2. Some prompts might require you to generate code that contains async functions. For example:

async def app_setup(context: ContextRepo):
    raise NotImplementedError()

In such cases, it is necessary to add the "import asyncio" statement at the top of the code. 

You will encounter sections marked as:

==== APP DESCRIPTION: ====

These sections contain the description of the FastStream app you need to implement. Treat everything below this line, until the end of the prompt, as the description to follow for the app

In [None]:
fixture_initial_prompt = "some invalid prompt"
max_retries = 3


class FixtureGenerate:
    def __init__(self, user_prompt):
        self.messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    def __call__(self, prompt):
        self.messages.append({"role": "user", "content": prompt})
        usage = {"prompt_tokens": 129, "completion_tokens": 1, "total_tokens": 130}
        return fixture_initial_prompt, usage


fixture_generate = FixtureGenerate(fixture_initial_prompt)

with TemporaryDirectory() as d:
    def fixture_validate(response, output_path, attempt):
        return ["error 1", "error 2"]

    with pytest.raises(ValueError) as e:
        kwargs = {"attempt": 0}
        v = ValidateAndFixResponse(fixture_generate, fixture_validate, max_retries)
        actual = v.fix(fixture_initial_prompt, [], STEP_LOG_DIR_NAMES["app"], d, **kwargs)
    
    print(f"{e.value=}")
    assert not e.value.args[1]

    for i in range(3):
        assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_{i+1}").exists()
        assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_{i+1}" / "input.txt").exists()
        assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_{i+1}" / "output.txt").exists()
        assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_{i+1}" / "errors.txt").exists()
        
    with open((Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["app"] / "attempt_1" / f"try_2" / "input.txt"), "r", encoding="utf-8") as f:
        print(f.read())

[INFO] __main__: Validation failed, trying again...Errors:
error 1
error 2
[INFO] __main__: Validation failed, trying again...Errors:
error 1
error 2
[INFO] __main__: Validation failed, trying again...Errors:
error 1
error 2
e.value=ValueError([defaultdict(<class 'int'>, {'prompt_tokens': 387, 'completion_tokens': 3, 'total_tokens': 390})], False)
====system====


You are an expert Python developer, tasked to generate executable Python code as a part of your work with the FastStream framework. 

You are to abide by the following guidelines:

1. You must never enclose the generated Python code with ``` python. It is mandatory that the output is a valid and executable Python code. Please ensure this rule is never broken.

2. Some prompts might require you to generate code that contains async functions. For example:

async def app_setup(context: ContextRepo):
    raise NotImplementedError()

In such cases, it is necessary to add the "import asyncio" statement at the top of the code. 

You

In [None]:
fixture_initial_prompt = "some invalid prompt"
max_retries = 3


class FixtureGenerate:
    def __init__(self, user_prompt):
        self.messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    def __call__(self, prompt):
        self.messages.append({"role": "user", "content": prompt})
        usage = {"prompt_tokens": 129, "completion_tokens": 1, "total_tokens": 130}
        return fixture_initial_prompt, usage


fixture_generate = FixtureGenerate(fixture_initial_prompt)

with TemporaryDirectory() as d:
    def fixture_validate(response, output_path, attempt):
        return ["error 1", "error 2"]

    with pytest.raises(ValueError) as e:
        kwargs = {"attempt": 0}
        v = ValidateAndFixResponse(fixture_generate, fixture_validate, max_retries)
        actual = v.fix(fixture_initial_prompt, [], STEP_LOG_DIR_NAMES["skeleton"], d, **kwargs)
    
    print(e.value)
    assert not e.value.args[1]

    for i in range(3):
        assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["skeleton"] / "attempt_1" / f"try_{i+1}").exists()
        assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["skeleton"] / "attempt_1" / f"try_{i+1}" / "input.txt").exists()
        assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["skeleton"] / "attempt_1" / f"try_{i+1}" / "output.txt").exists()
        assert (Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["skeleton"] / "attempt_1" / f"try_{i+1}" / "errors.txt").exists()
        
    with open((Path(d) / LOGS_DIR_NAME / STEP_LOG_DIR_NAMES["skeleton"] / "attempt_1" / f"try_2" / "input.txt"), "r", encoding="utf-8") as f:
        print(f.read())

[INFO] __main__: Validation failed, trying again...Errors:
error 1
error 2
[INFO] __main__: Validation failed, trying again...Errors:
error 1
error 2
[INFO] __main__: Validation failed, trying again...Errors:
error 1
error 2
([defaultdict(<class 'int'>, {'prompt_tokens': 387, 'completion_tokens': 3, 'total_tokens': 390})], False)
====system====


You are an expert Python developer, tasked to generate executable Python code as a part of your work with the FastStream framework. 

You are to abide by the following guidelines:

1. You must never enclose the generated Python code with ``` python. It is mandatory that the output is a valid and executable Python code. Please ensure this rule is never broken.

2. Some prompts might require you to generate code that contains async functions. For example:

async def app_setup(context: ContextRepo):
    raise NotImplementedError()

In such cases, it is necessary to add the "import asyncio" statement at the top of the code. 

You will encounter se