In [2]:
!pip install astor
!pip install ast

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting astor
  Downloading astor-0.8.1-py2.py3-none-any.whl.metadata (4.2 kB)
Downloading astor-0.8.1-py2.py3-none-any.whl (27 kB)
Installing collected packages: astor
Successfully installed astor-0.8.1
[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting ast
  Downloading AST-0.0.2.tar.gz (19 kB)
  Preparing metadata (setup.py) ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[8 lines of output][0m
  [31m   [0m Traceback (most recent call last):
  [31m   [0m   File "<string>", line 2, in <module>
  [31m   [0m   File "<pip-setuptools-caller>", line 34, in <module>
  [31m   [0m   File "/tmp/pip-install-81xsdey5/ast_124f37a7ae9b4b50845fc18b933f982c/setup.py", line 6, in <module>
  [31m   [0m     README = codecs.o

In [1]:
import ast
import astor

class FunctionExtractor(ast.NodeVisitor):
    def __init__(self, max_lines=5):
        self.functions = []
        self.max_lines = max_lines

    def visit_FunctionDef(self, node):
        if len(node.body) <= self.max_lines:
            self.functions.append(node)
        self.generic_visit(node)

def parse_code(source_code):
    return ast.parse(source_code)

def get_short_methods(ast_tree, max_lines=5):
    extractor = FunctionExtractor(max_lines)
    extractor.visit(ast_tree)
    return extractor.functions

class FunctionCallLocator(ast.NodeVisitor):
    def __init__(self, function_names):
        self.function_names = function_names
        self.calls = {name: [] for name in function_names}
        self.parents = []

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name) and node.func.id in self.function_names:
            self.calls[node.func.id].append((node, self.parents[-1]))
        self.generic_visit(node)

    def generic_visit(self, node):
        self.parents.append(node)
        super().generic_visit(node)
        self.parents.pop()

def get_function_calls(ast_tree, function_names):
    locator = FunctionCallLocator(function_names)
    locator.visit(ast_tree)
    return locator.calls

def replace_node(parent, old_node, new_nodes):
    for field, value in ast.iter_fields(parent):
        if isinstance(value, list):
            for i, item in enumerate(value):
                if item is old_node:
                    value[i:i+1] = new_nodes
                    return
        elif value is old_node:
            setattr(parent, field, new_nodes[0])

def inline_function_calls(ast_tree, function_def, calls):
    new_tree = ast_tree
    for call, parent in calls:
        inlined_body = []
        for stmt in function_def.body:
            new_stmt = ast.copy_location(stmt, call)
            if isinstance(new_stmt, ast.Return):
                new_stmt = ast.Expr(value=new_stmt.value)
            inlined_body.append(new_stmt)

        replace_node(parent, call, inlined_body)

    return astor.to_source(new_tree)

def generate_dataset(source_code, short_methods, function_calls):
    dataset = []
    for function_def in short_methods:
        func_name = function_def.name
        if func_name in function_calls:
            calls = function_calls[func_name]
            before = inline_function_calls(ast_tree, function_def, calls)
            after = source_code
            dataset.append((before, after))
    return dataset

# Sample source code for testing
source_code = """
def add(a, b):
    return a + b

def main():
    x = add(1, 2)
    y = add(3, 4)
    print(x, y)

main()
"""

# Parse the code
ast_tree = parse_code(source_code)

# Identify short methods
short_methods = get_short_methods(ast_tree)

# Get function names
function_names = [func.name for func in short_methods]

# Locate function calls
function_calls = get_function_calls(ast_tree, function_names)

# Generate dataset
dataset = generate_dataset(source_code, short_methods, function_calls)

for before, after in dataset:
    print("Before Refactoring:\n", before)
    print("After Refactoring:\n", after)


Before Refactoring:
 def add(a, b):
    return a + b


def main():
    x = 
    a + b
    y = 
    a + b
    print(x, y)


main()

After Refactoring:
 
def add(a, b):
    return a + b

def main():
    x = add(1, 2)
    y = add(3, 4)
    print(x, y)

main()

Before Refactoring:
 def add(a, b):
    return a + b


def main():
    x = 
    a + b
    y = 
    a + b
    print(x, y)


x = 
a + b

After Refactoring:
 
def add(a, b):
    return a + b

def main():
    x = add(1, 2)
    y = add(3, 4)
    print(x, y)

main()

