From a3ea56cee0a5269452ddb6a621a8f17ae5dae303 Mon Sep 17 00:00:00 2001 From: Nikhil Parasaram Date: Wed, 21 Feb 2024 16:06:34 +0000 Subject: [PATCH] Handle async --- BugsInPy/run_custom_patch.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/BugsInPy/run_custom_patch.py b/BugsInPy/run_custom_patch.py index 282b0a3..f186fe5 100644 --- a/BugsInPy/run_custom_patch.py +++ b/BugsInPy/run_custom_patch.py @@ -8,7 +8,7 @@ from git import Repo, NoSuchPathError from pathlib import Path import subprocess -from typing import Dict +from typing import Dict, Union from BugsInPy.utils import checkout @@ -16,11 +16,15 @@ class ReplaceFunctionNode(ast.NodeTransformer): - def __init__(self, target_lineno: int, replacement_node: ast.FunctionDef): + def __init__( + self, + target_lineno: int, + replacement_node: Union[ast.FunctionDef, ast.AsyncFunctionDef], + ): self.target_lineno = target_lineno self.replacement_node = replacement_node - def visit_FunctionDef(self, node: ast.FunctionDef): + def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): # Check if the function has decorators start_lineno = node.lineno if node.decorator_list: @@ -44,7 +48,9 @@ def replace_code(bug_data: Dict, repo_bug_id: str, file_path: Path) -> None: # Parse the replacement code to get its AST replacement_tree = ast.parse(replacement_code) - if not isinstance(replacement_tree.body[0], ast.FunctionDef): + if not isinstance( + replacement_tree.body[0], (ast.FunctionDef, ast.AsyncFunctionDef) + ): raise ValueError("Replacement code does not contain a function definition.") # Use the NodeTransformer to replace the original function with the replacement