In [1]:
# System
import os
from dotenv import load_dotenv
import tempfile
from titlecase import titlecase
import ast

# LLM Models
from langchain_openai import ChatOpenAI
from openai import OpenAI

# Template
from langchain_core.prompts import ChatPromptTemplate

# OutputParsers
from langchain.schema.output_parser import StrOutputParser

# Gradio frontend
import gradio as gr

In [2]:
load_dotenv()

True

In [3]:
chat_model = ChatOpenAI(model="gpt-4o-mini-2024-07-18",
                        max_completion_tokens=1024,
                        api_key=os.getenv("OPENAI_API_KEY"),
                        temperature=0.0)

In [95]:
def strip_double_quotes(s1):
    s2 = s1.replace('"""', '')
    return s2

In [133]:
def code_block_comment(code_block):

    prompt = ChatPromptTemplate.from_template("""
                                              Provide comments on the following block of Python codes.
                                              Provide only comments. Do not include the function in your response.
                                              Start each line of comments with '#'. Do not include triple quotes.
                                              Python code: {raw_text}.
                                              """)

    chat_chain = prompt | chat_model | StrOutputParser()
    response = chat_chain.invoke({"raw_text": code_block})  

    return response

In [134]:
ans = code_block_comment("if __name__ == '__main__':\n    main()")

In [135]:
print(ans)

# This line checks if the current script is being run as the main program.
# If the script is executed directly (not imported as a module), the following block will execute.
# The 'main()' function is called to start the program's execution.


In [126]:
def import_comment(import_line):

    prompt = ChatPromptTemplate.from_template("""
                                              Provide a concise, one-line comment on the imports in Python.
                                              Start the line with '#'. 
                                              Python imports: {raw_text}.
                                              """)

    chat_chain = prompt | chat_model | StrOutputParser()
    response = chat_chain.invoke({"raw_text": import_line})  

    return response

In [100]:
import_comment('from tqdm import tqdm')

'# This import statement brings in the tqdm library, allowing for the use of a progress bar in Python applications.'

In [101]:
def get_docstring(function_text):

    prompt = ChatPromptTemplate.from_template("""
                                              Create a high quality docstring for the given python function.
                                              Break up any line more than 80 characters into multiple
                                              lines separated by '\n'. Do not have any '\t' in the output.
                                              Instead have four spaces '    '.
                                              Do not include the function in your response.
                                              Python function: {raw_text}.
                                              """)

    chat_chain = prompt | chat_model | StrOutputParser()
    response = chat_chain.invoke({"raw_text": function_text})
    response = strip_double_quotes(response)    

    return response

In [102]:
docstring1 = get_docstring("def hello(name):\n\treturn f'Hello, World. My name is {name}'")

In [103]:
print(docstring1)


    Greets the user with a personalized message.

    This function takes a name as an input and returns a greeting string
    that includes the provided name. The greeting is formatted to say
    "Hello, World. My name is {name}", where {name} is replaced by the
    actual name passed to the function.

    Parameters:
        name (str): The name of the user to be included in the greeting.

    Returns:
        str: A greeting message that includes the user's name.

    Example:
        >>> hello("Alice")
        'Hello, World. My name is Alice'




In [104]:
def merge_docstring_and_function(original_function, docstring):
    split = original_function.split("\n")
    first_part, second_part = split[0], split[1:]
    if second_part[0].startswith('\t'):
        second_part[0] = '\n    ' + second_part[0][1:]
    docstring = '    '.join(docstring.splitlines(True))
    merged_function = first_part + "\n" + '    """' + docstring + '    """\n\n' + "\n".join(second_part)
    
    # split = original_function.split("\n")
    # first_part, second_part = split[0], split[1:]
    # docstring = '\t'.join(docstring.splitlines(True))
    # merged_function = first_part + '\n\t' + '"""\n\t' + docstring + '\n' + '    """\n\n' + "\n".join(second_part)
    # return second_part
    return merged_function

In [136]:
def docstring_generator(file_path):

    filename = file_path.split('\\')[-1].split('.')[0]
    extention = file_path.split('\\')[-1].split('.')[1]

    with open(file_path, "r") as file:
        file_content = file.read()

    # Parse the file content into an Abstract Syntax Tree
    tree = ast.parse(file_content)

    nodes = [node for node in ast.walk(tree) if isinstance(node, tuple([ast.FunctionDef,
                                                                        ast.Import,
                                                                        ast.If,
                                                                        ast.ImportFrom]))]

    # Collect function names and their source code
    result = ""
    for func in nodes:
        if isinstance(func, ast.FunctionDef):
            start_line = func.lineno - 1  # Line numbers in AST are 1-based
            end_line = func.end_lineno if hasattr(func, 'end_lineno') else None
            function_code = "\n".join(file_content.splitlines()[start_line:end_line])
            doc1 = get_docstring(function_code)
            combined_code = merge_docstring_and_function(function_code, doc1)
            if result.endswith('\n\n'):
                result += "\n" + combined_code + "\n"
            else:
                result += "\n\n" + combined_code + "\n"
        elif isinstance(func, ast.Import):
            start_line = func.lineno - 1  # Line numbers in AST are 1-based
            end_line = func.end_lineno if hasattr(func, 'end_lineno') else None
            function_code = "\n".join(file_content.splitlines()[start_line:end_line])
            doc1 = import_comment(function_code)
            combined_code = doc1 + '\n' + function_code + '\n\n'
            result += combined_code
        elif isinstance(func, ast.ImportFrom):
            start_line = func.lineno - 1  # Line numbers in AST are 1-based
            end_line = func.end_lineno if hasattr(func, 'end_lineno') else None
            function_code = "\n".join(file_content.splitlines()[start_line:end_line])
            doc1 = import_comment(function_code)
            combined_code = doc1 + '\n' + function_code + '\n\n'
            result += combined_code
        elif isinstance(func, ast.If):
            start_line = func.lineno - 1  # Line numbers in AST are 1-based
            end_line = func.end_lineno if hasattr(func, 'end_lineno') else None
            function_code = "\n".join(file_content.splitlines()[start_line:end_line])
            doc1 = code_block_comment(function_code)
            combined_code = '\n\n' + doc1 + '\n\n' + function_code + '\n'
            result += combined_code
            
    output_filename = f"{filename}_withdocstring.{extention}"
    with open(output_filename, "w") as file:
        file.write(result)

    return output_filename

In [137]:
demo = gr.Interface(fn=docstring_generator,
                    inputs=[gr.File(type='filepath')],
                    outputs=[gr.File(type='filepath')])

In [138]:
demo.launch()

* Running on local URL:  http://127.0.0.1:7862
* To create a public link, set `share=True` in `launch()`.




In [139]:
demo.close()

Closing server running on port: 7862


In [141]:
class LimitedDepthVisitor(ast.NodeVisitor):
    def __init__(self, max_depth):
        self.max_depth = max_depth

    def visit(self, node, current_depth=0):
        if current_depth > self.max_depth:
            return  # Stop traversal if the depth exceeds the limit
        print(f"{'  ' * current_depth}Visiting: {type(node).__name__}")
        super().visit(node)  # Continue traversal for child nodes

In [None]:
def tree_view_depth(file_path):

    with open(file_path, "r") as file:
        file_content = file.read()

    # Parse the file content into an Abstract Syntax Tree
    tree = ast.parse(file_content)

    # Traverse the AST with a depth limit of 2
    visitor = LimitedDepthVisitor(max_depth=2)
    visitor.visit(tree)

In [119]:
def tree_view(file_path):

    # filename = file_path.split('\\')[-1].split('.')[0]
    # extention = file_path.split('\\')[-1].split('.')[1]

    with open(file_path, "r") as file:
        file_content = file.read()

    tree = ast.parse(file_content)
    print(ast.dump(tree, indent=4))

In [123]:
tree_view('c:/Users/JimYi/PyCharmProjects/ProjectEuler/Problem022.py')

Module(
    body=[
        Import(
            names=[
                alias(name='time')]),
        FunctionDef(
            name='letter_value',
            args=arguments(
                posonlyargs=[],
                args=[
                    arg(arg='s')],
                kwonlyargs=[],
                kw_defaults=[],
                defaults=[]),
            body=[
                Assign(
                    targets=[
                        Name(id='d', ctx=Store())],
                    value=Dict(
                        keys=[
                            Constant(value='A'),
                            Constant(value='B'),
                            Constant(value='C'),
                            Constant(value='D'),
                            Constant(value='E'),
                            Constant(value='F'),
                            Constant(value='G'),
                            Constant(value='H'),
                            Constant(value='I'),
             