In [2]:
import inspect
import ast
import numpy as np
import typing

In [5]:
class TypedFunctionVisitor(ast.NodeTransformer):
    def __init__(self):
        self.current_function = None

    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
        self.current_function = node.name

        # Extract type annotations from function arguments
        arg_types: typing.Dict[str, ast.expr] = {}
        for arg in node.args.args:
            if arg.annotation:
                arg_types[arg.arg] = arg.annotation

        # Create new arguments with type annotations
        new_args = []
        for arg in node.args.args:
            if arg.arg in arg_types:
                new_args.append(ast.arg(arg=arg.arg, annotation=arg_types[arg.arg]))
            else:
                new_args.append(ast.arg(arg=arg.arg, annotation=None))

        # Replace the original arguments with the new typed ones
        node.args.args = new_args

        # Process the function body
        self.generic_visit(node)

        # Reset the current function
        self.current_function = None

        return node

    def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp:
        if self.current_function == 'add':
            # Assuming both left and right operands are named arguments from the 'add' function
            left_type = ast.Name(id='int', ctx=ast.Load())
            right_type = ast.Name(id='int', ctx=ast.Load())
            node.left_type = left_type
            node.right_type = right_type

        return self.generic_visit(node)


def generate_typed_ast(source_code: str) -> ast.AST:
    tree = ast.parse(source_code)
    visitor = TypedFunctionVisitor()
    typed_tree = visitor.visit(tree)
    return typed_tree

In [31]:
if __name__ == "__main__":
    source_code = '''
def greet(name: str) -> None:
    print(f"Hello, {name}!")

def add(a: int, b: int) -> int:
    return a + b
'''

    typed_tree = generate_typed_ast(source_code)
    typed_source_code = ast.unparse(typed_tree)
    print(typed_source_code)

def greet(name: str) -> None:
    print(f'Hello, {name}!')

def add(a: int, b: int) -> int:
    return a + b


In [32]:
print(ast.dump(typed_tree, indent=1))

Module(
 body=[
  FunctionDef(
   name='greet',
   args=arguments(
    posonlyargs=[],
    args=[
     arg(
      arg='name',
      annotation=Name(id='str', ctx=Load()))],
    kwonlyargs=[],
    kw_defaults=[],
    defaults=[]),
   body=[
    Expr(
     value=Call(
      func=Name(id='print', ctx=Load()),
      args=[
       JoinedStr(
        values=[
         Constant(value='Hello, '),
         FormattedValue(
          value=Name(id='name', ctx=Load()),
          conversion=-1),
         Constant(value='!')])],
      keywords=[]))],
   decorator_list=[],
   returns=Constant(value=None)),
  FunctionDef(
   name='add',
   args=arguments(
    posonlyargs=[],
    args=[
     arg(
      arg='a',
      annotation=Name(id='int', ctx=Load())),
     arg(
      arg='b',
      annotation=Name(id='int', ctx=Load()))],
    kwonlyargs=[],
    kw_defaults=[],
    defaults=[]),
   body=[
    Return(
     value=BinOp(
      left=Name(id='a', ctx=Load()),
      op=Add(),
      right=Name(id='b', ctx

In [37]:
print(ast.dump(typed_tree.body[1], indent=1))

FunctionDef(
 name='add',
 args=arguments(
  posonlyargs=[],
  args=[
   arg(
    arg='a',
    annotation=Name(id='int', ctx=Load())),
   arg(
    arg='b',
    annotation=Name(id='int', ctx=Load()))],
  kwonlyargs=[],
  kw_defaults=[],
  defaults=[]),
 body=[
  Return(
   value=BinOp(
    left=Name(id='a', ctx=Load()),
    op=Add(),
    right=Name(id='b', ctx=Load())))],
 decorator_list=[],
 returns=Name(id='int', ctx=Load()))


Multiply integer variable if the argument is typed as INT

In [None]:
class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

In [7]:
def apply_multiply_int_variables(tree: ast.AST, int_variables: typing.Set[str]) -> ast.AST:
    transformer = MultiplyIntVariables(int_variables=int_variables)
    return transformer.visit(tree)


def collect_int_variables(typed_tree: ast.AST) -> typing.Set[str]:
    int_variables = set()

    for node in ast.walk(typed_tree):
        if isinstance(node, ast.arg) and isinstance(node.annotation, ast.Name) and node.annotation.id == 'int':
            int_variables.add(node.arg)

    return int_variables

In [42]:
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

def greet(name: str) -> None:
    print(f'Hello, {name}!')

def add(a: int, b: int) -> int:
    return a * 100 + b * 100


Also identify derived INT varibles and Multiply them.

WRONG. It doesn't work! 

In [None]:
class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

    def visit_Assign(self, node: ast.Assign) -> ast.Assign:
        if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
            assigned_var_name = node.targets[0].id
            if assigned_var_name in self.int_variables:
                node.value = ast.BinOp(left=node.value, op=ast.Mult(), right=ast.Constant(value=100))
        return self.generic_visit(node)

In [53]:
def add_square(a: int, b: int) -> int:
    c = a + b 
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

def add_square(a: int, b: int) -> int:
    c = a * 100 + b * 100
    return c * c


In [54]:
class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

    def visit_Assign(self, node: ast.Assign) -> ast.Assign:
        if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
            assigned_var_name = node.targets[0].id
            if assigned_var_name in self.int_variables:
                node.value = ast.BinOp(left=node.value, op=ast.Mult(), right=ast.Constant(value=100))
        return self.generic_visit(node)

    def visit_Return(self, node: ast.Return) -> ast.Return:
        if isinstance(node.value, ast.BinOp) and isinstance(node.value.left, ast.Name) and node.value.left.id in self.int_variables:
            node.value = ast.BinOp(left=node.value, op=ast.Mult(), right=ast.Constant(value=100))
        return self.generic_visit(node)

In [55]:
source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

def add_square(a: int, b: int) -> int:
    c = a * 100 + b * 100
    return c * c


In [15]:
int_variables

{'a', 'b'}

In [13]:
class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

    def visit_Assign(self, node: ast.Assign) -> ast.Assign:
        if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
            assigned_var_name = node.targets[0].id
            if assigned_var_name in self.int_variables:
                node.value = ast.BinOp(left=node.value, op=ast.Mult(), right=ast.Constant(value=100))
        return self.generic_visit(node)

    def visit_Return(self, node: ast.Return) -> ast.Return:
        print("1", isinstance(node.value, ast.BinOp))
        print("2", isinstance(node.value.left, ast.Name))
        print("3", node.value.left.id in self.int_variables)
        print("4", node.value.left.id)
        print("5", self.int_variables)
        if isinstance(node.value, ast.BinOp) and isinstance(node.value.left, ast.Name) and node.value.left.id in self.int_variables:
            node.value.left = ast.BinOp(left=node.value.left, op=ast.Mult(), right=ast.Constant(value=100))
            node.value.right = ast.BinOp(left=node.value.right, op=ast.Mult(), right=ast.Constant(value=100))
        return self.generic_visit(node)

In [14]:
def add_square(a: int, b: int) -> int:
    c = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

1 True
2 True
3 False
4 c
5 {'a', 'b'}
def add_square(a: int, b: int) -> int:
    c = a * 100 + b * 100
    return c * c


Still not working, becaues 'c' is not recognized as an integer. 
The fix is given as follows 

In [16]:
def collect_int_variables(typed_tree: ast.AST) -> typing.Set[str]:
    int_variables = set()

    for node in ast.walk(typed_tree):
        if isinstance(node, ast.arg) and isinstance(node.annotation, ast.Name) and node.annotation.id == 'int':
            int_variables.add(node.arg)

        if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and isinstance(node.annotation, ast.Name) and node.annotation.id == 'int':
            int_variables.add(node.target.id)

    return int_variables

In [18]:
def add_square(a: int, b: int) -> int:
    c: int = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
print(int_variables)

{'c', 'a', 'b'}


In [19]:
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

1 True
2 True
3 True
4 c
5 {'c', 'a', 'b'}
def add_square(a: int, b: int) -> int:
    c * 100: int = a * 100 + b * 100
    return c * 100 * 100 * (c * 100 * 100)


In [20]:
class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

    def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp:
        if (
            isinstance(node.left, ast.Name) and node.left.id in self.int_variables
            or isinstance(node.right, ast.Name) and node.right.id in self.int_variables
        ):
            node.left = self.visit(node.left)
            node.right = self.visit(node.right)
        else:
            node.left = self.visit(node.left)
            node.right = self.visit(node.right)
        return node

In [21]:
def add_square(a: int, b: int) -> int:
    c: int = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

def add_square(a: int, b: int) -> int:
    c * 100: int = a * 100 + b * 100
    return c * 100 * (c * 100)


결과가 integer인지 확인하는 방법. 

-> 이걸로 typed_tree에 custom property를 추가하면 될듯. 

In [22]:
def is_integer_binop_result(left: ast.AST, right: ast.AST, operation: ast.BinOp) -> bool:
    if not (isinstance(left, (ast.Name, ast.Constant)) and isinstance(right, (ast.Name, ast.Constant))):
        return False

    if not (isinstance(operation.op, (ast.Add, ast.Sub, ast.Mult, ast.FloorDiv, ast.Mod, ast.Pow, ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift, ast.RShift))):
        return False

    return True

In [23]:
class MultiplyIntVariables(ast.NodeTransformer):
    def __init__(self, int_variables: typing.Set[str]):
        self.int_variables = int_variables

    def visit_Name(self, node: ast.Name) -> ast.expr:
        if node.id in self.int_variables:
            return ast.BinOp(left=node, op=ast.Mult(), right=ast.Constant(value=100))
        return node

    def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp:
        if is_integer_binop_result(node.left, node.right, node):
            node.left = self.visit(node.left)
            node.right = self.visit(node.right)
        else:
            node.left = self.visit(node.left)
            node.right = self.visit(node.right)
        return node

In [25]:
def add_square(a: int, b: int) -> int:
    c = a + b
    return c * c

source_code = inspect.getsource(add_square)

typed_tree = generate_typed_ast(source_code)
int_variables = collect_int_variables(typed_tree)
multiplied_tree = apply_multiply_int_variables(typed_tree, int_variables)
multiplied_source_code = ast.unparse(multiplied_tree)
print(multiplied_source_code)

def add_square(a: int, b: int) -> int:
    c = a * 100 + b * 100
    return c * c


TypeError: 'FunctionDef' object is not iterable

In [28]:
print(ast.dump(typed_tree, indent=1))

Module(
 body=[
  FunctionDef(
   name='add_square',
   args=arguments(
    posonlyargs=[],
    args=[
     arg(
      arg='a',
      annotation=Name(id='int', ctx=Load())),
     arg(
      arg='b',
      annotation=Name(id='int', ctx=Load()))],
    kwonlyargs=[],
    kw_defaults=[],
    defaults=[]),
   body=[
    Assign(
     targets=[
      Name(id='c', ctx=Store())],
     value=BinOp(
      left=BinOp(
       left=Name(id='a', ctx=Load()),
       op=Mult(),
       right=Constant(value=100)),
      op=Add(),
      right=BinOp(
       left=Name(id='b', ctx=Load()),
       op=Mult(),
       right=Constant(value=100)))),
    Return(
     value=BinOp(
      left=Name(id='c', ctx=Load()),
      op=Mult(),
      right=Name(id='c', ctx=Load())))],
   decorator_list=[],
   returns=Name(id='int', ctx=Load()))],
 type_ignores=[])


In [10]:
print(ast.dump(multiplied_tree, indent=1))

Module(
 body=[
  FunctionDef(
   name='add_square',
   args=arguments(
    posonlyargs=[],
    args=[
     arg(
      arg='a',
      annotation=Name(id='int', ctx=Load())),
     arg(
      arg='b',
      annotation=Name(id='int', ctx=Load()))],
    kwonlyargs=[],
    kw_defaults=[],
    defaults=[]),
   body=[
    Assign(
     targets=[
      Name(id='c', ctx=Store())],
     value=BinOp(
      left=BinOp(
       left=Name(id='a', ctx=Load()),
       op=Mult(),
       right=Constant(value=100)),
      op=Add(),
      right=BinOp(
       left=Name(id='b', ctx=Load()),
       op=Mult(),
       right=Constant(value=100)))),
    Return(
     value=BinOp(
      left=Name(id='c', ctx=Load()),
      op=Mult(),
      right=Name(id='c', ctx=Load())))],
   decorator_list=[],
   returns=Name(id='int', ctx=Load()))],
 type_ignores=[])


In [40]:
ast_body.type_comment

In [18]:
def var(data:np.ndarray):
    """Calculate variance
    """
    m = np.mean(data)
    diff = (data - m)
    result = np.sum(diff*diff)/len(data)
    return result


tree = ast.parse(inspect.getsource(var))

In [27]:
print(ast.dump(tree, indent=1))

Module(
 body=[
  FunctionDef(
   name='var',
   args=arguments(
    posonlyargs=[],
    args=[
     arg(
      arg='data',
      annotation=Attribute(
       value=Name(id='np', ctx=Load()),
       attr='ndarray',
       ctx=Load()))],
    kwonlyargs=[],
    kw_defaults=[],
    defaults=[]),
   body=[
    Expr(
     value=Constant(value='Calculate variance\n    ')),
    Assign(
     targets=[
      Name(id='m', ctx=Store())],
     value=Call(
      func=Attribute(
       value=Name(id='np', ctx=Load()),
       attr='mean',
       ctx=Load()),
      args=[
       Name(id='data', ctx=Load())],
      keywords=[])),
    Assign(
     targets=[
      Name(id='diff', ctx=Store())],
     value=BinOp(
      left=Name(id='data', ctx=Load()),
      op=Sub(),
      right=Name(id='m', ctx=Load()))),
    Assign(
     targets=[
      Name(id='result', ctx=Store())],
     value=BinOp(
      left=Call(
       func=Attribute(
        value=Name(id='np', ctx=Load()),
        attr='sum',
        ctx=Load

In [20]:
ast.unparse(tree)

'def var(data: np.ndarray):\n    """Calculate variance\n    """\n    m = np.mean(data)\n    diff = data - m\n    result = np.sum(diff * diff) / len(data)\n    return result'

In [17]:
visitor = BinOpReplacer()
visitor.visit(tree)

tree = ast.fix_missing_locations(tree)
print("------")
print(ast.unparse(tree))

NameError: name 'BinOpReplacer' is not defined

In [22]:
def myfun(name: str, age: int) -> str:
    return f"Hello, {name}! You are {age} years old."

typed_signature = build_typed_function_signature(var)

In [23]:
print(ast.dump(typed_signature))

arguments(args=[], kwonlyargs=[], kw_defaults=[], defaults=[])


In [9]:
func_ast = ast.parse(inspect.getsource(myfun))
fun_ast = ast.iter_fields(func_ast)

In [15]:
func_ast.body[0]

<ast.FunctionDef at 0x7f1a3eead240>

In [12]:
for node in fun_ast:
    print(node)

('type_ignores', [])
