In [19]:
import asttokens

In [21]:
code = '''
def add(a,b):
    return test_m(a) + b
'''

In [22]:
atok = asttokens.ASTTokens(code, parse=True)


In [34]:
for node in ast.walk(atok.tree):
  if hasattr(node, 'lineno'):
    print(atok.get_text_range(node), node.__class__.__name__, atok.get_text(node),node)


(1, 39) FunctionDef def add(a,b):
    return test_m(a) + b <_ast.FunctionDef object at 0x7fd2b5807da0>
(19, 39) Return return test_m(a) + b <_ast.Return object at 0x7fd2b5807e80>
(9, 10) arg a <_ast.arg object at 0x7fd2b5807e10>
(11, 12) arg b <_ast.arg object at 0x7fd2b5807e48>
(26, 39) BinOp test_m(a) + b <_ast.BinOp object at 0x7fd2b5807eb8>
(26, 35) Call test_m(a) <_ast.Call object at 0x7fd2b5807ef0>
(38, 39) Name b <_ast.Name object at 0x7fd2b5807f98>
(26, 32) Name test_m <_ast.Name object at 0x7fd2b5807f28>
(33, 34) Name a <_ast.Name object at 0x7fd2b5807f60>


In [28]:
next(n for n in asttokens.walk(atok.tree) if isinstance(n, ast.Attribute))


AttributeError: module 'asttokens' has no attribute 'walk'

In [40]:
import ast
tree = ast.parse(code)
for function in tree.body:
    if isinstance(function, ast.FunctionDef):
        # Just in case if there are loops in the definition
        lastBody = function.body[-1]
        while isinstance(lastBody, (ast.For, ast.While, ast.If)):
            lastBody = lastBody.Body[-1]
        lastLine = lastBody.lineno
        print("Name of the function is ", function.name)
        print("firstLine of the function is ", function.lineno)
        print("LastLine of the function is ", lastLine)
        print("the source lines are ")
        if isinstance(code, str):
            code = code.split("\n")
        for i, line in enumerate(code, 1):
            if i in range(function.lineno, lastLine+1):
                print(line)


Name of the function is  add
firstLine of the function is  2
LastLine of the function is  3
the source lines are 
def add(a,b):
    return test_m(a) + b


In [56]:
import asttokens
import javalang
import json
from tqdm import tqdm
import collections
import sys
import tokenize


def process_source(file_name, save_file):
    with open(file_name, 'r', encoding='utf-8') as source:
        lines = source.readlines()
    with open(save_file, 'w+', encoding='utf-8') as save:
        for line in lines:
            code = line.strip()
            tokens = list(javalang.tokenizer.tokenize(code))
            tks = []
            for tk in tokens:
                if tk.__class__.__name__ == 'String' or tk.__class__.__name__ == 'Character':
                    tks.append('STR_')
                elif 'Integer' in tk.__class__.__name__ or 'FloatingPoint' in tk.__class__.__name__:
                    tks.append('NUM_')
                elif tk.__class__.__name__ == 'Boolean':
                    tks.append('BOOL_')
                else:
                    tks.append(tk.value)
            save.write(" ".join(tks) + '\n')


def get_ast(file_name, w):
    with open(file_name, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    with open(w, 'w+', encoding='utf-8') as wf:
        ign_cnt = 0
        for line in tqdm(lines):
            code = line.strip()
            tokens = javalang.tokenizer.tokenize(code)
            token_list = list(javalang.tokenizer.tokenize(code))
            length = len(token_list)
            parser = javalang.parser.Parser(tokens)
            try:
                tree = parser.parse_member_declaration()
            except (javalang.parser.JavaSyntaxError, IndexError, StopIteration, TypeError):
                print(code)
                continue
            flatten = []
            for path, node in tree:
                flatten.append({'path': path, 'node': node})

            ign = False
            outputs = []
            stop = False
            for i, Node in enumerate(flatten):
                d = collections.OrderedDict()
                path = Node['path']
                node = Node['node']
                children = []
                for child in node.children:
                    child_path = None
                    if isinstance(child, javalang.ast.Node):
                        child_path = path + tuple((node,))
                        for j in range(i + 1, len(flatten)):
                            if child_path == flatten[j]['path'] and child == flatten[j]['node']:
                                children.append(j)
                    if isinstance(child, list) and child:
                        child_path = path + (node, child)
                        for j in range(i + 1, len(flatten)):
                            if child_path == flatten[j]['path']:
                                children.append(j)
                d["id"] = i
                d["type"] = str(node)
                if children:
                    d["children"] = children
                value = None
                if hasattr(node, 'name'):
                    value = node.name
                elif hasattr(node, 'value'):
                    value = node.value
                elif hasattr(node, 'position') and node.position:
                    for i, token in enumerate(token_list):
                        if node.position == token.position:
                            pos = i + 1
                            value = str(token.value)
                            while pos < length and token_list[pos].value == '.':
                                value = value + '.' + token_list[pos + 1].value
                                pos += 2
                            break
                elif type(node) is javalang.tree.This \
                        or type(node) is javalang.tree.ExplicitConstructorInvocation:
                    value = 'this'
                elif type(node) is javalang.tree.BreakStatement:
                    value = 'break'
                elif type(node) is javalang.tree.ContinueStatement:
                    value = 'continue'
                elif type(node) is javalang.tree.TypeArgument:
                    value = str(node.pattern_type)
                elif type(node) is javalang.tree.SuperMethodInvocation \
                        or type(node) is javalang.tree.SuperMemberReference:
                    value = 'super.' + str(node.member)
                elif type(node) is javalang.tree.Statement \
                        or type(node) is javalang.tree.BlockStatement \
                        or type(node) is javalang.tree.ForControl \
                        or type(node) is javalang.tree.ArrayInitializer \
                        or type(node) is javalang.tree.SwitchStatementCase:
                    value = 'None'
                elif type(node) is javalang.tree.VoidClassReference:
                    value = 'void.class'
                elif type(node) is javalang.tree.SuperConstructorInvocation:
                    value = 'super'

                if value is not None and type(value) is type('str'):
                    d['value'] = value
                if not children and not value:
                    # print('Leaf has no value!')
                    print(type(node))
                    print(code)
                    ign = True
                    ign_cnt += 1
                    # break
                outputs.append(d)
            if not ign:
                wf.write(json.dumps(outputs))
                wf.write('\n')
    print(ign_cnt)


PAD = 0
UNK = 1
BOS = 2
EOS = 3

PAD_WORD = '<blank>'
UNK_WORD = '<unk>'
BOS_WORD = '<s>'
EOS_WORD = '</s>'

NODE_FIX = '1*NODEFIX'  # '1*NODEFIX'


def python2tree(line):
    atok = asttokens.ASTTokens(line, parse=True)
    return atok, atok.tree


def traverse_python_tree(atok, root):
    iter_children = asttokens.util.iter_children_func(root)
    node_json = []
    current_global = {}
    current_idx, global_idx = 1, 1
    for node in asttokens.util.walk(root):
        if not next(iter_children(node), None) is None:
            child_num = 0
            for child in iter_children(node):
                child_num += 1
            global_idx = global_idx + child_num
            current_global[current_idx] = global_idx
        current_idx += 1
    # print current_global
    current_idx = 1
    for node in asttokens.util.walk(root):
        # print current_idx
        # idx_upper = current_idx
        new_node = {"id": current_idx, "type": type(node).__name__, "children": [],
                    "value": atok.get_text(node)}
        if new_node["type"] == 'Name':
            new_node['type'] = new_node['value']
        idx_upper = len(node_json)
        if not next(iter_children(node), None) is None:
            child_idx = 0
            for child in iter_children(node):
                child_idx += 1

                new_node['children'].append(
                    current_global[current_idx] - child_idx + 1)
        else:  # leaf node
            new_node['children'].append(atok.get_text(node))
        node_json.append(new_node)
        current_idx += 1

    # update_parent
    for k, node in enumerate(node_json):
        print(node)
        children = [c for c in node['children'] if c.startswith(NODE_FIX)]
        if len(children):
            for c in children:
                node_json[c]['parent'] = k

    return node_json


def process_python(file, des):
    atok, tree = python2tree(open(file, 'r', encoding='utf-8').read())
    tree_json = traverse_python_tree(atok, tree)
    open(des, 'w', encoding='utf-8').write(json.dumps(tree_json))


In [230]:
code = '''
def add(a,b):
    a = b**2
    return test_m(a) + b
'''


In [231]:
atok, tree = python2tree(code)


In [232]:
def get_child(root, iters):
    child_iter = next(iters(root), None)
    if child_iter is None:
        return []
    
    def expand(nested_list):
        for item in iters(nested_list):
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(child_iter))


In [233]:
def iter_node(node, iter_children):
  children = get_child(node,iter_children)
  if len(children)>0:
      print(f'node: {node.__class__.__name__} {node._fields}')
      try:
        print(f'{node.name}')
      except:
        pass
      try:
        print(f'{node.op._fields}')
      except:
        pass
  else:
      print(
          f'node: {node.__class__.__name__} {node._fields} {atok.get_text(node)}')
  try:
    print('attrs')
    for i in node._fields:
      if i == 'op':
        print(node.op)
  except:
    pass
  print(children)
  if children is not []:
    cnt = 0
    # print('have child')
    for child in iter_children(node):
      # print(f'child {cnt}')
      iter_node(child,iter_children)
      cnt += 1
  else:  # leaf node
    # print('no child')      
    pass


In [234]:
iter_children = asttokens.util.iter_children_func(tree)

iter_node(tree, iter_children)


node: Module ('body',)
attrs
[<_ast.arguments object at 0x7fcf44a42518>, <_ast.Assign object at 0x7fcf44a42e48>, <_ast.Return object at 0x7fcf44a42160>]
node: FunctionDef ('name', 'args', 'body', 'decorator_list', 'returns')
add
attrs
[<_ast.arg object at 0x7fcf44a428d0>, <_ast.arg object at 0x7fcf44a42e10>]
node: arguments ('args', 'vararg', 'kwonlyargs', 'kw_defaults', 'kwarg', 'defaults') a,b
attrs
[]
node: arg ('arg', 'annotation') a
attrs
[]
node: arg ('arg', 'annotation') b
attrs
[]
node: Assign ('targets', 'value') a = b**2
attrs
[]
node: Name ('id', 'ctx') a
attrs
[]
node: BinOp ('left', 'op', 'right') b**2
attrs
<_ast.Pow object at 0x7fd2ca1b9da0>
[]
node: Name ('id', 'ctx') b
attrs
[]
node: Num ('n',) 2
attrs
[]
node: Return ('value',)
attrs
[<_ast.Call object at 0x7fcf44a42828>, <_ast.Name object at 0x7fcf44a424a8>]
node: BinOp ('left', 'op', 'right')
()
attrs
<_ast.Add object at 0x7fd2ca1b9710>
[<_ast.Name object at 0x7fcf44a42f60>, <_ast.Name object at 0x7fcf44a42e80>]
nod

In [59]:
traverse_python_tree(atok, tree)


{'id': 1, 'type': 'Module', 'children': [2], 'value': '\ndef add(a,b):\n    return test_m(a) + b'}


AttributeError: 'int' object has no attribute 'startswith'

In [100]:
ast.Attribute

_ast.Attribute

In [101]:
python_special_name = ['Module', 'Interactive', 'Expression', 'FunctionType', 'FunctionDef', 'AsyncFunctionDef', 'ClassDef', 'Return', 'Delete', 'Assign', 'AugAssign', 'AnnAssign', 'For', 'AsyncFor', 'While', 'If', 'With', 'AsyncWith', 'Match', 'Raise', 'Try', 'Assert', 'Import', 'ImportFrom', 'Global', 'Nonlocal', 'Expr', 'Pass', 'Break', 'Continue', 'BoolOp', 'NamedExpr', 'BinOp', 'UnaryOp', 'Lambda', 'IfExp', 'Dict', 'Set', 'ListComp', 'SetComp', 'DictComp', 'GeneratorExp', 'Await', 'Yield', 'YieldFrom', 'Compare', 'Call',
    'FormattedValue', 'JoinedStr', 'Constant', 'Attribute', 'Subscript', 'Starred', 'Name', 'List', 'Tuple', 'Slice', 'Load', 'Store', 'Del', 'And', 'Or', 'Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift', 'BitOr', 'BitXor', 'BitAnd', 'FloorDiv', 'Invert', 'Not', 'UAdd', 'USub', 'Eq', 'NotEq', 'Lt', 'LtE', 'Gt', 'GtE', 'Is', 'IsNot', 'In', 'NotIn', 'ExceptHandler', 'MatchValue', 'MatchSingleton', 'MatchSequence', 'MatchMapping', 'MatchClass', 'MatchStar', 'MatchAs', 'MatchOr', 'TypeIgnore']

In [236]:
ast.NodeTransformer()

<ast.NodeTransformer at 0x7fcf44a42978>

In [239]:
import ast


class NodeVisitor(ast.NodeVisitor):
    def visit_Str(self, tree_node):
        print('{}'.format(tree_node.s))


class NodeTransformer(ast.NodeTransformer):
    def visit_Str(self, tree_node):
        return ast.Str('String: ' + tree_node.s)


tree_node = ast.parse('''
fruits = ['grapes', 'mango']
name = 'peter'

for fruit in fruits:
    print('{} likes {}'.format(name, fruit))
''')

# NodeTransformer().visit(tree_node)
NodeVisitor().visit(tree)


In [240]:
ast.Str('String: ')


<_ast.Str at 0x7fcf44a42400>

In [241]:
import ast
import asttokens
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, DataCollatorWithPadding
from anytree import AnyNode
from tqdm import tqdm
from treelib import Tree
import re
import operator
from functools import reduce

'''
Child 0
Parent 1
NextSib 2
NextUse 3
NextToken 4
SplitChild 5
SplitParent 6
SplitNextSib 7
LoopNext
'''
AST_EDGE = 0
# Parent = 1
NextSib = 1
NextUse = 2
NextToken = 3
# SplitChild = 4
# SplitParent = 5
# SplitNextSib = 6
LoopNext = 4
ControlOut = 5
ConditionNext = 6

checkpoint = 'microsoft/codebert-base'
tokenize_token = '_<SplitNode>_'
ast_tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
python_special_tokens = ['arguments', 'Module', 'Interactive', 'Expression', 'FunctionType', 'FunctionDef', 'AsyncFunctionDef', 'ClassDef', 'Return', 'Delete', 'Assign', 'AugAssign', 'AnnAssign', 'For', 'AsyncFor', 'While', 'If', 'With', 'AsyncWith', 'Match', 'Raise', 'Try', 'Assert', 'Import', 'ImportFrom', 'Global', 'Nonlocal', 'Expr', 'Pass', 'Break', 'Continue', 'BoolOp', 'NamedExpr', 'BinOp', 'UnaryOp', 'Lambda', 'IfExp', 'Dict', 'Set', 'ListComp', 'SetComp', 'DictComp', 'GeneratorExp', 'Await', 'Yield', 'YieldFrom', 'Compare', 'Call',
                         'FormattedValue', 'JoinedStr', 'Constant', 'Attribute', 'Subscript', 'Starred', 'Name', 'List', 'Tuple', 'Slice', 'Load', 'Store', 'Del', 'And', 'Or', 'Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift', 'BitOr', 'BitXor', 'BitAnd', 'FloorDiv', 'Invert', 'Not', 'UAdd', 'USub', 'Eq', 'NotEq', 'Lt', 'LtE', 'Gt', 'GtE', 'Is', 'IsNot', 'In', 'NotIn', 'ExceptHandler', 'MatchValue', 'MatchSingleton', 'MatchSequence', 'MatchMapping', 'MatchClass', 'MatchStar', 'MatchAs', 'MatchOr', 'TypeIgnore', tokenize_token]
special_tokens_dict = {'additional_special_tokens': python_special_tokens}
num_added_toks = ast_tokenizer.add_special_tokens(special_tokens_dict)


def visiulize_tree(any_node):
    tree = Tree()

    def new_tree(node, parent=None):
        if node is None:
            return
        tree.create_node(node.token, node.id, parent=(
            None if not parent else parent.id))
        for child in node.children:
         #  print(child.token)
            new_tree(child, node)

    new_tree(any_node)

    tree.show()

# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus


def has_child(node, iters):
    return next(iters(node), None) is not None


def get_token(node, atok, iters):
    token = 'None'
    class_name = node.__class__.__name__
    if class_name in python_special_tokens and class_name != 'Name':
        token = class_name
    elif not has_child(node, iters):
        token = atok.get_text(node)
    else:
        token = node.__class__.__name__
    return token


def get_child(root, iters):
    child_iter = next(iters(root), None)
    if child_iter is None:
        return []

    def expand(nested_list):
        for item in iters(nested_list):
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    children = list(expand(root))
    if 'op' in root._fields:
        children.append(ast.Str('Add'))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    atok = asttokens.ASTTokens(func, parse=True)
    return atok, atok.tree


#  generate tree for AST Node
def create_tree(atok, iters, root, node, node_list, sub_id_list, leave_list, tokenizer, parent=None):
    id = len(node_list)
    node_list.append(node)
    token, children = get_token(node, atok, iters), get_child(node, iters)

    if children == []:
        # print('this is a leaf:', token, id)
        leave_list.append(id)

    # Use roberta.tokenizer to generate subtokens
    # If a token can be divided into multiple(>1) subtokens, the first subtoken will be set as the previous node,
    # and the other subtokens will be set as its new children
    token = token.encode('utf-8', 'ignore').decode("utf-8")
    sub_token_list = tokenizer.tokenize(token)

    if len(sub_token_list) == 1 and sub_token_list[0] in python_special_tokens:
        pass
    # TODO convert token into lower
    else:
        sub_tokens = [tokenizer.tokenize(i.lower()) for i in re.sub(
            '([a-z0-9])([A-Z])', r'\1 \2', token).split()]
        for i in range(len(sub_tokens)):
            sub_tokens[i][0] = 'Ġ' + sub_tokens[i][0]
        sub_token_list = reduce(operator.add, sub_tokens)
        # print(sub_token_list)

    #  # TODO 叶子节点加上空白符号
    # if children is None or len(children) == 0:
    #     sub_token_list[0] = 'Ġ' + sub_token_list[0]

    if id == 0:
        # the root node is one of the tokenizer's special tokens
        root.token = sub_token_list[0]
        root.data = node
        # record the num of nodes for every children of root
        root_children_node_num = []
        for child in children:
            node_num = len(node_list)
            create_tree(atok, iters, root, child, node_list, sub_id_list,
                        leave_list, tokenizer, parent=root)
            root_children_node_num.append(len(node_list) - node_num)
        return root_children_node_num
    else:
        new_token = sub_token_list[0] if len(
            sub_token_list) <= 1 else tokenize_token
        new_node = AnyNode(
            id=id, token=new_token, data=node, parent=parent)

        if len(sub_token_list) > 1:
            sub_id_list.append(id)
            for sub_token in sub_token_list:
                id += 1
                AnyNode(id=id, token=sub_token, data=node, parent=new_node)
                node_list.append(sub_token)
                sub_id_list.append(id)

        for child in children:
            create_tree(atok, iters, root, child, node_list, sub_id_list,
                        leave_list, tokenizer, parent=new_node)


# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, tokenizer, src, tgt, edge_attrs, variable_token_list, variable_id_list, token_dicts, token_list, token_ids, parent_next=None):
    token = node.token
    node_id = tokenizer.convert_tokens_to_ids(token)
    assert isinstance(node_id, int)
    node_index_list.append(node_id)
    # node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    token_dicts[node.id] = node.token
    if not node.children and token not in python_special_tokens:
        token_list.append(token)
        token_ids.append(node.id)

    # if token in ['VariableDeclarator', 'MemberReference']:
    #     if node.children:  # some chidren are comprised by non-utf8 and will be removed
    #         child_token = node.children[0].token
    #         if child_token == tokenize_token:
    #             # print(node.children[0])
    #             child_token += ' - '+node.children[0].data
    #             print(child_token)
    #         variable_token_list.append(child_token)
    #         variable_id_list.append(node.children[0].id)

    children = node.children
    is_split_node = (node.token == tokenize_token)

    for idx, child in enumerate(children[:-1]):
        # edge_attr = NextSib if not is_split_node else SplitNextSib
        if not is_split_node:
            edge_attr = NextSib
            src.append(child.id)
            tgt.append(children[idx+1].id)
            edge_attrs.append(edge_attr)

            tgt.append(child.id)
            src.append(children[idx+1].id)
            edge_attrs.append(edge_attr)

        # if node.token == 'SwitchStatement' and child.token == 'SwitchStatementCase' and parent_next:
        #     if 'BreakStatement' in [i.token for i in child.children]:
        #         src.append(child.id)
        #         tgt.append(parent_next.id)
        #         edge_attrs.append(ControlOut)

        #         tgt.append(child.id)
        #         src.append(parent_next.id)
        #         edge_attrs.append(ControlOut)

    # # Control Flow
    # if node.token == 'ForStatement' or node.token == 'WhileStatement':
    #     # assert len(children) == 2 or (len(children)==3 and children[-1] == 'outer')
    #     if len(children) >= 2:
    #         src.append(children[1].id)
    #         tgt.append(children[0].id)
    #         edge_attrs.append(LoopNext)

    #         tgt.append(children[1].id)
    #         src.append(children[0].id)
    #         edge_attrs.append(LoopNext)

    # if node.token == 'IfStatement':
    #     assert (len(children) == 2 or len(children) == 3) or len(children) == 0
    #     if len(children) == 3:
    #         # assert children[0].token == 'BinaryOperation'
    #         src.append(children[0].id)
    #         tgt.append(children[-1].id)
    #         edge_attrs.append(ConditionNext)

    #         tgt.append(children[0].id)
    #         src.append(children[-1].id)
    #         edge_attrs.append(ConditionNext)

    for idx, child in enumerate(children):
        # parent_type = Parent if not is_split_node else SplitParent
        # child_type = Child if not is_split_node else SplitChild
        src.append(node.id)
        tgt.append(child.id)
        edge_attrs.append(AST_EDGE)
        src.append(child.id)
        tgt.append(node.id)
        edge_attrs.append(AST_EDGE)
        parent_next = children[idx+1] if (idx+1 < len(children)) else None
        get_node_and_edge(child, node_index_list, tokenizer,
                          src, tgt, edge_attrs, variable_token_list, variable_id_list, token_dicts, token_list, token_ids, parent_next)


# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(atok, ast, tokenizer=ast_tokenizer):
    node_list = []
    sub_id_list = []  # record the ids of node that can be divide into multple subtokens
    leave_list = []  # record the ids of leave
    new_tree = AnyNode(id=0, token=None, data=None)
    iter_children = asttokens.util.iter_children_func(ast)

    root_children_node_num = create_tree(
        atok, iter_children, new_tree, ast, node_list, sub_id_list, leave_list, tokenizer)

    # print('root_children_node_num', root_children_node_num)
    x = []
    edge_src = []
    edge_tgt = []
    edge_attrs = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    token_dicts = {}

    token_list, token_ids = [], []

    get_node_and_edge(new_tree, x, tokenizer, edge_src, edge_tgt, edge_attrs,
                      variable_token_list, variable_id_list, token_dicts, token_list, token_ids)
    # print(variable_token_list)
    # visiulize_tree(new_tree)
    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(
                variable_token_list[i], variable_id_list[i])
        else:
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_attrs.append(NextUse)

            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            edge_src.append(variable_id_list[i])
            edge_attrs.append(NextUse)
            variable_dict[variable_token_list[i]] = variable_id_list[i]

    for idx, item in enumerate(leave_list[:-1]):
        edge_src.append(item)
        edge_tgt.append(leave_list[idx+1])
        edge_attrs.append(NextToken)

        edge_tgt.append(item)
        edge_src.append(leave_list[idx+1])
        edge_attrs.append(NextToken)

    edge_index = [edge_src, edge_tgt]

    # TODO 第一个词不需要空格符号
    if token_list:
        token_list[0] = token_list[0].lstrip('Ġ')
        token_idx = tokenizer.convert_tokens_to_ids(token_list[0])
        x[token_ids[0]] = token_idx

    return x, edge_index, edge_attrs, root_children_node_num, token_list, token_ids


def get_graph_from_source(code, tokenizer=ast_tokenizer):
    atok, ast = parse_program(code)
    return get_pyg_data_from_ast(atok, ast, tokenizer)


def get_subgraph_node_num(root_children_node_num, divide_node_num, max_subgraph_num):
    subgraph_node_num = []
    node_sum = 0
    real_graph_num = 0
    for num in root_children_node_num:
        node_sum += num
        if node_sum >= divide_node_num:
            subgraph_node_num.append(node_sum)
            node_sum = 0

    subgraph_node_num.append(node_sum)
    real_graph_num = len(subgraph_node_num)

    if real_graph_num >= max_subgraph_num:
        return subgraph_node_num[: max_subgraph_num], max_subgraph_num

    # print(len(subgraph_node_num))
    # if the last subgraph node num < divide_node_num/2, then put the last subgraph to the second to last subgraph
    if subgraph_node_num[-1] < divide_node_num/2:
        subgraph_node_num[-2] = subgraph_node_num[-2] + subgraph_node_num[-1]
        subgraph_node_num[-1] = 0
        real_graph_num -= 1

    # zero padding for tensor transforming
    for _ in range(real_graph_num, max_subgraph_num):
        subgraph_node_num.append(0)

    return subgraph_node_num, real_graph_num


ConnectionError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/microsoft/codebert-base (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7fcf4690da20>: Failed to establish a new connection: [Errno -2] Name or service not known',))

In [None]:
code = '''
def add(a,b):
  if a>b:
    return a-b
  return a+b
  '''

In [None]:
get_graph_from_source(code, tokenizer=ast_tokenizer)


[]
Module
└── FunctionDef
    ├── If
    │   ├── Compare
    │   │   ├── Ġa
    │   │   └── Ġb
    │   └── Return
    │       └── BinOp
    │           ├── Ġa
    │           └── Ġb
    ├── Return
    │   └── BinOp
    │       ├── Ġa
    │       └── Ġb
    └── arguments
        ├── Ġa
        └── Ġb



([48720,
  50269,
  50265,
  102,
  741,
  1106,
  45448,
  10,
  741,
  42555,
  50284,
  10,
  741,
  42555,
  50284,
  10,
  741],
 [[0,
   1,
   2,
   5,
   5,
   13,
   1,
   2,
   3,
   4,
   2,
   3,
   2,
   4,
   1,
   5,
   6,
   9,
   5,
   6,
   7,
   8,
   6,
   7,
   6,
   8,
   5,
   9,
   9,
   10,
   11,
   12,
   10,
   11,
   10,
   12,
   1,
   13,
   13,
   14,
   15,
   16,
   14,
   15,
   14,
   16,
   3,
   4,
   4,
   7,
   7,
   8,
   8,
   11,
   11,
   12,
   12,
   15,
   15,
   16],
  [1,
   0,
   5,
   2,
   13,
   5,
   2,
   1,
   4,
   3,
   3,
   2,
   4,
   2,
   5,
   1,
   9,
   6,
   6,
   5,
   8,
   7,
   7,
   6,
   8,
   6,
   9,
   5,
   10,
   9,
   12,
   11,
   11,
   10,
   12,
   10,
   13,
   1,
   14,
   13,
   16,
   15,
   15,
   14,
   16,
   14,
   4,
   3,
   7,
   4,
   8,
   7,
   11,
   8,
   12,
   11,
   15,
   12,
   16,
   15]],
 [0,
  0,
  1,
  1,
  1,
  1,
  0,
  0,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  0,


In [176]:
''.join(get_graph_from_source(code, tokenizer=ast_tokenizer)[-2]).replace('Ġ',' ')


Module
└── FunctionDef
    ├── If
    │   ├── Assign
    │   │   ├── Attribute
    │   │   │   └── Ġspec
    │   │   └── Ġspec
    │   ├── If
    │   │   ├── Compare
    │   │   │   ├── Call
    │   │   │   │   ├── Dict
    │   │   │   │   ├── _<SplitNode>_
    │   │   │   │   │   ├── attr
    │   │   │   │   │   └── Ġget
    │   │   │   │   ├── _<SplitNode>_
    │   │   │   │   │   ├── '
    │   │   │   │   │   ├── __
    │   │   │   │   │   ├── __
    │   │   │   │   │   ├── dict
    │   │   │   │   │   └── Ġ'
    │   │   │   │   └── Ġspec
    │   │   │   └── Ġentry
    │   │   └── Return
    │   │       └── Ġfalse
    │   └── UnaryOp
    │       └── Call
    │           ├── _<SplitNode>_
    │           │   ├── instance
    │           │   └── Ġis
    │           ├── Ġspec
    │           └── Ġtype
    └── arguments
        ├── _<SplitNode>_
        │   ├── _
        │   ├── type
        │   └── Ġis
        ├── Ġentry
        └── Ġspec



"spec entry is_type isinstance spec type entry getattr spec '__dict__' false spec spec"

In [147]:
print(code)


def add(a,b):
    return test_m(a) + b



In [171]:
code = '''
def _must_skip(spec, entry, is_type): 
  if (not isinstance(spec, type)): 
    if (entry in getattr(spec, '__dict__', {})): 
      return False 
    spec = spec.__class__

'''