In [None]:
# | default_exp docstring_generator

# Docstring Generator

In [None]:
# | export

import ast
import tokenize
from typing import *
from pathlib import Path
from io import BytesIO

import nbformat

In [None]:
import shutil
from tempfile import TemporaryDirectory

## Docstring generator functions

In [None]:
# | export


def _generate_docstring_using_codex(code: str) -> str:
    return """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
"""

In [None]:
# | export


def _inject_docstring_to_source(
    source: str, docstring: str, lineno: int, node_col_offset: int
) -> str:
    """Inject a docstring into the source code at a specified line number.

    Args:
        source: the source code
        docstring: the docstring to be added
        lineno: the line number at which the docstring will be inserted
        node_col_offset: the number of spaces to indent the docstring

    Returns:
        The updated source code with the docstring injected
    """
    lines = source.split("\n")
    indented_docstring = "\n".join(
        [
            line
            if i == 0 or i == len(docstring.split("\n")) - 1
            else f"{' ' * node_col_offset}{line}"
            for i, line in enumerate(docstring.split("\n"))
        ]
    )
    indent = node_col_offset + 4
    lines.insert(lineno, f'{" " * indent}"""{indented_docstring}{" " * indent}"""')
    return "\n".join(lines)

In [None]:
source = """
async def drive(self):
    print(f'The {self.model} is now driving.')
"""

docstring = """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
"""

expected = """
async def drive(self):
    \"""Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    \"""
    print(f'The {self.model} is now driving.')
"""

lineno = 2
node_col_offset = 0
actual = _inject_docstring_to_source(source, docstring, lineno, node_col_offset)
print(actual)

assert actual == expected


async def drive(self):
    """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    """
    print(f'The {self.model} is now driving.')



In [None]:
# | export

def _get_code_from_source(source: str, start_line_no: int, end_line_no: int) -> str:
    """Get a block of lines from a given source string.
    
    Args:
        source: The source string.
        start_line_no: The line number of the start of the block of lines.
        end_line_no: The line number of the end of the block of lines.
    
    Returns:
        The extracted block of lines from the source
    """
    source_lines = source.split("\n")
    extracted_lines = source_lines[start_line_no-1:end_line_no]
    return "\n".join(extracted_lines)

In [None]:
source = """
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""
start_line_no = 7
end_line_no = 8

expected = """    async def drive(self):
        print(f'The {self.model} is now driving.')"""

actual = _get_code_from_source(source, start_line_no, end_line_no)
print(actual)

assert actual == expected


    async def drive(self):
        print(f'The {self.model} is now driving.')


In [None]:
# | export


def _calculate_end_lineno(source: str, start_line_no: int) -> int:
    """Calculate the end line number of a function in a Python source code.
    
    Args:
        source: The source code string.
        start_line_no: The line number of the start of the function.
        
    Returns:
        The end line number of the function.
    """
    lines = source.split("\n")[start_line_no - 1 :]
    first_indent = len(lines[0]) - len(lines[0].lstrip())
    end_line_in_source = 0

    for i, line in enumerate(lines[1:]):
        if len(line) - len(line.lstrip()) == first_indent and line.strip() != "":
            end_line_in_source = i
            break

    ret_val = (
        len(source.split("\n"))
        if end_line_in_source == 0
        else end_line_in_source + start_line_no
    )
    return ret_val - 1

In [None]:
source = """
class A:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')

class B:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
        
    async def stop(self):
        print(f'The {self.model} is now stopped.')
        
class C:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a

    async def drive(self):
        print(f'The {self.model} is now driving.')
        
async def drive(self):
    print(f'The {self.model} is now driving.')
"""

tree = ast.parse(source)
expected = [8, 5, 9, 19, 13, 16, 20, 27, 24, 28, 30]
actual = []
for i, node in enumerate(tree.body):
    res = _calculate_end_lineno(source, node.lineno)
    actual.append(res)
#     print(node.end_lineno)

    if isinstance(node, ast.ClassDef):
        for f in node.body:
            if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            res = _calculate_end_lineno(source, f.lineno)
            actual.append(res)
#             print(f.end_lineno)
print(actual)
assert actual == expected

[8, 5, 9, 19, 13, 16, 20, 27, 24, 28, 30]


In [None]:
# | export


def _line_has_class_or_method(source: str, lineno: int) -> bool:
    """Check if a line in the source code contains a class or method definition.
    
    Args:
        source: The source code as a string.
        lineno: The line number to check.
    
    Returns:
        True if the line contains a class or method definition, False otherwise.
    """
    line = "".join(source.split("\n")[lineno - 1])
    tokens = list(tokenize.tokenize(BytesIO(line.encode("utf-8")).readline))
    return tokens[1].type == tokenize.NAME and tokens[1].string in {
        "class",
        "def",
        "async",
    }

def _get_start_line_for_class_or_func(source: str, lineno: int) -> int:
    """Get the line number of the first line containing a class or function definition.
    
    Args:
        source: The source code as a string.
        lineno: The line number to start from.
    
    Returns:
        The line number of the first line containing a class or function definition, 
        or the original line number if no such line is found.
    """
    if _line_has_class_or_method(source, lineno):
        return lineno

    original_lineno = lineno
    total_lines = source.split("\n")
    for i in total_lines:
        lineno += 1
        if lineno > len(total_lines):
            break
        if _line_has_class_or_method(source, lineno):
            return lineno
    return original_lineno

In [None]:
source = """def decorator1(func):
    \"""Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    \"""
    def inner():
        func()
    return inner

def decorator2(func):
    \"""Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    \"""
    def inner():
        func()
    return inner


@decorator1
@decorator2
def outer_func():
    def inner_func():
        print("Hello, World!")
    inner_func()
"""

line_no = 28
expected = 30
actual = _get_start_line_for_class_or_func(source, line_no)
print(actual)
assert actual == expected

30


In [None]:
source = """
async def drive(self):
    print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 2
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected

source = """
@foo(x=5)
@decorator2
async def drive(self):
    print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 4
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected

source = """
@foo(x=5)
async def drive(self):
    print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 3
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected

source = """
@foo(x=5)
@bar(x=5)
@zar(x=5)
class A:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""
lineno = 2
expected = 5
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected

source = """
@foo(x=5)
@bar(x=5)

"""
lineno = 2
expected = 2
actual = _get_start_line_for_class_or_func(source, lineno)
print(actual)

assert actual == expected



2
4
3
5
2


In [None]:
# | export


def _add_docstring(
    source: str,
    node: Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef],
    line_offset: int,
) -> Tuple[str, int]:
    """Add a docstring to the given node and update the source code.

    Args:
        source: the source code from the notebook cell
        node: the AST node representing a class definition, function definition,
            or async function definition
        line_offset: the number of lines added before the current
            node in the source

    Returns:
        A tuple containing the updated source code and the new line number offset
    """
    line_no = node.lineno + line_offset
    
    # Fix for ast's node.lineno giving line number of decorator 
    # instead of function/class definition in Python 3.7
    line_no = _get_start_line_for_class_or_func(source, line_no)
    
    if hasattr(node, 'end_lineno') and node.end_lineno is not None:
        end_line_no = node.end_lineno + line_offset
    else:
        end_line_no = _calculate_end_lineno(source, line_no)
    
    code = _get_code_from_source(source, line_no, end_line_no)
    docstring = _generate_docstring_using_codex(code)
    
    source = _inject_docstring_to_source(source, docstring, line_no, node.col_offset)
    line_offset += len(docstring.split("\n"))
    return source, line_offset

In [None]:
source = """
class test:
    CONST_VAL = 1
    def __init__(self, a):
        self.a = a
        
    async def drive(self):
        print(f'The {self.model} is now driving.')
"""

expected = """
class test:
    \"""Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    \"""
    CONST_VAL = 1
    def __init__(self, a):
        \"""Sample docstring
    
        Args:
            s: sample args
    
        Returns:
            sample return
        \"""
        self.a = a
        
    async def drive(self):
        \"""Sample docstring
    
        Args:
            s: sample args
    
        Returns:
            sample return
        \"""
        print(f'The {self.model} is now driving.')
"""

tree = ast.parse(source)
line_offset = 0
for node in tree.body:
    source, line_offset = _add_docstring(source, node, line_offset)

    for f in node.body:
        if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
            continue
        source, line_offset = _add_docstring(source, f, line_offset)

print(source)
assert source == expected


class test:
    """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    """
    CONST_VAL = 1
    def __init__(self, a):
        """Sample docstring
    
        Args:
            s: sample args
    
        Returns:
            sample return
        """
        self.a = a
        
    async def drive(self):
        """Sample docstring
    
        Args:
            s: sample args
    
        Returns:
            sample return
        """
        print(f'The {self.model} is now driving.')



In [None]:
# | export


def _check_and_add_docstrings_to_source(source: str) -> str:
    """Check for missing docstrings in the source code and add them if necessary.

    Args:
        source: the source code from the notebook cell

    Returns:
        The updated source code with added docstrings
    """

    tree = ast.parse(source)
    line_offset = 0

    for node in tree.body:
        if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
            continue
        if ast.get_docstring(node) is not None:
            continue

        # A class or a function without docstring
        source, line_offset = _add_docstring(source, node, line_offset)
        if not isinstance(node, ast.ClassDef):
            continue
        # Is a class and we need to check the functions inside
        # 29 - 36 make it as a recursive function
        for f in node.body:
            if not isinstance(f, (ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            if ast.get_docstring(f) is not None:
                continue

            # should be a function inside the class for which there is no docstring
            source, line_offset = _add_docstring(source, f, line_offset)

    return source

In [None]:
source = """
def decorator1(func):
    def inner():
        func()
    return inner

def decorator2(func):
    def inner():
        func()
    return inner

@decorator1
@decorator2
def outer_func():
    def inner_func():
        print("Hello, World!")
    inner_func()
"""
actual = _check_and_add_docstrings_to_source(source)
print(actual)


def decorator1(func):
    """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    """
    def inner():
        func()
    return inner

def decorator2(func):
    """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    """
    def inner():
        func()
    return inner

@decorator1
@decorator2
def outer_func():
    """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    """
    def inner_func():
        print("Hello, World!")
    inner_func()



In [None]:
# | export


def add_docstring_to_notebook(nb_path: Union[str, Path], version: int = 4):
    """Add docstrings to the source
    
    This function reads through a Jupyter notebook cell by cell and 
    adds docstrings for classes and methods that do not have them.
    
    Args:
        nb_path: The notebook file path
        version: The version of the Jupyter notebook format
    """
    nb_path = Path(nb_path)
    nb = nbformat.read(nb_path, as_version=version)

    for cell in nb.cells:
        if cell.cell_type == "code":
            cell["source"] = _check_and_add_docstrings_to_source(cell["source"])

    nbformat.write(nb, nb_path)

In [None]:
with TemporaryDirectory() as d:
    nbs_path = Path(d) / "nbs"
    nbs_path.mkdir(parents=True)

    nb_path = nbs_path / "test.ipynb"
    shutil.copyfile(Path("..") / "fixtures" / "Test_Data.ipynb", nb_path)
    
    assert nb_path.exists()

    add_docstring_to_notebook(nb_path)

    with nb_path.open("r") as f:
        nb = nbformat.read(f, as_version=4)

for cell in nb.cells:
#     print(cell["source"])
    
    if cell.cell_type == "code":
        tree = ast.parse(cell["source"])        
        print(cell["source"])
        for node in tree.body:
            if not isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
                continue
            assert ast.get_docstring(node)
            
    else:
        print(cell["source"])



# Test notebook for docstring generator

> Test notebook for docstring generator
# | export

from typing import *
import os
from pathlib import Path
# from contextlib import contextmanager

from contextlib import contextmanager

import shutil
from tempfile import TemporaryDirectory

# Title
# | export

# Vehicle class
class Vehicle:
    """Sample docstring

    Args:
        s: sample args

    Returns:
        sample return
    """
    # Constructor function
    def __init__(self, brand, model, type):
        """Constructor function
        
        Args:
            brand: Vehicle's brand
            model: Vehicle's model
            type: Vehicle's type
        """
        self.brand = brand
        self.model = model
        self.type = type
        self.gas_tank_size = 14
        self.fuel_level = 0
    
    # fuel_up function
    def fuel_up(self):
        """Sample docstring
    
        Args:
            s: sample args
    
        Returns:
            sample return
        ""