In [1]:
import inspect
import ast
import numpy as np
import typing

### Multiply integer variable if the argument is typed as INT

class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

def apply_multiply_int_variables(tree: ast.AST, int_variables: typing.Set[str]) -> ast.AST:
    transformer = MultiplyIntVariables(int_variables=int_variables)
    return transformer.visit(tree)


def collect_int_variables(typed_tree: ast.AST) -> typing.Set[str]:
    int_variables = set()

    for node in ast.walk(typed_tree):
        if isinstance(node, ast.arg) and isinstance(node.annotation, ast.Name) and node.annotation.id == 'int':
            int_variables.add(node.arg)

    return int_variables

In [42]:
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

def greet(name: str) -> None:
    print(f'Hello, {name}!')

def add(a: int, b: int) -> int:
    return a * 100 + b * 100


Also identify derived INT varibles and Multiply them.

WRONG. It doesn't work! 

In [15]:
class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

    def visit_Assign(self, node: ast.Assign) -> ast.Assign:
        if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
            assigned_var_name = node.targets[0].id
            if assigned_var_name in self.int_variables:
                node.value = ast.BinOp(left=node.value, op=ast.Mult(), right=ast.Constant(value=100))
        return self.generic_visit(node)

    def visit_Return(self, node: ast.Return) -> ast.Return:
        print("1", isinstance(node.value, ast.BinOp))
        print("2", isinstance(node.value.left, ast.Name))
        print("3", node.value.left.id in self.int_variables)
        print("4", node.value.left.id)
        print("5", self.int_variables)
        if isinstance(node.value, ast.BinOp) and isinstance(node.value.left, ast.Name) and node.value.left.id in self.int_variables:
            node.value.left = ast.BinOp(left=node.value.left, op=ast.Mult(), right=ast.Constant(value=100))
            node.value.right = ast.BinOp(left=node.value.right, op=ast.Mult(), right=ast.Constant(value=100))
        return self.generic_visit(node)

In [16]:
def add_square(a: int, b: int) -> int:
    c = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

NameError: name 'collect_int_variables' is not defined

Still not working, becaues 'c' is not recognized as an integer. 
The fix is given as follows 

In [16]:
def collect_int_variables(typed_tree: ast.AST) -> typing.Set[str]:
    int_variables = set()

    for node in ast.walk(typed_tree):
        if isinstance(node, ast.arg) and isinstance(node.annotation, ast.Name) and node.annotation.id == 'int':
            int_variables.add(node.arg)

        if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and isinstance(node.annotation, ast.Name) and node.annotation.id == 'int':
            int_variables.add(node.target.id)

    return int_variables

In [18]:
def add_square(a: int, b: int) -> int:
    c: int = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
print(int_variables)

{'c', 'a', 'b'}


In [19]:
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

1 True
2 True
3 True
4 c
5 {'c', 'a', 'b'}
def add_square(a: int, b: int) -> int:
    c * 100: int = a * 100 + b * 100
    return c * 100 * 100 * (c * 100 * 100)


In [17]:
class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

    def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp:
        if (
            isinstance(node.left, ast.Name) and node.left.id in self.int_variables
            or isinstance(node.right, ast.Name) and node.right.id in self.int_variables
        ):
            node.left = self.visit(node.left)
            node.right = self.visit(node.right)
        else:
            node.left = self.visit(node.left)
            node.right = self.visit(node.right)
        return node

In [18]:
def add_square(a: int, b: int) -> int:
    c: int = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

NameError: name 'collect_int_variables' is not defined

결과가 integer인지 확인하는 방법. 

-> 이걸로 typed_tree에 custom property를 추가하면 될듯. 

# Current version

In [12]:
class TypedFunctionVisitor(ast.NodeTransformer):
    def __init__(self):
        self.current_function = None

    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
        self.current_function = node.name

        # Extract type annotations from function arguments
        arg_types: typing.Dict[str, ast.expr] = {}
        for arg in node.args.args:
            if arg.annotation:
                arg_types[arg.arg] = arg.annotation

        # Create new arguments with type annotations
        new_args = []
        for arg in node.args.args:
            if arg.arg in arg_types:
                new_args.append(ast.arg(arg=arg.arg, annotation=arg_types[arg.arg]))
            else:
                new_args.append(ast.arg(arg=arg.arg, annotation=None))

        # Replace the original arguments with the new typed ones
        node.args.args = new_args

        # Process the function body
        self.generic_visit(node)

        # Reset the current function
        self.current_function = None

        return node

    def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp:
        if self.current_function == 'add':
            # Assuming both left and right operands are named arguments from the 'add' function
            left_type = ast.Name(id='int', ctx=ast.Load())
            right_type = ast.Name(id='int', ctx=ast.Load())
            node.left_type = left_type
            node.right_type = right_type

        return self.generic_visit(node)


class MultiplyIntVariables(ast.NodeTransformer):
    def visit_Name(self, node: ast.Name) -> ast.expr:
        if hasattr(node, 'dtype') and node.dtype == 'int':
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

    def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp:
        if is_integer_binop_result(node.left, node.right, node):
            node.left = self.visit(node.left)
            node.right = self.visit(node.right)
        else:
            node.left = self.visit(node.left)
            node.right = self.visit(node.right)
        return node

In [13]:
class AddDtypeAttribute(ast.NodeTransformer):
    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
        for arg in node.args.args:
            if isinstance(arg.annotation, ast.Name):
                arg.dtype = arg.annotation.id
        self.generic_visit(node)
        return node

    def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign:
        if isinstance(node.annotation, ast.Name):
            node.target.dtype = node.annotation.id
        self.generic_visit(node)
        return node
    

def apply_multiply_int_variables(tree: ast.AST) -> ast.AST:
    transformer = MultiplyIntVariables()
    return transformer.visit(tree)

def is_integer_binop_result(left: ast.AST, right: ast.AST, operation: ast.BinOp) -> bool:
    if not (isinstance(left, (ast.Name, ast.Constant)) and isinstance(right, (ast.Name, ast.Constant))):
        return False

    if not (isinstance(operation.op, (ast.Add, ast.Sub, ast.Mult, ast.FloorDiv, ast.Mod, ast.Pow, ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift, ast.RShift))):
        return False

    return True


def generate_typed_ast(source_code: str) -> ast.AST:
    tree = ast.parse(source_code)
    visitor = TypedFunctionVisitor()
    typed_tree = visitor.visit(tree)
    return typed_tree

In [27]:
def add_square(a: int, b: int) -> int:
    c: int = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)

In [30]:
typed_tree_with_dtype = AddDtypeAttribute().visit(typed_tree)

In [31]:
print(ast.dump(typed_tree_with_dtype.body[0], indent=2))

FunctionDef(
  name='add_square',
  args=arguments(
    posonlyargs=[],
    args=[
      arg(
        arg='a',
        annotation=Name(id='int', ctx=Load())),
      arg(
        arg='b',
        annotation=Name(id='int', ctx=Load()))],
    kwonlyargs=[],
    kw_defaults=[],
    defaults=[]),
  body=[
    AnnAssign(
      target=Name(id='c', ctx=Store()),
      annotation=Name(id='int', ctx=Load()),
      value=BinOp(
        left=Name(id='a', ctx=Load()),
        op=Add(),
        right=Name(id='b', ctx=Load())),
      simple=1),
    Return(
      value=BinOp(
        left=Name(id='c', ctx=Load()),
        op=Mult(),
        right=Name(id='c', ctx=Load())))],
  decorator_list=[],
  returns=Name(id='int', ctx=Load()))


In [32]:
body = typed_tree_with_dtype.body[0]

In [40]:
print(ast.dump(body.args.args[0]))

arg(arg='a', annotation=Name(id='int', ctx=Load()))


In [41]:
body.args.args[0].dtype

'int'

In [25]:
multiplied_tree = apply_multiply_int_variables(typed_tree_with_dtype)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

def add_square(a: int, b: int) -> int:
    c * 100: int = a + b
    return c * c


In [17]:
body = typed_tree_with_dtype.body[0]

In [18]:
assign = body.body[0]
ret = body.body[1]

In [22]:
assign.target.dtype

'c'

In [69]:
target.id

'c'

In [67]:
print(ast.dump(assign.targets[0], indent=2))

Name(id='c', ctx=Store())


In [16]:
print(ast.dump(typed_tree_with_dtype.body[0], indent=2))

FunctionDef(
  name='add_square',
  args=arguments(
    posonlyargs=[],
    args=[
      arg(
        arg='a',
        annotation=Name(id='int', ctx=Load())),
      arg(
        arg='b',
        annotation=Name(id='int', ctx=Load()))],
    kwonlyargs=[],
    kw_defaults=[],
    defaults=[]),
  body=[
    AnnAssign(
      target=Name(id='int', ctx=Store()),
      annotation=Name(id='c', ctx=Load()),
      value=BinOp(
        left=Name(id='a', ctx=Load()),
        op=Add(),
        right=Name(id='b', ctx=Load())),
      simple=1),
    Return(
      value=BinOp(
        left=Name(id='c', ctx=Load()),
        op=Mult(),
        right=Name(id='c', ctx=Load())))],
  decorator_list=[],
  returns=Name(id='int', ctx=Load()))


In [76]:
def custom_dump(node: ast.AST, *, indent: int = 2, level: int = 0) -> str:
    def _format(node):
        if isinstance(node, ast.AST):
            return custom_dump(node, indent=indent, level=level + 1)
        elif isinstance(node, list):
            return f"[{', '.join(_format(x) for x in node)}]"
        else:
            return repr(node)

    if not isinstance(node, ast.AST):
        raise TypeError("Expected instance of ast.AST")

    prefix = f"{node.__class__.__name__}("
    if hasattr(node, "dtype"):
        prefix += f"dtype={node.dtype}, "

    fields = [f"{name}={_format(value)}" for name, value in ast.iter_fields(node)]
    result = f"{prefix}{', '.join(fields)})"

    if indent is not None:
        result = result.replace(",", f",\n{' ' * ((level + 1) * indent)}")
        #result = f"{' ' * (level * indent)}{result}"

    return result

In [77]:
print(custom_dump(typed_tree_with_dtype.body[0], indent=1))

FunctionDef(name='add_square',
  args=arguments(posonlyargs=[],
 
   args=[arg(dtype=int,
 
  
    arg='a',
 
  
    annotation=Name(id='int',
 
  
   
     ctx=Load()),
 
  
    type_comment=None),
 
   arg(dtype=int,
 
  
    arg='b',
 
  
    annotation=Name(id='int',
 
  
   
     ctx=Load()),
 
  
    type_comment=None)],
 
   vararg=None,
 
   kwonlyargs=[],
 
   kw_defaults=[],
 
   kwarg=None,
 
   defaults=[]),
  body=[AnnAssign(target=Name(dtype=int,
 
  
    id='c',
 
  
    ctx=Store()),
 
   annotation=Name(id='int',
 
  
    ctx=Load()),
 
   value=BinOp(left=Name(id='a',
 
  
   
     ctx=Load()),
 
  
    op=Add(),
 
  
    right=Name(id='b',
 
  
   
     ctx=Load())),
 
   simple=1),
  Return(value=BinOp(left=Name(id='c',
 
  
   
     ctx=Load()),
 
  
    op=Mult(),
 
  
    right=Name(id='c',
 
  
   
     ctx=Load())))],
  decorator_list=[],
  returns=Name(id='int',
 
   ctx=Load()),
  type_comment=None)


In [47]:
ast.dump??

In [25]:
def add_square(a: int, b: int) -> int:
    c = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

def add_square(a: int, b: int) -> int:
    c = a * 100 + b * 100
    return c * c


In [35]:
typed_tree

<ast.Module at 0x7f6a3b821480>