In [None]:
# | default_exp docstring_generator

# Docstring Generator

In [None]:
# | export

import time
import random
import ast
import tokenize
import os
import re
from typing import *
from pathlib import Path
from io import BytesIO
from configparser import ConfigParser

import nbformat
import openai
import typer

In [None]:
import shutil
from tempfile import TemporaryDirectory
from contextlib import contextmanager

import pytest

In [None]:
# | export

def _get_code_from_source(source: str, start_line_no: int, end_line_no: int) -> str:
    """    This function takes in the source code of a python file, the start line number and the end line number
        and returns the code snippet from the source code.
        
        Args:
            source (str): The source code of the python file
            start_line_no (int): The start line number
            end_line_no (int): The end line number
            
        Returns:
            str: The code snippet from the source code
            
        Raises:
            ValueError: If the start line number is greater than the end line number


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    
    
    
    
    source_lines = source.split("\n")
    extracted_lines = source_lines[start_line_no-1:end_line_no]
    return "\n".join(extracted_lines)

In [None]:
source = """
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""
start_line_no = 7
end_line_no = 8

expected = """    async def drive(self):
        print(f'The {self.model} is now driving.')"""

actual = _get_code_from_source(source, start_line_no, end_line_no)
print(actual)

assert actual == expected


    async def drive(self):
        print(f'The {self.model} is now driving.')


In [None]:
# | export


def _calculate_end_lineno(source: str, start_line_no: int) -> int:
    """Calculate the end line number of a function.
    
    Args:
        source: The source code of the file.
        start_line_no: The line number of the function definition.
    
    Returns:
        The line number of the end of the function.
    
    Raises:
        ValueError: If the function definition is not found.


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    
    
    
    
    lines = source.split("\n")[start_line_no - 1 :]
    first_indent = len(lines[0]) - len(lines[0].lstrip())
    end_line_in_source = 0

    for i, line in enumerate(lines[1:]):
        if len(line) - len(line.lstrip()) == first_indent and line.strip() != "":
            end_line_in_source = i
            break

    ret_val = (
        len(source.split("\n"))
        if end_line_in_source == 0
        else end_line_in_source + start_line_no
    )
    return ret_val - 1

In [None]:
source = """
class A:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')

class B:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
        
    async def stop(self):
        print(f'The {self.model} is now stopped.')
        
class C:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a

    async def drive(self):
        print(f'The {self.model} is now driving.')
        
async def drive(self):
    print(f'The {self.model} is now driving.')
"""

tree = ast.parse(source)
expected = [8, 5, 9, 19, 13, 16, 20, 27, 24, 28, 30]
actual = []
for i, node in enumerate(tree.body):
    res = _calculate_end_lineno(source, node.lineno)
    actual.append(res)
#     print(node.end_lineno)

    if isinstance(node, ast.ClassDef):
        for f in node.body:
            if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            res = _calculate_end_lineno(source, f.lineno)
            actual.append(res)
#             print(f.end_lineno)
print(actual)
assert actual == expected

[8, 5, 9, 19, 13, 16, 20, 27, 24, 28, 30]


In [None]:
# | export


def _line_has_decorator(source: str, lineno: int) -> bool:
    """    This function checks if a line has a decorator.
        Args:
            source: The source code of the file
            lineno: The line number of the line to be checked
        Returns:
            True if the line has a decorator, False otherwise
        Raises:
            None


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    
    
    
    
    line = "".join(source.split("\n")[lineno - 1])
    return line.startswith("@") or line.strip() == ""

def _get_start_line_for_class_or_func(source: str, lineno: int) -> int:
    """


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    
    
    
    
    if not _line_has_decorator(source, lineno):
        return lineno

    original_lineno = lineno
    total_lines = source.split("\n")
    for i in total_lines:
        lineno += 1
        if lineno > len(total_lines):
            break
        if not _line_has_decorator(source, lineno):
            return lineno
    return original_lineno

In [None]:
source = """def decorator1(func):
    \"""Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    \"""
    def inner():
        func()
    return inner

def decorator2(func):
    \"""Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    \"""
    def inner():
        func()
    return inner


@decorator1
@decorator2
def outer_func():
    def inner_func():
        print("Hello, World!")
    inner_func()
"""

line_no = 28
expected = 30
actual = _get_start_line_for_class_or_func(source, line_no)
print(actual)
assert actual == expected

30


In [None]:
source = """
async def drive(self
):
    print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 2
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected

source = """
@foo(x=5)
@decorator2
async def drive(self):
    print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 4
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected

source = """
@foo(x=5)
async def drive(self):
    print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 3
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected

source = """
@foo(x=5)
@bar(x=5)
@zar(x=5)
class A:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 5
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected

source = """
@foo(x=5)
@bar(x=5)

"""
lineno = 2
expected = 2
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected, f"actual = {actual}"



2
4
3
5
2


In [None]:
# | export

def _get_lineno_to_append_docstring(source: str, lineno: int) -> int:
    """    This function takes in a source code and a line number and returns the line number where the docstring should be appended.
        It does this by checking if the source code is tokenized. If it is not, it raises a token error.
        If it is, it returns the line number where the docstring should be appended.
        
        Args:
            source: The source code of the function
            lineno: The line number of the function
            
        Returns:
            The line number where the docstring should be appended
            
        Raises:
            TokenError: If the source code is not tokenized


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    
    
    
    
    line_offset = 0
    is_src_tokenized = False
    lines = source.split("\n")[lineno - 1:]
    
    for i in range(len(lines)):
        line = "".join(source.split("\n")[lineno - 1:][:i+1])
        if line != "":
            try:
                list(tokenize.tokenize(BytesIO(line.encode("utf-8")).readline))
                is_src_tokenized = True
                break
            except tokenize.TokenError as e:
                line_offset +=1
                continue
    if not is_src_tokenized:
        raise tokenize.TokenError(f"TokenError: {source}")
    
    ret_val = line_offset + lineno
    return ret_val

In [None]:
source = """
def gen(
    path: str = typer.Option(
      ...,
        help="The path to the Jupyter notebook or Python file, or a directory containing these files",
    ),
    prompt: Optional[str] = typer.Option(
        None,
 help="Text that will be given as input to the GPT-3 model to generate the docstring. If no text is provided, the docstring will be generated according to the Google Python Style Guide.",
    ),
    include_auto_gen_txt: bool = typer.Option(
        True,
        help="If set to true, a note indicating that the docstring was autogenerated will be added to the end of the docstring.",
    ),
    model: str = typer.Option(
        "code-davinci-002",
        help="The name of the GPT-3 model to use for docstring generation.",
    ),
    temperature: int= typer.Option(
        0, help="The temperature parameter for the GPT-3 model."
    ),
    max_tokens: int = typer.Option(
        250, help="The maximum number of tokens to generate in the docstring."
    ),
    top_p: float = typer.Option(1.0, help="The top-p parameter for the GPT-3 model."),
    frequency_penalty: float = typer.Option(
        0.0, help="The frequency penalty parameter for the GPT-3 model."
    ),
    presence_penalty: float = typer.Option(
        0.0, help="The presence penalty parameter for the GPT-3 model."
    )
):
    pass
"""
lineno = 2
expected = 32

actual = _get_lineno_to_append_docstring(source, lineno)
print(actual)

assert actual == expected

32


In [None]:
source = """
async def drive(self):
    print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 2

actual = _get_lineno_to_append_docstring(source, lineno)
print(actual)

assert actual == expected

source = """
async def drive(self,
a,
b,
c,
"""
lineno = 2
with pytest.raises(tokenize.TokenError) as e:
    _get_lineno_to_append_docstring(source, lineno)
print(e.value)

2
TokenError: 
async def drive(self,
a,
b,
c,



In [None]:
# | export

AUTO_GEN_PERFIX = '''!!! note

'''

# AUTO_GEN_BODY will be used in the {} function for replacing the autogenerated docstring from the previous run
AUTO_GEN_BODY = "The above docstring is autogenerated by docstring-gen library"

AUTO_GEN_SUFFIX = "(https://github.com/airtai/docstring-gen)"

AUTO_GEN_TXT =  AUTO_GEN_PERFIX + " " * 4 + AUTO_GEN_BODY + " "+ AUTO_GEN_SUFFIX

In [None]:
expected = '''!!! note

    The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)'''

print(AUTO_GEN_TXT)
assert AUTO_GEN_TXT == expected

!!! note

    The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)


In [None]:
# | export

def _inject_docstring_to_source(
    source: str,
    docstring: str,
    lineno: int,
    node_col_offset: int,
    include_auto_gen_txt: bool,
) -> str:
    """Injects a docstring into a source file.
    
    Args:
        source: The source code to inject the docstring into.
        docstring: The docstring to inject.
        lineno: The line number to inject the docstring at.
        node_col_offset: The column offset of the node to inject the docstring at.
        include_auto_gen_txt: Whether to include the auto-generated text.
    
    Returns:
        The source code with the docstring injected.
    
    Raises:
        ValueError: If the source code is invalid.


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    
    
    
    

    lineno = _get_lineno_to_append_docstring(source, lineno)
    lines = source.split("\n")
    indented_docstring = "\n".join(
        [
            line
            if i == 0 or i == len(docstring.split("\n")) - 1
            else f"{' ' * (node_col_offset + 4)}{line}"
            for i, line in enumerate(docstring.split("\n"))
        ]
    )
    indent = node_col_offset + 4
    nl = "\n"
    auto_gen_txt = f'{nl + nl + (nl.join((" " * indent + i) for i in AUTO_GEN_TXT.split(nl))) + nl if include_auto_gen_txt else ""}'
    lines.insert(
        lineno,
        f'{" " * indent}"""{indented_docstring}{auto_gen_txt}{" " * indent}"""',
    )
    return "\n".join(lines)

In [None]:
source = """
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""

docstring = """Sample docstring

Args:
    s: sample args

Returns:
    sample return
"""

expected = '''
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        """Sample docstring
        
        Args:
            s: sample args
        
        Returns:
            sample return
        """
        print(f'The {self.model} is now driving.')
'''

lineno = 7
node_col_offset = 4
include_auto_gen_txt = False
actual = _inject_docstring_to_source(source, docstring, lineno, node_col_offset, include_auto_gen_txt)
print(actual)

assert actual == expected


class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        """Sample docstring
        
        Args:
            s: sample args
        
        Returns:
            sample return
        """
        print(f'The {self.model} is now driving.')



In [None]:
source = """
async def drive(
self,
a,
b,
c
):
    print(f'The {self.model} is now driving.')
"""

docstring = """Sample docstring

Args:
    s: sample args

Returns:
    sample return
"""

expected = """
async def drive(
self,
a,
b,
c
):
    \"""Sample docstring
    
    Args:
        s: sample args
    
    Returns:
        sample return
    \"""
    print(f'The {self.model} is now driving.')
"""

lineno = 2
node_col_offset = 0
include_auto_gen_txt = False
actual = _inject_docstring_to_source(source, docstring, lineno, node_col_offset, include_auto_gen_txt)
print(actual)

assert actual == expected


async def drive(
self,
a,
b,
c
):
    """Sample docstring
    
    Args:
        s: sample args
    
    Returns:
        sample return
    """
    print(f'The {self.model} is now driving.')



In [None]:
# | export

# Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb

def retry_with_exponential_backoff(
    initial_delay: float = 1,
    exponential_base: float = 2,
    jitter: bool = True,
    max_retries: int = 10,
    max_wait: float = 60,
    errors: tuple = (openai.error.RateLimitError, openai.error.ServiceUnavailableError),
):
    """Retry a function with exponential backoff."""

    def decorator(func):
        def wrapper(*args, **kwargs):
            num_retries = 0
            delay = initial_delay

            while True:
                try:
                    return func(*args, **kwargs)

                except errors as e:
                    num_retries += 1
                    if num_retries > max_retries:
                        raise Exception(
                            f"Maximum number of retries ({max_retries}) exceeded."
                        )
                    delay = min(
                        delay * exponential_base * (1 + jitter * random.random()), # nosec
                        max_wait,
                    )
                    typer.secho(
                        f"Note: OpenAI's API rate limit reached. Command will automatically retry in {int(delay)} seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits",
                        fg=typer.colors.BLUE,
                    )
                    time.sleep(delay)

                except Exception as e:
                    raise e

        return wrapper

    return decorator


@retry_with_exponential_backoff()
def completions_with_backoff(**kwargs):
    return openai.Completion.create(**kwargs)

In [None]:
@retry_with_exponential_backoff()
def mock_func():
    return "Success"

assert mock_func() == "Success"

# Test max retries exceeded
@retry_with_exponential_backoff(max_retries=1)
def mock_func_error():
    raise openai.error.RateLimitError

with pytest.raises(Exception) as e:
    mock_func_error()
    
print(e.value)
assert str(e.value) == "Maximum number of retries (1) exceeded."

[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 3 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
Maximum number of retries (1) exceeded.


In [None]:
#######

# Tried the below DEFAULT_PROMPTs. The model is taking longer time to generate docstring and sometimes in wrong format

# Write a concise and high quality docstring for the above function following the Google Python Style Guide.
# The docstring should include a one-line summary, overall description of the function's purpose, arguments,
# return value, an usage example, and any exceptions or errors raised by the function.

# Write a concise and high quality docstring for the above function following the Google Python Style Guide.
# Include the function's arguments, return value, an example of how to use it, and any errors or exceptions it may raise.

#######

# The below are a few examples of prompts that are already tried

# An elaborate, high quality docstring in Google style for the above function
# Write a concise and high quality docstring for the above python code, following the Google Python Style Guide, that accurately and clearly describes the code.
# Write a concise and high quality docstring for the above python code, following the Google Python Style Guide.
# Write a concise, high quality docstring for the above code by following the Google Python Style Guide. Include function args, return types, exceptions, and usage example.
# Write a concise, high quality docstring for the above code by following the Google Python Style Guide. Include function args, return types, and exceptions.

#######

In [None]:
def _get_ast_tree(
    source: str,
) -> Union[ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef]:
    tree = ast.parse(source)
    return tree.body[0]

In [None]:
source = """def add(x, y):
    return x + y
"""

actual = _get_ast_tree(source)

print(actual)
type(actual) == ast.FunctionDef

<ast.FunctionDef object>


True

In [None]:
# | export

DOCSTRING_RETRY_ATTEMPTS = 1

PROMPT_TEMPLATE = '''
# Python 3.7

{source}

{prompt}
"""
'''

# Having multi-line prompts works the best with the codex model
# Note: The prompt must start with the # symbol
DEFAULT_PROMPT = '''
# An elaborate, high quality docstring for the above function adhering to the Google python docstring format:
# Any deviation from the Google python docstring format will not be accepted
# Include one line description, args, returns and raises
'''


def _get_response(**kwargs):
    """    This function takes in a prompt and returns a response.
        Args:
            prompt: A string that is the prompt for the response.
            max_tokens: An integer that is the maximum number of tokens to be returned.
            temperature: A float that is the temperature for the response.
            top_p: A float that is the top_p for the response.
            n: An integer that is the number of responses to be returned.
            stream: A string that is the stream for the response.
            logprobs: A float that is the logprobs for the response.
            stop: A string that is the stop for the response.
        Returns:
            A string that is the response.
        Raises:
            openai.error.AuthenticationError: If no API key is provided.


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    try:
        response = completions_with_backoff(**kwargs)
    except openai.error.AuthenticationError as e:
        raise openai.error.AuthenticationError(
            "No API key provided. Please set the API key in the environment variable OPENAI_API_KEY=<API-KEY>. You can generate API keys in the OpenAI web interface. See https://onboard.openai.com for details."
        )
    return response.choices[0].text


def _generate_docstring_using_codex(source: str, **kwargs) -> str:
    """Generates a docstring for a given source code using codex.
    
    Args:
        source (str): The source code for which the docstring is to be generated.
        **kwargs: Arbitrary keyword arguments.
    
    Returns:
        str: The generated docstring.
    
    Raises:
        ValueError: If the source code is not a string.


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    prompt = DEFAULT_PROMPT if kwargs["prompt"] is None else kwargs["prompt"]
    prompt = "# " + prompt if not prompt.startswith("#") else prompt
    kwargs["prompt"] = PROMPT_TEMPLATE.format(source=source, prompt=prompt)

    ret_val: str = _get_response(**kwargs)
    for i in range(DOCSTRING_RETRY_ATTEMPTS):
        if "Args:" in ret_val:
            break
        ret_val = _get_response(**kwargs)

    return ret_val

In [None]:
source = '''
# | export

def _check_and_add_docstrings_to_source(
    source: str, include_auto_gen_txt: bool, **kwargs
) -> str:
    source = _remove_auto_generated_docstring(source)    
    tree = ast.parse(source)
    line_offset = 0

    for node in tree.body:
        if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
            continue
        
        if ast.get_docstring(node) is not None:
            continue

        source, line_offset = _add_docstring(
            source, node, line_offset, include_auto_gen_txt, **kwargs
        )
        if not isinstance(node, ast.ClassDef):
            continue
        # Is a class and we need to check the functions inside
        # 29 - 36 make it as a recursive function
        for f in node.body:
            if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            
            if ast.get_docstring(f) is not None:
                continue

            # should be a function inside the class for which there is no docstring
            source, line_offset = _add_docstring(
                source, f, line_offset, include_auto_gen_txt, **kwargs
            )

    return source
'''

node = _get_ast_tree(source)

docstring = _generate_docstring_using_codex(
    source,
    model="code-davinci-002",
    temperature=0,
    max_tokens=250,
    top_p=1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=["#", '"""'],
    prompt = None
)
source_with_docstring = _inject_docstring_to_source(
    source, docstring, node.lineno, node.col_offset, include_auto_gen_txt=True
)

print(source_with_docstring)
assert ast.get_docstring(ast.parse(source_with_docstring).body[0]) is not None




# | export

def _check_and_add_docstrings_to_source(
    source: str, include_auto_gen_txt: bool, **kwargs
) -> str:
    """    _check_and_add_docstrings_to_source(source: str, include_auto_gen_txt: bool, **kwargs) -> str:
        source = _remove_auto_generated_docstring(source)    
        tree = ast.parse(source)
        line_offset = 0
    
        for node in tree.body:
            if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            
            if ast.get_docstring(node) is not None:
                continue
    
            source, line_offset = _add_docstring(
                source, node, line_offset, include_auto_gen_txt, **kwargs
            )
            if not isinstance(node, ast.ClassDef):
                continue
        

    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    source = _remove_auto_generated_docstring(s

In [None]:
source = """def add(x, y):
    return x + y
"""
node = _get_ast_tree(source)

docstring = _generate_docstring_using_codex(
    source,
    model="code-davinci-002",
    temperature=0,
    max_tokens=250,
    top_p=1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=["#", '"""'],
    prompt = None
)
source_with_docstring = _inject_docstring_to_source(
    source, docstring, node.lineno, node.col_offset, include_auto_gen_txt=True
)

print(source_with_docstring)
assert ast.get_docstring(ast.parse(source_with_docstring).body[0]) is not None

def add(x, y):
    """This function adds two numbers
    
    Args:
        x (int): first number
        y (int): second number
    
    Returns:
        int: sum of x and y
    
    Raises:
        TypeError: if x or y is not an integer


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    return x + y



In [None]:
@contextmanager
def unset_env_var(name: str):
    """    This function is used to unset the environment variable.
        Args:
            name: name of the environment variable
        Returns:
            None
        Raises:
            None


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    env_var_defined = name in os.environ
    api_key_set_in_openai = openai.api_key is not None
    if env_var_defined:
        original_value = os.environ.get(name)
        del os.environ[name]

    if api_key_set_in_openai:
        original_api_key = openai.api_key
        openai.api_key = None
    try:
        yield
    finally:
        if env_var_defined:
            os.environ[name] = original_value
        if api_key_set_in_openai:
            openai.api_key = original_api_key


source = """def add(x, y):
    return x + y
"""
node = _get_ast_tree(source)

with pytest.raises(openai.error.AuthenticationError) as e:
    with unset_env_var("OPENAI_API_KEY"):
        _generate_docstring_using_codex(
            source,
            model="code-davinci-002",
            temperature=0,
            max_tokens=150,
            top_p=1.0,
            frequency_penalty=0.0,
            presence_penalty=0.0,
            stop=["#", '"""'],
            prompt=None,
        )

print(e.value)
assert (
    "Please set the API key in the environment variable OPENAI_API_KEY=<API-KEY>"
    in str(e.value)
)

No API key provided. Please set the API key in the environment variable OPENAI_API_KEY=<API-KEY>. You can generate API keys in the OpenAI web interface. See https://onboard.openai.com for details.


In [None]:
# | export


def _add_docstring(
    source: str,
    node: Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef],
    line_offset: int,
    include_auto_gen_txt: bool,
    **kwargs,
) -> Tuple[str, int]:
    """Adds a docstring to a class or function.
    
    Args:
        source: The source code of the file.
        node: The class or function node.
        line_offset: The line offset of the file.
        include_auto_gen_txt: Whether to include the auto generated text.
        **kwargs: Additional keyword arguments.
    
    Returns:
        A tuple of the source code and the line offset.
    
    Raises:
        ValueError: If the source code is not a string.


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    line_no = node.lineno + line_offset

    # Fix for  Python 3.7
    # Delete the below line once support for Python 3.7 is dropped
    line_no = _get_start_line_for_class_or_func(source, line_no)

    if hasattr(node, "end_lineno") and node.end_lineno is not None:
        end_line_no = node.end_lineno + line_offset
    else:
        end_line_no = _calculate_end_lineno(source, line_no)

    code = _get_code_from_source(source, line_no, end_line_no)
    docstring = _generate_docstring_using_codex(code, **kwargs)

    source = _inject_docstring_to_source(
        source, docstring, line_no, node.col_offset, include_auto_gen_txt
    )
    line_offset += (
        len(docstring.split("\n"))
        if not include_auto_gen_txt
        else len(docstring.split("\n")) + len(AUTO_GEN_TXT.split("\n")) + 2 # the 2 is for the \n characters at the beginning
    )
    return source, line_offset

In [None]:
source = """
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""

tree = ast.parse(source)
line_offset = 0

for node in tree.body:
    source, line_offset = _add_docstring(
        source,
        node,
        line_offset,
        include_auto_gen_txt= True,
        model="code-davinci-002",
        temperature=0,
        max_tokens=150,
        top_p=1.0,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        stop=["#", '"""'],
        prompt=None,
    )

    for f in node.body:
        if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
            continue
        source, line_offset = _add_docstring(
            source,
            f,
            line_offset,
            include_auto_gen_txt= True,
            model="code-davinci-002",
            temperature=0,
            max_tokens=150,
            top_p=1.0,
            frequency_penalty=0.0,
            presence_penalty=0.0,
            stop=["#", '"""'],
            prompt=None,
        )
        
def assert_docstring(source):
    """This function takes in a list of numbers and returns the sum of the numbers
    
    Args:
        nums (list): list of numbers
    
    Returns:
        int: sum of the numbers
    
    Raises:
        TypeError: if input is not a list


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    tree = ast.parse(source)
    for node in tree.body:
        if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
            assert ast.get_docstring(node) is not None
        if isinstance(node, ast.ClassDef):
            for f in node.body:
                if isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
                    assert ast.get_docstring(node) is not None

    print(source)
assert_docstring(source)

[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 2 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m

class test:
    """This is a docstring for the above function.
    
    Args:
        a: An integer
    
    Returns:
        None
    
    Raises:
        ValueError: If a is not an integer


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    CONST_VAL = 1
    def __init__(self, a):
        """This is a docstring for the above function
        
        Args:
            a (int): This is the first parameter
        
        Returns:
            int: This is a description of what is returned
        
        Raises:
            KeyError: Raises an exception


        !!! note
        
            The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring

In [None]:
# | export

def _remove_auto_generated_docstring(source: str) -> str:
    """Removes the auto-generated docstring from the source code.
    
    Args:
        source (str): The source code.
    
    Returns:
        str: The source code with the auto-generated docstring removed.
    
    Raises:
        ValueError: If the source code does not contain the auto-generated docstring.


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    return re.sub(f'"""((?!""").)*?({AUTO_GEN_BODY}).*?"""', '', source, flags=re.DOTALL)

In [None]:
source = '''
def decorator1(func):
    """Decorator function that takes a function as an argument and returns a function."""
    pass
    
def decorator2(func):
    
    pass
'''

expected = '''
def decorator1(func):
    """Decorator function that takes a function as an argument and returns a function."""
    pass
    
def decorator2(func):
    
    pass
'''

actual = _remove_auto_generated_docstring(source)
print(actual)

assert actual == expected


def decorator1(func):
    """Decorator function that takes a function as an argument and returns a function."""
    pass
    
def decorator2(func):
    
    pass



In [None]:
# | export

def _check_and_add_docstrings_to_source(
    source: str, include_auto_gen_txt: bool, **kwargs
) -> str:
    """    This function checks if the source code has docstrings for all the functions and classes.
        If not, it adds the docstrings to the source code.
        It also removes the auto-generated docstrings.
    
        Args:
            source (str): The source code of the file.
            include_auto_gen_txt (bool): Whether to include the auto-generated text in the docstring.
            **kwargs: Any other keyword arguments.
    
        Returns:
            str: The source code with the docstrings added.


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    source = _remove_auto_generated_docstring(source)    
    tree = ast.parse(source)
    line_offset = 0

    for node in tree.body:
        if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
            continue
        
        if ast.get_docstring(node) is not None:
            continue

        source, line_offset = _add_docstring(
            source, node, line_offset, include_auto_gen_txt, **kwargs
        )
        if not isinstance(node, ast.ClassDef):
            continue
        # Is a class and we need to check the functions inside
        # 29 - 36 make it as a recursive function
        for f in node.body:
            if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            
            if ast.get_docstring(f) is not None:
                continue

            # should be a function inside the class for which there is no docstring
            source, line_offset = _add_docstring(
                source, f, line_offset, include_auto_gen_txt, **kwargs
            )

    return source

In [None]:
source = """
def decorator1(func):
    def inner():
        func()
    return inner

def decorator2(func):
    def inner():
        func()
    return inner

@decorator1
@decorator2
def outer_func():
    def inner_func():
        print("Hello, World!")
    inner_func()
"""
actual = _check_and_add_docstrings_to_source(
    source,
    include_auto_gen_txt = True,
    model="code-davinci-002",
    temperature=0,
    max_tokens=250,
    top_p=1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=["#", '"""'],
    prompt=None,
)
assert_docstring(actual)


def decorator1(func):
    """This is a decorator function that takes a function as an argument and returns a function.
    
    Args:
        func: A function
    
    Returns:
        A function
    
    Raises:
        None


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    def inner():
        func()
    return inner

def decorator2(func):
    """This is a decorator function that takes a function as an argument and returns a function.
    
    Args:
        func: A function
    
    Returns:
        A function
    
    Raises:
        None


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    def inner():
        func()
    return inner

@decorator1
@decorator2
def outer_func():
    """This function prints "Hello, World!"
    
    Args:
        None
    
    Returns:
        None
    
    Raises:
     

In [None]:
# | export


def _get_files(nb_path: Path) -> List[Path]:
    """    This function returns a list of files in the given directory.
        The files returned are either .ipynb or .py files.
        The files returned do not start with a '.' or '_'
        
        Args:
            nb_path: The path to the directory to be searched
            
        Returns:
            A list of files in the given directory
            
        Raises:
            ValueError: If the directory does not contain any Python files or notebooks


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    exts = [".ipynb", ".py"]
    files = [
        f
        for f in nb_path.rglob("*")
        if f.suffix in exts
        and not any(p.startswith(".") for p in f.parts)
        and not f.name.startswith("_")
    ]
    
    if len(files) == 0:
        raise ValueError(f"The directory {nb_path.resolve()} does not contain any Python files or notebooks")

    return files

In [None]:
with TemporaryDirectory() as d:
    nbs_path = Path(d) / "nbs"
    nbs_path.mkdir(parents=True)
    
    hidden_dir = nbs_path / ".hidden"
    hidden_dir.mkdir(parents=True)

    shutil.copyfile(Path("..") / "settings.ini", nbs_path / "settings.ini")
    shutil.copyfile(Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "_test.ipynb")
    shutil.copyfile(Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "test.ipynb")
    shutil.copyfile(Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "test_1.ipynb")
    
    shutil.copyfile(Path("..") / "fixtures" / "Test_Data.ipynb", hidden_dir / "test.ipynb")
    shutil.copyfile(Path("..") / "fixtures" / "Test_Data.ipynb", hidden_dir / "test_1.ipynb")
    
    for f in nbs_path.rglob("*"):
        print(f)

    files = _get_files(nbs_path)

    assert len(files) == 2
    print(f"\n\n{files}")
    assert files == [nbs_path / "test_1.ipynb", nbs_path / "test.ipynb"]
    
    

/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/.hidden
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/test_1.ipynb
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/settings.ini
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/test.ipynb
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/_test.ipynb
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/.hidden/test_1.ipynb
/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/.hidden/test.ipynb


[PosixPath('/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/test_1.ipynb'), PosixPath('/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpgpkwtzsi/nbs/test.ipynb')]


In [None]:
with pytest.raises(ValueError) as e:

    with TemporaryDirectory() as d:
        nbs_path = Path(d) / "nbs"
        nbs_path.mkdir(parents=True)

        _get_files(nbs_path)
        
print(e.value)

The directory /private/var/folders/6n/3rjds7v52cd83wqkd565db0h0000gn/T/tmpl4d1k23u/nbs does not contain any Python files or notebooks


In [None]:
# | export


def _add_docstring_to_nb(
    file: Path, version: int, include_auto_gen_txt: bool, **kwargs
):
    """    Adds docstrings to all functions in a notebook.
        Args:
            file: Path to the notebook file.
            version: Version of the notebook file.
            include_auto_gen_txt: Whether to include the text "Auto-generated by nb_docstrings" in the docstring.
            kwargs: Additional keyword arguments to be passed to the function _check_and_add_docstrings_to_source.
        Returns:
            None
        Raises:
            None


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    _f = nbformat.read(file, as_version=version)
    for cell in _f.cells:
        if cell.cell_type == "code":
            cell["source"] = _check_and_add_docstrings_to_source(
                cell["source"], include_auto_gen_txt, **kwargs
            )
    nbformat.write(_f, file)


def _add_docstring_to_py(file: Path, include_auto_gen_txt: bool, **kwargs):
    """    Adds docstrings to all functions in a python file.
        Args:
            file: Path to the python file
            include_auto_gen_txt: If True, include the text "This function was auto-generated from the original Numba source code.
            **kwargs: Additional keyword arguments to pass to the docstring parser
        Returns:
            None
        Raises:
            ValueError: If the file is not a python file


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """
    with file.open("r") as f:
        source = f.read()
    source = _check_and_add_docstrings_to_source(source, include_auto_gen_txt, **kwargs)
    with file.open("w") as f:
        f.write(source)


def add_docstring_to_source(
    path: Union[str, Path],
    version: int = 4,
    include_auto_gen_txt: bool = True,
    model: str = "code-davinci-002",
    temperature: int = 0,
    max_tokens: int = 250,
    top_p: float = 1.0,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    stop: List[str] = ["#", "\"\"\""],
    prompt: Optional[str] = None,
) -> None:
    """Adds a docstring to the source code.
    
    Args:
        path: The path to the source code file or directory.
        version: The version of the docstring format to use.
        include_auto_gen_txt: Whether to include the text "Auto-generated by Code-Davinci" in the docstring.
        model: The model to use for generating the docstring.
        temperature: The temperature to use for generating the docstring.
        max_tokens: The maximum number of tokens to use for generating the docstring.
        top_p: The top_p to use for generating the docstring.
        frequency_penalty: The frequency_penalty to use for generating the docstring.
        presence_penalty: The presence_penalty to use for generating the docstring.
        stop: The stop tokens to use for generating the docstring.
        prompt: The prompt to use for generating the docstring.
    
    Raises:
        ValueError: If the path is not a file or directory.


    !!! note
    
        The above docstring is autogenerated by docstring-gen library (https://github.com/airtai/docstring-gen)
    """    
    path = Path(path)
    files = _get_files(path) if path.is_dir() else [path]

    for file in files:
        if file.suffix == ".ipynb":
            _add_docstring_to_nb(
                file=file,
                version=version,
                include_auto_gen_txt=include_auto_gen_txt,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                stop=stop,
                prompt=prompt,
            )
        else:
            _add_docstring_to_py(
                file=file,
                include_auto_gen_txt=include_auto_gen_txt,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                stop=stop,
                prompt=prompt,
            )

In [None]:
with TemporaryDirectory() as d:
    nbs_path = Path(d) / "nbs"
    nbs_path.mkdir(parents=True)

    shutil.copyfile(Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "test.ipynb")
    shutil.copyfile(Path("..") / "fixtures" / "Test_Data.ipynb", nbs_path / "_test.ipynb")
    
    shutil.copyfile(Path("..") / "fixtures" / "test_data.py", nbs_path / "test_data.py")
    shutil.copyfile(Path("..") / "settings.ini", nbs_path / "settings.ini")

    add_docstring_to_source(nbs_path)
    
    with (nbs_path / "test.ipynb").open("r") as f:
        nb = nbformat.read(f, as_version=4)

for cell in nb.cells:
#     print(cell["source"])
    
    if cell.cell_type == "code":
        tree = ast.parse(cell["source"])        
        print(cell["source"])
        for node in tree.body:
            if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            assert ast.get_docstring(node)
            
    else:
        print(cell["source"])

[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 3 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 2 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 7 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
[34mNote: OpenAI's API rate limit reached. Command will automatically retry in 29 seconds. For more information visit: https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits[0m
# Test notebook for docstring generator

> Test notebook for docstring generator
# | export

from typing import *
import os
from pathlib import Path
# from contextlib impo