# Tests for AST transformers

In [1]:
import ast
from pto.core.automatic_names import transform_ast


In [2]:
def check_trans(source):
    tree = ast.parse(source)

    print('BEFORE:')
    print(ast.unparse(tree))

    # tansform ast
    tree = transform_ast(tree)

    print('AFTER:')
    print(ast.unparse(tree))
    

In [3]:
source = '''def f(): print('hello')'''

check_trans(source)

BEFORE:
def f():
    print('hello')
AFTER:
@func_name
def f():
    print('hello')


In [4]:
source = '''def g(): [x for x in range(10)]'''

check_trans(source)

BEFORE:
def g():
    [x for x in range(10)]
AFTER:
@func_name
def g():
    [x for x in iter_name(range(10))]


In [5]:
source = '''def j():\n\tc = 10\n\twhile c:\n\t\tc-=1'''

check_trans(source)

BEFORE:
def j():
    c = 10
    while c:
        c -= 1
AFTER:
@func_name
def j():
    c = 10
    with Loop_name() as count:
        while c:
            c -= 1
            count()


In [6]:
source = '''def s():\n\tdef t():\n\t\tprint('hello')\n\tdef u():\n\t\tprint('world')'''

check_trans(source)

BEFORE:
def s():

    def t():
        print('hello')

    def u():
        print('world')
AFTER:
@func_name
def s():

    @func_name
    def t():
        print('hello')

    @func_name
    def u():
        print('world')


In [7]:
source = '''def f(matrix):\n\treturn [val for sublist in matrix for val in sublist]'''

check_trans(source)

BEFORE:
def f(matrix):
    return [val for sublist in matrix for val in sublist]
AFTER:
@func_name
def f(matrix):
    return [val for sublist in iter_name(matrix) for val in iter_name(sublist)]


In [8]:
source = '''def g(matrix):\n\tfor sublist in matrix:\n\t\tfor val in sublist:\n\t\t\tflatten_matrix.append(val)'''

check_trans(source)

BEFORE:
def g(matrix):
    for sublist in matrix:
        for val in sublist:
            flatten_matrix.append(val)
AFTER:
@func_name
def g(matrix):
    with Loop_name() as count:
        for sublist in matrix:
            with Loop_name() as count:
                for val in sublist:
                    flatten_matrix.append(val)
                    count()
            count()


In [9]:
source = '''def h(matrix):\n\tfor sublist in matrix:\n\t\twhile val:\n\t\t\tflatten_matrix.append(val)'''

check_trans(source)

BEFORE:
def h(matrix):
    for sublist in matrix:
        while val:
            flatten_matrix.append(val)
AFTER:
@func_name
def h(matrix):
    with Loop_name() as count:
        for sublist in matrix:
            with Loop_name() as count:
                while val:
                    flatten_matrix.append(val)
                    count()
            count()


In [10]:
source = '''def j(matrix):\n\tfor sublist in matrix:\n\t\twhile val:\n\t\t\t[x for x in range(5)]'''

check_trans(source)

BEFORE:
def j(matrix):
    for sublist in matrix:
        while val:
            [x for x in range(5)]
AFTER:
@func_name
def j(matrix):
    with Loop_name() as count:
        for sublist in matrix:
            with Loop_name() as count:
                while val:
                    [x for x in iter_name(range(5))]
                    count()
            count()
