In [None]:
# | default_exp docstring_generator

# Docstring Generator

In [None]:
# | export

import time
import random
import ast
import textwrap
import os
import re
from typing import *
from pathlib import Path

import nbformat
import openai
import typer

In [None]:
import shutil
from tempfile import TemporaryDirectory
from contextlib import contextmanager

import pytest

In [None]:
# # | export


# def _get_end_line_for_class_or_func(source: str, lineno: int) -> int:

#     tree = ast.parse("\n".join(source.split("\n")[lineno - 1 :]))
#     if len(tree.body) == 1:
#         return len(source.split("\n")[:])

#     return tree.body[1].lineno - 1

In [None]:
# | export


def _get_end_line_for_class_or_func(source: str, lineno: int) -> int:

    lines = source.split("\n")[lineno - 1 :]
    first_indent = len(lines[0]) - len(lines[0].lstrip())
    end_line_in_source = 0

    for i, line in enumerate(lines[1:]):
        if len(line) - len(line.lstrip()) == first_indent and line.strip() != "":
            end_line_in_source = i
            break

    ret_val = (
        len(source.split("\n"))
        if end_line_in_source == 0
        else end_line_in_source + lineno
    )
    return ret_val - 1

In [None]:
source = """
class A:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')   
        
    async def drive(self):
        print(f'The {self.model} is now driving.')   
"""

lineno = 2
actual = _get_end_line_for_class_or_func(source, lineno)
expected = 11
print(actual)
assert actual == expected

11


In [None]:
source = """
class A:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
        
def stop(self):
    print(f'The {self.model} is now driving.')
"""
lineno = 7
actual = _get_end_line_for_class_or_func(source, lineno)
expected = 9
print(actual)
assert actual == expected

9


In [None]:
# | export


def _get_code_from_source(source: str, start_line_no: int, end_line_no: int) -> str:
    source_lines = source.split("\n")
    extracted_lines = source_lines[start_line_no - 1 : end_line_no]
    return "\n".join(extracted_lines)

In [None]:
source = """
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""
start_line_no = 7
end_line_no = 8

expected = """    async def drive(self):
        print(f'The {self.model} is now driving.')"""

actual = _get_code_from_source(source, start_line_no, end_line_no)
print(actual)

assert actual == expected

    async def drive(self):
        print(f'The {self.model} is now driving.')


In [None]:
# | export

# Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb


def _retry_with_exponential_backoff(
    initial_delay: float = 1,
    exponential_base: float = 2,
    jitter: bool = True,
    max_retries: int = 20,
    max_wait: float = 60,
    errors: tuple = (
        openai.error.RateLimitError,
        openai.error.ServiceUnavailableError,
        openai.error.APIError,
    ),
) -> Callable:
    """Retry a function with exponential backoff."""

    def decorator(func):
        def wrapper(*args, **kwargs):
            num_retries = 0
            delay = initial_delay

            while True:
                try:
                    return func(*args, **kwargs)

                except errors as e:
                    num_retries += 1
                    if num_retries > max_retries:
                        raise Exception(
                            f"Maximum number of retries ({max_retries}) exceeded."
                        )
                    delay = min(
                        delay
                        * exponential_base
                        * (1 + jitter * random.random()),  # nosec
                        max_wait,
                    )
                    typer.secho(
                        f"Note: OpenAI's API rate limit reached. Command will automatically retry in {int(delay)} seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits",
                        fg=typer.colors.BLUE,
                    )
                    time.sleep(delay)

                except Exception as e:
                    raise e

        return wrapper

    return decorator


@_retry_with_exponential_backoff()
def _completions_with_backoff(**kwargs):
    """This function takes in a dictionary of keyword arguments and returns a completion object.

    Args:
        **kwargs: A dictionary of keyword arguments to be passed to the openai.Completion.create() function

    Returns:
        A completion object

    Raises:
        N/A


    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """

    return openai.Completion.create(**kwargs)

In [None]:
@_retry_with_exponential_backoff()
def mock_func():
    """This function is used to mock a function

    Args:
        None

    Returns:
        str: The string "Success"

    Raises:
        None


    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """

    return "Success"


assert mock_func() == "Success"

# Test max retries exceeded
@_retry_with_exponential_backoff(max_retries=1)
def mock_func_error():
    """Mock function that raises an error.

    Args:
        None

    Returns:
        None

    Raises:
        openai.error.RateLimitError


    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """

    raise openai.error.RateLimitError


with pytest.raises(Exception) as e:
    mock_func_error()

print(e.value)
assert str(e.value) == "Maximum number of retries (1) exceeded."

[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 3 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
Maximum number of retries (1) exceeded.


In [None]:
# | export


def _get_best_docstring(docstrings: List[str]) -> Optional[str]:
    """_get_best_docstring(docstrings: List[str]) -> Optional[str]

        Returns the best docstring from a list of docstrings.

        Args:
            docstrings: A list of docstrings.

        Returns:
            The best docstring.

        Raises:
            ValueError: If the list of docstrings is empty.


    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """

    docstrings = [d for d in docstrings if "Args:" in d]
    docstrings = [d for d in docstrings if "~~~~" not in d]
    return docstrings[0] if len(docstrings) > 0 else None

In [None]:
docstrings = [
    "    _check_and_add_docstrings_to_source(\n    source: str, include_auto_gen_txt: bool, **kwargs\n) -> str:\n    source = _remove_auto_generated_docstring(source)    \n    tree = ast.parse(source)\n    line_offset = 0\n\n    for node in tree.body:\n        if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):\n            continue\n        \n        if ast.get_docstring(node) is not None:\n            continue\n\n        source, line_offset = _add_docstring(\n            source, node, line_offset, include_auto_gen_txt, **kwargs\n        )\n        if not isinstance(node, ast.ClassDef):\n            continue\n        ",
    '    _check_and_add_docstrings_to_source\n    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    This function checks if the source code has docstrings for all the functions and classes.\n    If not, it adds a docstring to the function/class.\n\n    Args:\n        source (str): The source code to be checked for docstrings.\n        include_auto_gen_txt (bool): If True, the docstring will include the text "Auto-generated by nbdev".\n        **kwargs: Additional keyword arguments.\n\n    Returns:\n        str: The source code with docstrings added.\n',
    "    This function checks if the source code has docstrings for all the functions and classes.\n    If not, it adds the docstrings.\n    It also removes the auto generated docstring.\n    \n    Args:\n        source: The source code as a string.\n        include_auto_gen_txt: Whether to include the auto generated text in the docstring.\n        **kwargs: Other keyword arguments.\n    \n    Returns:\n        The source code with docstrings added.\n    \n    Raises:\n        ValueError: If the source code is not a string.\n",
]

actual = _get_best_docstring(docstrings)
expected = docstrings[2]

print(actual)
assert actual == expected

    This function checks if the source code has docstrings for all the functions and classes.
    If not, it adds the docstrings.
    It also removes the auto generated docstring.
    
    Args:
        source: The source code as a string.
        include_auto_gen_txt: Whether to include the auto generated text in the docstring.
        **kwargs: Other keyword arguments.
    
    Returns:
        The source code with docstrings added.
    
    Raises:
        ValueError: If the source code is not a string.



In [None]:
docstrings = [
    "    _check_and_add_docstrings_to_source(\n    source: str, include_auto_gen_txt: bool, **kwargs\n) -> str:\n    source = _remove_auto_generated_docstring(source)    \n    tree = ast.parse(source)\n    line_offset = 0\n\n    for node in tree.body:\n        if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):\n            continue\n        \n        if ast.get_docstring(node) is not None:\n            continue\n\n        source, line_offset = _add_docstring(\n            source, node, line_offset, include_auto_gen_txt, **kwargs\n        )\n        if not isinstance(node, ast.ClassDef):\n            continue\n        ",
    '    _check_and_add_docstrings_to_source\n    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    This function checks if the source code has docstrings for all the functions and classes.\n    If not, it adds a docstring to the function/class.\n\n    Args:\n        source (str): The source code to be checked for docstrings.\n        include_auto_gen_txt (bool): If True, the docstring will include the text "Auto-generated by nbdev".\n        **kwargs: Additional keyword arguments.\n\n    Returns:\n        str: The source code with docstrings added.\n',
    "    This function checks if the source code ~~~~~~~~~~~~~~~~ has docstrings for all the functions and classes.\n    If not, it adds the docstrings.\n    It also removes the auto generated docstring.\n    \n    Args:\n        source: The source code as a string.\n        include_auto_gen_txt: Whether to include the auto generated text in the docstring.\n        **kwargs: Other keyword arguments.\n    \n    Returns:\n        The source code with docstrings added.\n    \n    Raises:\n        ValueError: If the source code is not a string.\n",
]

actual = _get_best_docstring(docstrings)
expected = None

print(actual)
assert actual == expected

None


In [None]:
#######

# Tried the below DEFAULT_PROMPTs. The model is taking longer time to generate docstring and sometimes in wrong format

# Write a concise and high quality docstring for the above function following the Google Python Style Guide.
# The docstring should include a one-line summary, overall description of the function's purpose, arguments,
# return value, an usage example, and any exceptions or errors raised by the function.

# Write a concise and high quality docstring for the above function following the Google Python Style Guide.
# Include the function's arguments, return value, an example of how to use it, and any errors or exceptions it may raise.

#######

# The below are a few examples of prompts that are already tried

# An elaborate, high quality docstring in Google style for the above function
# Write a concise and high quality docstring for the above python code, following the Google Python Style Guide, that accurately and clearly describes the code.
# Write a concise and high quality docstring for the above python code, following the Google Python Style Guide.
# Write a concise, high quality docstring for the above code by following the Google Python Style Guide. Include function args, return types, exceptions, and usage example.
# Write a concise, high quality docstring for the above code by following the Google Python Style Guide. Include function args, return types, and exceptions.

#######

In [None]:
# | export

DOCSTRING_RETRY_ATTEMPTS = 5

PROMPT_TEMPLATE = '''
# Python 3.7

{source}

{prompt}
"""
'''

# Having multi-line prompts works the best with the codex model
# Note: The prompt must start with the # symbol
DEFAULT_PROMPT = """
# An elaborate, high quality docstring for the above function adhering to the Google python docstring format:
# Any deviation from the Google python docstring format will not be accepted
# Include one line description, args, returns and raises
"""


def _get_response(**kwargs: Union[int, float, Optional[str], List[str]]) -> Any:
    try:
        response = _completions_with_backoff(**kwargs)
    except openai.error.AuthenticationError as e:
        raise openai.error.AuthenticationError(
            "No API key provided. Please set the API key in the environment variable OPENAI_API_KEY=<API-KEY>. You can generate API keys in the OpenAI web interface. See https://onboard.openai.com for details."
        )
    return response.choices


def _generate_docstring_using_codex(
    source: str, **kwargs: Union[int, float, Optional[str], List[str]]
) -> str:
    prompt: str = DEFAULT_PROMPT if kwargs["prompt"] is None else kwargs["prompt"]  # type: ignore
    prompt = f"# {prompt}" if not prompt.startswith("#") else prompt
    kwargs["prompt"] = PROMPT_TEMPLATE.format(source=source, prompt=prompt)

    for i in range(DOCSTRING_RETRY_ATTEMPTS):
        res = _get_response(**kwargs)
        ret_val = _get_best_docstring([d.text for d in res])

        if ret_val is not None:
            return ret_val

    return """!!! note
    
    Failed to generate docs"""

In [None]:
source = """
# | export

def _check_and_add_docstrings_to_source(
    source: str, include_auto_gen_txt: bool, **kwargs
) -> str:
    source = _remove_auto_generated_docstring(source)    
    tree = ast.parse(source)
    line_offset = 0

    for node in tree.body:
        if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
            continue
        
        if ast.get_docstring(node) is not None:
            continue

        source, line_offset = _add_docstring(
            source, node, line_offset, include_auto_gen_txt, **kwargs
        )
        if not isinstance(node, ast.ClassDef):
            continue
        # Is a class and we need to check the functions inside
        # 29 - 36 make it as a recursive function
        for f in node.body:
            if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            
            if ast.get_docstring(f) is not None:
                continue

            # should be a function inside the class for which there is no docstring
            source, line_offset = _add_docstring(
                source, f, line_offset, include_auto_gen_txt, **kwargs
            )

    return source
"""

node = ast.parse(source).body[0]

docstring = _generate_docstring_using_codex(
    source,
    model="code-davinci-002",
    temperature=0.2,
    max_tokens=250,
    top_p=1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=["#", '"""'],
    prompt=None,
    n=3,
)

print(docstring)
assert "Args:" in docstring
assert "~~~" not in docstring

    This function is used to check and add docstrings to the source code.
    It will remove the auto generated docstring and add the docstring to the source code.
    It will also check if the docstring is already present or not.
    If the docstring is already present, it will not add the docstring.

    Args:
        source: The source code of the file.
        include_auto_gen_txt: A boolean value to check if the auto generated text is to be included or not.
        **kwargs: Keyword arguments.

    Returns:
        The source code with the docstring added.



In [None]:
source = """def add(x, y):
    return x + y
"""
node = ast.parse(source).body[0]

docstring = _generate_docstring_using_codex(
    source,
    model="code-davinci-002",
    temperature=0.2,
    max_tokens=250,
    top_p=1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=["#", '"""'],
    prompt=None,
)

print(docstring)
assert "Args:" in docstring
assert "~~~" not in docstring

    This function adds two numbers and returns the sum
    Args:
        x: first number
        y: second number
    Returns:
        sum of the two numbers
    Raises:
        TypeError: if inputs are not numbers



In [None]:
# | export

AUTO_GEN_PERFIX = """!!! note

"""
AUTO_GEN_BODY = "The above docstring is autogenerated by docstring-gen library"
AUTO_GEN_SUFFIX = "(https://github.com/airtai/docstring-gen)"
AUTO_GEN_TXT = AUTO_GEN_PERFIX + AUTO_GEN_BODY + " " + AUTO_GEN_SUFFIX


def _add_auto_gen_txt(docstring: str, indent: int) -> str:
    lines = AUTO_GEN_TXT.split("\n")
    auto_gen_txt = (
        textwrap.indent(lines[0], " " * (indent))
        + "\n"
        + textwrap.indent("\n".join(lines[1:]), " " * (indent + 4))
    )

    docstring = docstring + "\n" + auto_gen_txt + "\n"
    return docstring

In [None]:
docstring = """    
    Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
"""

expected = """    
    Sample docstring

    Args:
        s: sample args

    Returns:
        sample return

    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
"""
indent = 4
actual = _add_auto_gen_txt(docstring, indent)
print(actual)

assert actual == expected

    
    Sample docstring

    Args:
        s: sample args

    Returns:
        sample return

    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)



In [None]:
# | export


def _fix_docstring_indentation(
    docstring: str, col_offset: int, include_auto_gen_txt: bool
) -> str:
    indent = col_offset + 4
    lines = docstring.split("\n")
    matches = ["Args:", "Returns:", "Raises:"]
    ret_val = (
        textwrap.dedent(lines[0])
        + "\n"
        + "\n".join(
            textwrap.indent(textwrap.dedent(l), " " * indent)
            if any(x in l for x in matches)
            else textwrap.indent(textwrap.dedent(l), " " * (indent + 4))
            for l in lines[1:]
        )
    )

    ret_val = (
        ret_val if not include_auto_gen_txt else _add_auto_gen_txt(ret_val, indent)
    )

    return ret_val

In [None]:
docstring = """  Sample docstring

    Args:
            s: sample args

            Returns:
sample return
"""

expected = """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
"""
col_offset = 0
actual = _fix_docstring_indentation(docstring, col_offset, False)
print(actual)

assert actual == expected

Sample docstring

    Args:
        s: sample args

    Returns:
        sample return



In [None]:
docstring = """  Sample docstring

Args:
    s: sample args

Returns:
    sample return
"""

expected = '''Sample docstring

    Args:
        s: sample args

    Returns:
        sample return

    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
'''
col_offset = 0
actual = _fix_docstring_indentation(docstring, col_offset, True)
print(actual)

assert actual == expected

Sample docstring

    Args:
        s: sample args

    Returns:
        sample return

    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)



In [None]:
# | export


def _inject_docstring_to_source(
    source: str,
    docstring: str,
    node: Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef],
    line_offset: int,
) -> str:
    lineno = (node.body[0].lineno - 1) + line_offset
    indent = node.col_offset + 4
    lines = source.split("\n")
    lines.insert(
        lineno,
        f'{" " * indent}"""{docstring}{" " * indent}"""',
    )
    return "\n".join(lines)

In [None]:
source = """
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""

docstring = """        Sample docstring

            Args:
    s: sample args

    Returns:
            sample return
"""

expected = '''
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        """Sample docstring

        Args:
            s: sample args

        Returns:
            sample return
        """
        print(f'The {self.model} is now driving.')
'''

node = ast.parse(source).body[0].body[2]
line_offset = 0
docstring = _fix_docstring_indentation(docstring, node.col_offset, False)
actual = _inject_docstring_to_source(source, docstring, node, line_offset)
print(actual)

assert actual == expected


class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        """Sample docstring

        Args:
            s: sample args

        Returns:
            sample return
        """
        print(f'The {self.model} is now driving.')



In [None]:
# | export


def _add_docstring(
    source: str,
    node: Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef],
    line_offset: int,
    include_auto_gen_txt: bool,
    **kwargs: Union[int, float, Optional[str], List[str]],
) -> Tuple[str, int]:

    line_no = node.lineno + line_offset
    end_line_no = _get_end_line_for_class_or_func(source, line_no)
    code = _get_code_from_source(source, line_no, end_line_no)

    docstring = _generate_docstring_using_codex(code, **kwargs)
    docstring = _fix_docstring_indentation(docstring, node.col_offset, include_auto_gen_txt)
    source = _inject_docstring_to_source(source, docstring, node, line_offset)
    line_offset += len(docstring.split("\n"))
    
    return source, line_offset

In [None]:
source = """
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""

tree = ast.parse(source)
line_offset = 0

for node in tree.body:
    source, line_offset = _add_docstring(
        source,
        node,
        line_offset,
        include_auto_gen_txt=True,
        model="code-davinci-002",
        temperature=0,
        max_tokens=150,
        top_p=1.0,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        stop=["#", '"""'],
        prompt=None,
    )
    
    for f in node.body:
        if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
            continue
        source, line_offset = _add_docstring(
            source,
            f,
            line_offset,
            include_auto_gen_txt=True,
            model="code-davinci-002",
            temperature=0,
            max_tokens=150,
            top_p=1.0,
            frequency_penalty=0.0,
            presence_penalty=0.0,
            stop=["#", '"""'],
            prompt=None,
        )


def assert_docstring(source):
    tree = ast.parse(source)
    for node in tree.body:
        if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
            assert ast.get_docstring(node) is not None
        if isinstance(node, ast.ClassDef):
            for f in node.body:
                if isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
                    assert ast.get_docstring(node) is not None

    print(source)

assert_docstring(source)


class test:
    """This is a docstring for the above function.

    Args:
        a: An integer.

    Returns:
        None

    Raises:
        ValueError: If a is not an integer.

    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    CONST_VAL = 1
    def __init__(self, a):
        """This is a docstring for the above function

        Args:
            a (int): This is the first parameter

        Returns:
            int: This is a description of what is returned

        Raises:
            KeyError: Raises an exception

        !!! note

            The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
        """
        self.a = a
        
    async def drive(self):
        """This function is used to drive the car.

        Args:
            None

        Returns:
            None

        Raises:
            None

        !!! note

           

In [None]:
# | export


def _remove_auto_generated_docstring(source: str) -> str:
    return re.sub(
        f'"""((?!""").)*?({AUTO_GEN_BODY}).*?"""', "", source, flags=re.DOTALL
    )

In [None]:
source = '''
def decorator1(func):
    """Decorator function that takes a function as an argument and returns a function."""
    pass
    
def decorator2(func):
    
    pass
'''

expected = '''
def decorator1(func):
    """Decorator function that takes a function as an argument and returns a function."""
    pass
    
def decorator2(func):
    
    pass
'''

actual = _remove_auto_generated_docstring(source)
print(actual)

assert actual == expected


def decorator1(func):
    """Decorator function that takes a function as an argument and returns a function."""
    pass
    
def decorator2(func):
    
    pass



In [None]:
# | export


def _check_and_add_docstrings_to_source(
    source: str,
    include_auto_gen_txt: bool,
    recreate_auto_gen_docs: bool,
    **kwargs: Union[int, float, Optional[str], List[str]]
) -> str:

    if recreate_auto_gen_docs:
        source = _remove_auto_generated_docstring(source)

    tree = ast.parse(source)
    line_offset = 0

    def _add_docstrings_recursively(node):
        nonlocal source, line_offset
        if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
            if ast.get_docstring(node) is None:
                source, line_offset = _add_docstring(
                    source, node, line_offset, include_auto_gen_txt, **kwargs
                )
                if isinstance(node, ast.ClassDef):
                    for n in node.body:
                        _add_docstrings_recursively(n)

    for node in tree.body:
        _add_docstrings_recursively(node)

    return source

In [None]:
source = """
def decorator1(func):
    def inner():
        func()
    return inner

def decorator2(func):
    def inner():
        func()
    return inner

@decorator1
@decorator2
def outer_func():
    def inner_func():
        print("Hello, World!")
    inner_func()
    
class Test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
        
    class Inner:
        def __init__(self, b):
            self.b = b
            
        def stop(self):
            pass
"""
updated_source = _check_and_add_docstrings_to_source(
    source,
    include_auto_gen_txt=True,
    recreate_auto_gen_docs=True,
    model="code-davinci-002",
    temperature=0,
    max_tokens=250,
    top_p=1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=["#", '"""'],
    prompt=None,
)

assert_docstring(updated_source)

[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 2 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 2 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m

def decorator1(func):
    """This is a decorator function that takes a function as an argument and returns a function.

    Args:
        func: A function

    Returns:
        A function

    Raises:
        None

    !!! note

        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    def inner():
        func()
    return inner

def decorator2(func):
    """This is a decorator function that takes a function as an argument and returns a function.

    Args:
        func: A function

    Returns:
        A functio

In [None]:
# | export


def _get_files(nb_path: Path) -> List[Path]:
    exts = [".ipynb", ".py"]
    files = [
        f
        for f in nb_path.rglob("*")
        if f.suffix in exts
        and not any(p.startswith(".") for p in f.parts)
        and not f.name.startswith("_")
    ]

    if len(files) == 0:
        raise ValueError(
            f"The directory {nb_path.resolve()} does not contain any Python files or notebooks"
        )

    return files

In [None]:
with TemporaryDirectory() as d:
    nbs_path = Path(d) / "nbs"
    nbs_path.mkdir(parents=True)

    hidden_dir = nbs_path / ".hidden"
    hidden_dir.mkdir(parents=True)

    shutil.copyfile(Path("..") / "settings.ini", nbs_path / "settings.ini")
    shutil.copyfile(
        Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "_test.ipynb"
    )
    shutil.copyfile(
        Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "test.ipynb"
    )
    shutil.copyfile(
        Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "test_1.ipynb"
    )

    shutil.copyfile(
        Path("..") / "fixtures" / "Test_Data.ipynb", hidden_dir / "test.ipynb"
    )
    shutil.copyfile(
        Path("..") / "fixtures" / "Test_Data.ipynb", hidden_dir / "test_1.ipynb"
    )

    for f in nbs_path.rglob("*"):
        print(f)

    files = _get_files(nbs_path)

    assert len(files) == 2
    print(f"\n\n{files}")
    assert files == [nbs_path / "test_1.ipynb", nbs_path / "test.ipynb"]

/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/.hidden
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/test_1.ipynb
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/settings.ini
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/test.ipynb
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/_test.ipynb
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/.hidden/test_1.ipynb
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/.hidden/test.ipynb


[PosixPath('/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/test_1.ipynb'), PosixPath('/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpb0osodkg/nbs/test.ipynb')]


In [None]:
with pytest.raises(ValueError) as e:

    with TemporaryDirectory() as d:
        nbs_path = Path(d) / "nbs"
        nbs_path.mkdir(parents=True)

        _get_files(nbs_path)

print(e.value)

The directory /private/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpyikzf8qh/nbs does not contain any Python files or notebooks


In [None]:
# | export


def _add_docstring_to_nb(
    file: Path,
    version: int,
    include_auto_gen_txt: bool,
    recreate_auto_gen_docs: bool,
    **kwargs: Union[int, float, Optional[str], List[str]]
) -> None:
    _f = nbformat.read(file, as_version=version)
    for cell in _f.cells:
        if cell.cell_type == "code":
            cell["source"] = _check_and_add_docstrings_to_source(
                cell["source"], include_auto_gen_txt, recreate_auto_gen_docs, **kwargs
            )
    nbformat.write(_f, file)


def _add_docstring_to_py(
    file: Path,
    include_auto_gen_txt: bool,
    recreate_auto_gen_docs: bool,
    **kwargs: Union[int, float, Optional[str], List[str]]
) -> None:
    with file.open("r") as f:
        source = f.read()
    source = _check_and_add_docstrings_to_source(
        source, include_auto_gen_txt, recreate_auto_gen_docs, **kwargs
    )
    with file.open("w") as f:
        f.write(source)


def add_docstring_to_source(
    path: Union[str, Path],
    version: int = 4,
    include_auto_gen_txt: bool = True,
    recreate_auto_gen_docs: bool = False,
    model: str = "code-davinci-002",
    temperature: float = 0.2,
    max_tokens: int = 250,
    top_p: float = 1.0,
    n: int = 3,
    prompt: Optional[str] = None,
) -> None:

    path = Path(path)
    files = _get_files(path) if path.is_dir() else [path]

    frequency_penalty = 0.0
    presence_penalty = 0.0
    stop = ["#", '"""']

    for file in files:
        if file.suffix == ".ipynb":
            _add_docstring_to_nb(
                file=file,
                version=version,
                include_auto_gen_txt=include_auto_gen_txt,
                recreate_auto_gen_docs=recreate_auto_gen_docs,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                stop=stop,
                n=n,
                prompt=prompt,
            )
        else:
            _add_docstring_to_py(
                file=file,
                include_auto_gen_txt=include_auto_gen_txt,
                recreate_auto_gen_docs=recreate_auto_gen_docs,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                stop=stop,
                n=n,
                prompt=prompt,
            )

In [None]:
with TemporaryDirectory() as d:
    nbs_path = Path(d) / "nbs"
    nbs_path.mkdir(parents=True)

    shutil.copyfile(
        Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "test.ipynb"
    )
    shutil.copyfile(
        Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "_test.ipynb"
    )

    shutil.copyfile(Path("..") / "fixtures" / "test_data.py", nbs_path / "test_data.py")
    shutil.copyfile(Path("..") / "settings.ini", nbs_path / "settings.ini")

    add_docstring_to_source(nbs_path)

    with (nbs_path / "test.ipynb").open("r") as f:
        nb = nbformat.read(f, as_version=4)

for cell in nb.cells:
    #     print(cell["source"])

    if cell.cell_type == "code":
        tree = ast.parse(cell["source"])
        print(cell["source"])
        for node in tree.body:
            if not isinstance(
                node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)
            ):
                continue
            assert ast.get_docstring(node)

    else:
        print(cell["source"])

[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 2 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 9 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 2 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 2 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 3 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-us