In [40]:
import ast
import re
import os
import json

from abc import ABC
from pathlib import Path
from enum import Enum, auto

from codetf.models import load_model_pipeline

In [3]:
root_dir = Path.cwd()

In [35]:
class FileType(Enum):
    PY = auto()
    IPYNB = auto()
    UNKNOWN = auto()

class AbstractDocStringUtil(ABC):
    FILETYPE: FileType = FileType.UNKNOWN
    DEF_REGEX = re.compile(r"(\bdef .*\(.*\).*:)")

    def __init__(self, model_name: str = "codet5", model_type: str = "sum_python", task: str = "base") -> None:
        super().__init__()

        self._model = load_model_pipeline(model_name=model_name, task=model_type, model_type=task)

    def gen_docstring(self, method: str) -> str:
        return self._model.predict([method])
    
    def _get_files(self, path: Path):
        if self.FILETYPE == FileType.UNKNOWN:
            raise NotImplementedError()
        
        if self.FILETYPE == FileType.IPYNB:
            return path.rglob(f"*.ipynb")
        return path.rglob(f"*.py")
    
    def process_files(self, path: Path, inplace = False):
        for file in self._get_files(path):
            if inplace:
                self._process_file(file, file)
                continue

            new_base_dir = path / "docs"
            if Path(os.path.commonpath([file, new_base_dir])) == new_base_dir:
                continue

            if not new_base_dir.exists():
                new_base_dir.mkdir()

            new_path = new_base_dir / (str(file)[len(str(path)):]).lstrip("/")
            self._process_file(file, new_path)
    
    def _process_file(self, file: Path, ):
        raise NotImplementedError()
    
    def _add_docstring(self, code: str):
        functions = [*sorted([
            f for f in ast.walk(ast.parse(code)) 
            if isinstance(f, ast.FunctionDef)
        ], key=lambda f:f.lineno, reverse=True)]

        c = code.split("\n")

        for func in functions:
            indent = re.search('\S', c[func.lineno]).start()
            docsrting = f"\n{' ' * indent}".join(self.gen_docstring(ast.unparse(func)))
            c = c[:func.lineno] + [f'{" " * indent}"""{docsrting}"""'] + c[func.lineno:]

        return "\n".join(c)

In [37]:
class PythonFileDocStringUtil(AbstractDocStringUtil):
    FILETYPE: FileType = FileType.PY

    def _process_file(self, file: Path, new_path: Path):
        with file.open("r+") as f:
            code = f.read()

        with new_path.open("w") as f:
            f.write(self._add_docstring(code))


py_util = PythonFileDocStringUtil()
py_util.process_files(root_dir / "test")

In [48]:
class IPYNBFileDocStringUtil(AbstractDocStringUtil):
    FILETYPE: FileType = FileType.IPYNB

    def _process_file(self, file: Path, new_path: Path):
        with file.open("r+") as f:
            notebook = json.load(f)

        for cell in notebook.get("cells", []):
            if not cell.get("cell_type", "code") == "code":
                continue

            cell["source"] = self._add_docstring("".join(cell.get("source", "")))

        with new_path.open("w") as f:
            json.dump(notebook, f)
    
ipynb_util = IPYNBFileDocStringUtil()
ipynb_util.process_files(root_dir / "test")