Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

See descriptions of functions generated using GPT-3! #3

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions gct/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
version=f"GCT version: {__version__}",
)

parser.add_argument(
"--summarize",
"-s",
action="store_true",
help="Use GPT-3 to summarize functions (requires API key)"
)


def main():
args = parser.parse_args()
Expand All @@ -45,12 +52,12 @@ def main():
if not status["valid"]:
path = args.input

graph, _ = api.run(path)
graph, _ = api.run(path, summarize=args.summarize)

api.render(
graph,
file_path=f"{args.destination_folder}/{GRAPH_FOLDER_DEFAULT_NAME}",
output_format="pdf",
output_format="svg"
)


Expand Down
5 changes: 3 additions & 2 deletions gct/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@
from gct.constants import TEMP_FOLDER, GRAPH_FOLDER_DEFAULT_NAME


def run(resource_name: str) -> "list[graphviz.Digraph, str]":
def run(resource_name: str, summarize: bool) -> "list[graphviz.Digraph, str]":
"""
Runs GCT on a given resource and returns the graphviz object.
@Parameter:
1. resource_name: str = Path to the file/URL to generate graph for.
2. summarize: bool = Boolean indicating whether or not to add function descriptions to the graph
@Returns:
1. graphviz.Digraph object. To render the graph, call the render() method on the object.
2. str: The raw code corresponding to `resource_name`.
Expand All @@ -51,7 +52,7 @@ def run(resource_name: str) -> "list[graphviz.Digraph, str]":
# Get the AST and raw code
tree, raw_code = utils.parse_file(resource_name)
# Extract relevant components -- node connection and edge mapping
node_representation, edge_representation = extract(tree, raw_code)
node_representation, edge_representation = extract(tree, raw_code, summarize)
# Heirarchical clustering
node_representation.group_nodes_by_level()
# Define graphviz graph
Expand Down
24 changes: 24 additions & 0 deletions gct/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,27 @@
SELF_NODE_NAME = "self"
TEMP_FOLDER = "temp"
GRAPH_FOLDER_DEFAULT_NAME = "gct_graph"
PROMPT = """
function:
def fibonacci(n):
if n <= 0:
return 0
elif n == 1:
return 1
else:
return fibonacci(n-1) + fibonacci(n-2)

one sentence description: Calculates the fibonacci numbers to the n'th degree.
###
function:
def resize_image(image_path):
image = cv2.imread(image_path)
image = cv2.resize(image, (512, 512))
cv2.imwrite(image_path, image)

one sentence description: Reads an image, resizes it to (512, 512) using openCV, then saves it in the same path.
###
function:
<function_code>

one sentence description:"""
2 changes: 2 additions & 0 deletions gct/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ def __init__(
line_end: int,
name: str,
type: str = None,
description: str = None
):
self.line_start = line_start
self.line_end = line_end
self.name = name
self.type = type # options: [function, class]
self.id = uuid.uuid1().hex
self.description = description

def __repr__(self) -> str:
return f"{self.name} #{self.line_start + 1}"
Expand Down
6 changes: 3 additions & 3 deletions gct/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from gct.type_check import Metadata


def extract(tree: ast, raw_code: "list[str]"):
def extract(tree: ast, raw_code: "list[str]", summarize: bool):
"""2 pass algorithm"""

node_line_map: "dict[int, Node]" = {
constants.ROOT_NODE: Node(constants.ROOT_NODE_LINENO, len(raw_code), "root")
constants.ROOT_NODE: Node(constants.ROOT_NODE_LINENO, len(raw_code), raw_code, "root")
}
node_creation_graph = Graph()
edge_creation_graph = Graph()

# Node creation
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
func_visitor = UserDefinedFuncVisitor()
func_visitor = UserDefinedFuncVisitor(raw_code, summarize)
func_visitor.visit(node)
node_line_map[func_visitor.node.line_start] = func_visitor.node

Expand Down
33 changes: 33 additions & 0 deletions gct/summarize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import openai
import os
from gct.constants import PROMPT
import textwrap

class CodeSummarizer:
def __init__(self, decription_max_len: int = 50, temperature: float = 0.8):
openai.api_key = os.getenv("OPENAI_API_KEY")
assert openai.api_key is not None, "You must set the environment variable OPENAI_API_KEY to enable function descriptions with gct"

self.description_max_len = decription_max_len
self.temperature = temperature


def summarize(self, code: "list[str]") -> str:
prompt = self._populate_prompt(code)
description = self._text_completion(prompt)
return "<BR />".join(textwrap.wrap(description, 42))


def _text_completion(self, prompt: str) -> str:
resp = openai.Completion.create(
model="text-davinci-003",
prompt=prompt,
max_tokens=self.description_max_len,
temperature=self.temperature,
)
return resp["choices"][0]["text"].split("###")[0]


def _populate_prompt(self, lines: "list[str]") -> str:
assert "<function_code>" in PROMPT, "Prompt must contain a phrase <function_code> with which to populate function code"
return PROMPT.replace("<function_code>", "\n".join(lines))
15 changes: 11 additions & 4 deletions gct/syntax_tree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import ast
from gct.network import Node
from collections import deque

from gct.summarize import CodeSummarizer
from gct.utils import fetch_full_function

class FunctionCallVisitor(ast.NodeVisitor):
"""Extract all function calls"""
Expand Down Expand Up @@ -32,12 +33,18 @@ def visit_Attribute(self, node):
class UserDefinedFuncVisitor(ast.NodeVisitor):
"""Extract all user defined functions and classes"""

def __init__(self):
def __init__(self, raw_code: "list[str]", summarize: bool):
self.node: Node = None
self.summarize = summarize
self.raw_code = raw_code
if self.summarize:
self.code_summarizer = CodeSummarizer()

def create_node(self, node: ast.AST, node_name: str, type: str):
end_lineno = node.lineno - 1 if "end_lineno" in dir(node) else str(None)
self.node = Node(node.lineno - 1, end_lineno, node_name, type)
end_lineno = node.lineno - 1 if "end_lineno" in dir(node) else str(None) # FIXME: @qasim is this correct?
lines = fetch_full_function(self.raw_code, node.lineno - 1)
summary = self.code_summarizer.summarize(lines) if self.summarize else None
self.node = Node(node.lineno - 1, end_lineno, node_name, type, summary)

def visit_Lambda(self, node: ast.Lambda):
raise NotImplementedError("Lambda functions are not supported yet")
Expand Down
3 changes: 3 additions & 0 deletions gct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def add_subgraphs(
style = "rounded, filled"
shape = "box"
bgcolor = generate_random_color()

if node.description is not None:
text = f"<{node.__repr__()} <BR/> <FONT POINT-SIZE=\"8\">{node.description}</FONT>>"

graphviz_graph.node(
node.id,
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
argparse==1.4.0
graphviz==0.20.1
networkx==2.8.8
platform==1.0.8
platform==1.0.8
openai
textwrap