In [None]:
Java,
C, C++,

Go, C#,
Python, PHP

In [4]:
def expand_dict(d, seq):
    for i in d.keys():
        if isinstance(d[i], dict):
            seq.append(i)

In [1]:
# For Golang
'''
import os
import json
with os.popen('go run goblin.go -file "ast test/goblin.go"') as f:
    d = json.loads(f.read())
print(d)
'''

class Node_go(object):
    def __init__(self, node):
        self.node = node
        self.is_str = isinstance(node, str)  # str => 叶子节点 => 无孩子节点
        self.token = self.get_token(node)
        self.children = self.add_children()

    def is_leaf(self):
        if self.is_str:
            return True
        return len(self.node.children) == 0

    def get_token(self, node):
        if self.is_str:
            return self.node
        import _ast
        if isinstance(node, _ast.AST):
            token = node.__class__.__name__
        else:
            try:
                token = str(node)
            except:
                token = ''
        return token

    def iter_fields(self, node, exclude=[]):
        """
        Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
        that is present on *node*.
        """
        for field in [f for f in node._fields if f not in exclude]:
            try:
                yield getattr(node, field)
            except AttributeError:
                pass
            
    def ori_children(self, root, exclude=[]):
        import _ast
        #if isinstance(root, str):
        #    children = []
        if isinstance(root, list):
            children = root
        elif isinstance(root, _ast.AST):
            #children = [getattr(root, f) for f in root._fields if getattr(root, f)]
            children = self.iter_fields(root, exclude)
        elif isinstance(root, set):  # 猜测这条规则不会触发吧
            children = list(root)
        else:
            children = []
        return list(expand(children))

    def add_children(self):
        if self.is_str:  # str => 叶子节点 => 无孩子节点
            return []
        if self.token in EXCLUDE_FIELDS.keys():
            ef = EXCLUDE_FIELDS[self.token]
        else:
            ef = []
        children = self.ori_children(self.node, ef)
        return [Node_python(child) for child in children]


In [3]:
# For Golang
def expand(nested_list):  # 生成器，用于展开嵌套的list
    for item in nested_list:
        if isinstance(item, list):
            yield from expand(item)
        elif item:
            yield item
            
class Preprocessor_go():
    def __init__(self, vocab):
        self.vocab = vocab
        self.max_token = len(vocab)
        
    '''
    def file_to_ast(self, file):
        import subprocess
        import json
        proc = subprocess.Popen("vendor/bin/php-parse -j {}".format(file), shell=True, stdout=subprocess.PIPE)
        script_response = proc.stdout.read()
        ast = json.loads(script_response.decode('UTF-8'))
        return ast
    '''

    def file_to_ast(self, filename):
        import os
        assert os.path.isfile(filename)
        assert filename[-3:].lower() == '.go'
        import json
        with os.popen('go run goblin.go -file "' + filename + '"') as f:
            d = json.loads(f.read())
        ast = d['declarations']
        return ast
    
    def get_functions(self, ast):
        return [f for f in ast if f['type']=='function']

    def get_function_name(self, function_ast):
        return function_ast['name']['value']

    def get_token(self, node):
        if isinstance(node, dict):
            if 'type' in node.keys():
                token = node['type']
            elif 'kind' in node.keys():
                token = node['kind']
            else:
                token = node['value']
        elif isinstance(node, list): # 不可能发生吧？
            return ''
        else:
            try:
                token = str(node)
            except:
                token = ''
        return token
            
    def get_children(self, root, exclude=[]):
        import _ast
        #if isinstance(root, str):
        #    children = []
        if isinstance(root, list): # 猜测这条规则不会触发吧
            children = root
        elif isinstance(root, dict):
            children = [root[k] for k in root.keys()]
        elif isinstance(root, set):  # 猜测这条规则不会触发吧
            children = list(root)
        else:
            children = []
        return list(expand(children))


    def ast_to_block(self, ast):
        blocks = []
        self.get_blocks(ast, blocks)
        tree = []
        for node in blocks:  # 这里的每个node就是对应一行代码？
            btree = self.replaced_by_index(node)
            tree.append(btree)
        return tree
    
    def replaced_by_index(self, node):
        # 返回的形式：[node, children1, children2, ...]
        # 一个大list，每个children又是一个子list
        token = node.token
        result = [self.vocab[token].index if token in self.vocab else self.max_token]
        children = node.children
        for child in children:
            result.append(self.replaced_by_index(child))
        return result
    
    def ast_to_token_block(self, ast):
        blocks = []
        self.get_blocks(ast, blocks)
        tree = []
        for node in blocks:
            btree = self.replaced_by_token(node)
            tree.append(btree)
        return tree
    
    def replaced_by_token(self, node):
        #result = [node.token if node.token in self.vocab else 'UNKNOWN']
        result = [node.token]
        children = node.children
        for child in children:
            result.append(self.replaced_by_token(child))
        return result
    
    

    def get_blocks(self, node, block_seq):
        name = self.get_token(node)
        block_seq.append(Node_python(node))
        if name in EXCLUDE_FIELDS.keys():
            ef = [f for f in node._fields if f not in EXCLUDE_FIELDS[name]]
            children = self.get_children(node, ef)
            for child in children:
                self.get_blocks(child, block_seq)
    
    '''
    以下两个函数对ast进行先序遍历获得先序遍历的token序列，用于 word embedding 的训练
    '''
    def get_sequence(self, node, sequence):  # 获取先序遍历结果，同时为一些特殊代码块加上'End'
        token, children = self.get_token(node), self.get_children(node)
        sequence.append(token)

        for child in children:
            self.get_sequence(child, sequence)

    def trans_to_sequences(self, ast):
        # 这个用于生成token列表，用于 word embedding 的训练
        sequence = []
        self.get_sequence(ast, sequence)  # 从根节点开始先序遍历
        return sequence
    
    
    '''
    以下函数待定，只是为了打印出来看一下ast或者block之类的，方便调试
    '''
    def visit_block(self, block):
        pass
        
    def visit_token_block(self, block):
        pass
    
    def block_to_embedded(self):
        pass


In [9]:
import os
preprocessor = Preprocessor_go(['1'])

file_name = file_name = 'ast test/goblin.go'
file_name = os.path.normpath(file_name)

ast = preprocessor.file_to_ast(file_name)
fun_asts = preprocessor.get_functions(ast)
fun_names = [preprocessor.get_function_name(f) for f in fun_asts]

seqence = [preprocessor.trans_to_sequences(f) for f in fun_asts]
#blocks = [preprocessor.ast_to_block(f) for f in fun_asts]
#token_blocks = [preprocessor.ast_to_token_block(f) for f in fun_asts]

In [11]:
print(seqence[0])

['function', 'if', 'expression', 'statement', 'expression', 'call', 'expression', 'binary', 'expression', 'binary', 'call', 'identifier', 'expression', 'ident', 'ident', 'pos', 'identifier', 'ident', 'ident', 'String', 'expression', 'call', '+', 'STRING', 'literal', 'STRING', '": "', 'expression', '+', 'identifier', 'expression', 'identifier', 'ident', 'ident', 'reason', 'expression', 'identifier', 'expression', 'identifier', 'ident', 'ident', 'panic', 'expression', 'call', 'identifier', 'expression', 'identifier', 'ident', 'ident', 'ShouldPanic', 'block', 'define', 'statement', 'identifier', 'expression', 'identifier', 'ident', 'ident', 'res', 'identifier', 'expression', 'identifier', 'ident', 'ident', '_', 'call', 'composite', 'map', 'identifier', 'type', 'identifier', 'ident', 'ident', 'string', 'type', 'map', 'interface', 'type', 'interface', 'literal', 'composite', 'key-value', 'STRING', 'literal', 'STRING', '"error"', 'expression', 'key-value', 'composite', 'map', 'identifier', '

In [12]:
{'body': [{'body': [{'kind': 'statement',
     'type': 'expression',
     'value': {'arguments': [{'kind': 'binary',
        'left': {'kind': 'binary',
         'left': {'arguments': [],
          'ellipsis': False,
          'function': {'kind': 'expression',
           'qualifier': {'kind': 'ident', 'value': 'pos'},
           'type': 'identifier',
           'value': {'kind': 'ident', 'value': 'String'}},
          'kind': 'expression',
          'type': 'call'},
         'operator': '+',
         'right': {'kind': 'literal', 'type': 'STRING', 'value': '": "'},
         'type': 'expression'},
        'operator': '+',
        'right': {'kind': 'expression',
         'type': 'identifier',
         'value': {'kind': 'ident', 'value': 'reason'}},
        'type': 'expression'}],
      'ellipsis': False,
      'function': {'kind': 'expression',
       'type': 'identifier',
       'value': {'kind': 'ident', 'value': 'panic'}},
      'kind': 'expression',
      'type': 'call'}}],
   'condition': {'kind': 'expression',
    'type': 'identifier',
    'value': {'kind': 'ident', 'value': 'ShouldPanic'}},
   'else': {'body': [{'kind': 'statement',
      'left': [{'kind': 'expression',
        'type': 'identifier',
        'value': {'kind': 'ident', 'value': 'res'}},
       {'kind': 'expression',
        'type': 'identifier',
        'value': {'kind': 'ident', 'value': '_'}}],
      'right': [{'arguments': [{'declared': {'key': {'kind': 'type',
            'type': 'identifier',
            'value': {'kind': 'ident', 'value': 'string'}},
           'kind': 'type',
           'type': 'map',
           'value': {'incomplete': False,
            'kind': 'type',
            'methods': [],
            'type': 'interface'}},
          'kind': 'literal',
          'type': 'composite',
          'values': [{'key': {'kind': 'literal',
             'type': 'STRING',
             'value': '"error"'},
            'kind': 'expression',
            'type': 'key-value',
            'value': {'declared': {'key': {'kind': 'type',
               'type': 'identifier',
               'value': {'kind': 'ident', 'value': 'string'}},
              'kind': 'type',
              'type': 'map',
              'value': {'incomplete': False,
               'kind': 'type',
               'methods': [],
               'type': 'interface'}},
             'kind': 'literal',
             'type': 'composite',
             'values': [{'key': {'kind': 'literal',
                'type': 'STRING',
                'value': '"type"'},
               'kind': 'expression',
               'type': 'key-value',
               'value': {'kind': 'expression',
                'type': 'identifier',
                'value': {'kind': 'ident', 'value': 'typ'}}},
              {'key': {'kind': 'literal', 'type': 'STRING', 'value': '"info"'},
               'kind': 'expression',
               'type': 'key-value',
               'value': {'kind': 'expression',
                'type': 'identifier',
                'value': {'kind': 'ident', 'value': 'reason'}}}]}}]}],
        'ellipsis': False,
        'function': {'kind': 'expression',
         'qualifier': {'kind': 'ident', 'value': 'json'},
         'type': 'identifier',
         'value': {'kind': 'ident', 'value': 'Marshal'}},
        'kind': 'expression',
        'type': 'call'}],
      'type': 'define'},
     {'kind': 'statement',
      'type': 'expression',
      'value': {'arguments': [{'kind': 'expression',
         'type': 'identifier',
         'value': {'kind': 'ident', 'value': 'res'}}],
       'ellipsis': False,
       'function': {'field': {'kind': 'ident', 'value': 'Write'},
        'kind': 'expression',
        'target': {'kind': 'expression',
         'qualifier': {'kind': 'ident', 'value': 'os'},
         'type': 'identifier',
         'value': {'kind': 'ident', 'value': 'Stderr'}},
        'type': 'selector'},
       'kind': 'expression',
       'type': 'call'}}],
    'kind': 'statement',
    'type': 'block'},
   'init': None,
   'kind': 'statement',
   'type': 'if'},
  {'kind': 'statement',
   'type': 'expression',
   'value': {'arguments': [{'kind': 'literal', 'type': 'INT', 'value': '1'}],
    'ellipsis': False,
    'function': {'kind': 'expression',
     'qualifier': {'kind': 'ident', 'value': 'os'},
     'type': 'identifier',
     'value': {'kind': 'ident', 'value': 'Exit'}},
    'kind': 'expression',
    'type': 'call'}}],
 'comments': [],
 'kind': 'decl',
 'name': {'kind': 'ident', 'value': 'Perish'},
 'params': [{'declared-type': {'kind': 'type',
    'qualifier': {'kind': 'ident', 'value': 'token'},
    'type': 'identifier',
    'value': {'kind': 'ident', 'value': 'Position'}},
   'kind': 'field',
   'names': [{'kind': 'ident', 'value': 'pos'}],
   'tag': None},
  {'declared-type': {'kind': 'type',
    'type': 'identifier',
    'value': {'kind': 'ident', 'value': 'string'}},
   'kind': 'field',
   'names': [{'kind': 'ident', 'value': 'typ'}],
   'tag': None},
  {'declared-type': {'kind': 'type',
    'type': 'identifier',
    'value': {'kind': 'ident', 'value': 'string'}},
   'kind': 'field',
   'names': [{'kind': 'ident', 'value': 'reason'}],
   'tag': None}],
 'results': None,
 'type': 'function'}

{'body': [{'body': [{'kind': 'statement',
     'type': 'expression',
     'value': {'arguments': [{'kind': 'binary',
        'left': {'kind': 'binary',
         'left': {'arguments': [],
          'ellipsis': False,
          'function': {'kind': 'expression',
           'qualifier': {'kind': 'ident', 'value': 'pos'},
           'type': 'identifier',
           'value': {'kind': 'ident', 'value': 'String'}},
          'kind': 'expression',
          'type': 'call'},
         'operator': '+',
         'right': {'kind': 'literal', 'type': 'STRING', 'value': '": "'},
         'type': 'expression'},
        'operator': '+',
        'right': {'kind': 'expression',
         'type': 'identifier',
         'value': {'kind': 'ident', 'value': 'reason'}},
        'type': 'expression'}],
      'ellipsis': False,
      'function': {'kind': 'expression',
       'type': 'identifier',
       'value': {'kind': 'ident', 'value': 'panic'}},
      'kind': 'expression',
      'type': 'call'}}],
   'condit

In [1]:
# For Python
#Logic1 = ['For', 'AsyncFor', 'While', 'If', 'With', 'AsyncWith', 'Try']
#Logic2 = ['FunctionDef', 'AsyncFunctionDef']
EXCLUDE_FIELDS = {'FunctionDef': ['body', 'decorator_list'],
                  'AsyncFunctionDef': ['body', 'decorator_list'],
                  'For': ['body', 'orelse'],
                  'AsyncFor': ['body', 'orelse'],
                  'While': ['body', 'orelse'], 
                  'If': ['body', 'orelse'],
                  'With': ['body'],
                  'AsyncWith': ['body'],
                  'Try': ['body', 'handlers', 'orelse', 'finalbody'],
                  'ExceptHandler': ['body']
                 }

class Node_python(object):
    def __init__(self, node):
        self.node = node
        self.is_str = isinstance(node, str)  # str => 叶子节点 => 无孩子节点
        self.token = self.get_token(node)
        self.children = self.add_children()

    def is_leaf(self):
        if self.is_str:
            return True
        return len(self.node.children) == 0

    def get_token(self, node):
        if self.is_str:
            return self.node
        import _ast
        if isinstance(node, _ast.AST):
            token = node.__class__.__name__
        else:
            try:
                token = str(node)
            except:
                token = ''
        return token

    def iter_fields(self, node, exclude=[]):
        """
        Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
        that is present on *node*.
        """
        for field in [f for f in node._fields if f not in exclude]:
            try:
                yield getattr(node, field)
            except AttributeError:
                pass
            
    def ori_children(self, root, exclude=[]):
        import _ast
        #if isinstance(root, str):
        #    children = []
        if isinstance(root, list):
            children = root
        elif isinstance(root, _ast.AST):
            #children = [getattr(root, f) for f in root._fields if getattr(root, f)]
            children = self.iter_fields(root, exclude)
        elif isinstance(root, set):  # 猜测这条规则不会触发吧
            children = list(root)
        else:
            children = []
        return list(expand(children))

    def add_children(self):
        if self.is_str:  # str => 叶子节点 => 无孩子节点
            return []
        if self.token in EXCLUDE_FIELDS.keys():
            ef = EXCLUDE_FIELDS[self.token]
        else:
            ef = []
        children = self.ori_children(self.node, ef)
        return [Node_python(child) for child in children]


In [2]:
def fun_cleanout(code, fun_pos, mode=None, comment_open_close_pattern=None, comment_inline_pattern=None):
    if mode is None:  # 只清除结尾多余的空行
        pass
    elif mode == 'tailer_comment':  # 只清除结尾多余的空行和注释
        pass
    elif mode = 'comment': # 清除所有的注释
        pass
    elif mode = 'all':  # 清除所有注释和空行（紧凑形式）
        pass
    else:
        pass


def expand(nested_list):  # 生成器，用于展开嵌套的list
    for item in nested_list:
        if isinstance(item, list):
            #for sub_item in expand(item):
            #    yield sub_item
            yield from expand(item)
        elif item:
            yield item

# For Python
class Preprocessor_python():
    def __init__(self, vocab):
        self.vocab = vocab
        self.max_token = len(vocab)
    
    '''
    以下5个是从file到function_ast，也就是把一个文件中代码解析成AST并拆分出各个函数
    '''
    def file_to_code(self, filename):
        import os
        assert os.path.isfile(filename)
        assert filename[-3:].lower() == '.py'
        try:
            with open(filename, 'r', encoding="utf-8") as f:
                code = f.read()
            return code
        except:
            return None

    def file_to_ast(self, file):
        ast = self.code_to_ast(self.file_to_code(file))
        return ast
    
    def code_to_ast(self, code):
        from Lib import ast as python_ast
        ast = python_ast.parse(code)
        return ast
        
    def get_functions(self, ast):
        function_asts = []
        for c in ast.body:
            if c.__class__.__name__ == 'FunctionDef':
                function_asts.append(c)
            elif c.__class__.__name__ == 'ClassDef':
                function_asts.extend(self.get_functions(c))
        return function_asts

    def get_function_name(self, function_ast):
        return function_ast.name
    
    def extract_functions(self, code):
        '''
        输入代码，输出两个list：一个是函数的ast，一个是函数起止行
        '''
        import ast
        tree = None
        try:
            tree = ast.parse(code)
        except Exception as e:
            pass
            return None, None

        linecount = code.count("\n")
        if not code.endswith("\n"):
            linecount += 1

        function_nodes = []
        function_pos = []

        for index, stmt in enumerate(tree.body):
            if isinstance(stmt, ast.ClassDef):
                for idx, s in enumerate(stmt.body):
                    if isinstance(s, ast.FunctionDef):
                        start_lineno =  s.lineno
                        if idx == len(stmt.body)-1:
                            # this is the last one in stmt.body
                            if index == len(tree.body)-1:
                                # also the last stmt in tree.body
                                end_lineno = linecount
                            else:
                                # but not the last stmt in tree.body
                                end_lineno =  tree.body[index+1].lineno-1
                        else:
                            #not the last one in stmt.body
                            end_lineno = stmt.body[idx+1].lineno-1
                        function_nodes.append(s)
                        function_pos.append((start_lineno, end_lineno))

            if isinstance(stmt, ast.FunctionDef):
                start_lineno =  stmt.lineno
                if index == len(tree.body)-1:
                    # the last stmt in tree.body
                    end_lineno = linecount
                else:
                    end_lineno = tree.body[index+1].lineno-1
                function_nodes.append(s)
                function_pos.append((start_lineno, end_lineno))

        return function_nodes, function_pos
    

    '''
    以下两个函数分别获取节点token和孩子节点，是后续其他操作的基础
    '''
    def get_token(self, node):
        import _ast
        if isinstance(node, str):
            token = node
        elif isinstance(node, _ast.AST):
            token = node.__class__.__name__
        else:
            try:
                token = str(node)
            except:
                token = ''
        return token
        
    def iter_fields(self, node, exclude=[]):
        """
        Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
        that is present on *node*.
        """
        for field in [f for f in node._fields if f not in exclude]:
            try:
                yield getattr(node, field)
            except AttributeError:
                pass
            
    def get_children(self, root, exclude=[]):
        import _ast
        #if isinstance(root, str):
        #    children = []
        if isinstance(root, list):
            children = root
        elif isinstance(root, _ast.AST):
            #children = [getattr(root, f) for f in root._fields if getattr(root, f)]
            children = self.iter_fields(root, exclude)
        elif isinstance(root, set):  # 猜测这条规则不会触发吧
            children = list(root)
        else:
            children = []
        return list(expand(children))

    
    '''
    以下函数，都是把ast变成ASTNN的输入结构，但是前两个使用token index，后两个使用token本身
    主要用前两个
    后两个是为了打印出来方便调试
    '''
    def ast_to_block(self, ast):
        blocks = []
        self.get_blocks(ast, blocks)
        tree = []
        for node in blocks:  # 这里的每个node就是对应一行代码？
            btree = self.replaced_by_index(node)
            tree.append(btree)
        return tree
    
    def replaced_by_index(self, node):
        # 返回的形式：[node, children1, children2, ...]
        # 一个大list，每个children又是一个子list
        token = node.token
        result = [self.vocab[token].index if token in self.vocab else self.max_token]
        children = node.children
        for child in children:
            result.append(self.replaced_by_index(child))
        return result
    
    def ast_to_token_block(self, ast):
        blocks = []
        self.get_blocks(ast, blocks)
        tree = []
        for node in blocks:
            btree = self.replaced_by_token(node)
            tree.append(btree)
        return tree
    
    def replaced_by_token(self, node):
        #result = [node.token if node.token in self.vocab else 'UNKNOWN']
        result = [node.token]
        children = node.children
        for child in children:
            result.append(self.replaced_by_token(child))
        return result
    
    
    '''
    最复杂的东东，根据当前的 node 获得一个或多个 Node_python 并添加到 block_seq 里面
    block_seq.append 的必定是一个 Node_python，可以猜测 Node_python 是一个把 ast.Node 转化成自定义的节点类
    '''
    def get_blocks(self, node, block_seq):
        name = self.get_token(node)
        block_seq.append(Node_python(node))
        if name in EXCLUDE_FIELDS.keys():
            ef = [f for f in node._fields if f not in EXCLUDE_FIELDS[name]]
            children = self.get_children(node, ef)
            for child in children:
                self.get_blocks(child, block_seq)
    
    '''
    以下两个函数对ast进行先序遍历获得先序遍历的token序列，用于 word embedding 的训练
    '''
    def get_sequence(self, node, sequence):  # 获取先序遍历结果，同时为一些特殊代码块加上'End'
        token, children = self.get_token(node), self.get_children(node)
        sequence.append(token)

        for child in children:
            self.get_sequence(child, sequence)

    def trans_to_sequences(self, ast):
        # 这个用于生成token列表，用于 word embedding 的训练
        sequence = []
        self.get_sequence(ast, sequence)  # 从根节点开始先序遍历
        return sequence
    
    
    '''
    以下函数待定，只是为了打印出来看一下ast或者block之类的，方便调试
    '''
    def visit_block(self, block):
        pass
        
    def visit_token_block(self, block):
        pass
    
    def block_to_embedded(self):
        pass
    
    

In [3]:
import _ast
def iter_fields(node):
    """
    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
    that is present on *node*.
    """
    for field in node._fields:
        try:
            yield field, getattr(node, field)
        except AttributeError:
            pass


def iter_child_nodes(node):
    """
    Yield all direct child nodes of *node*, that is, all fields that are nodes
    and all items of fields that are lists of nodes.
    """
    import _ast
    for name, field in iter_fields(node):
        if isinstance(field, _ast.AST):
            yield field
        elif isinstance(field, list):
            for item in field:
                if isinstance(item, _ast.AST):
                    yield item
                    
def walk(node):
    """
    Recursively yield all descendant nodes in the tree starting at *node*
    (including *node* itself), in no specified order.  This is useful if you
    only want to modify nodes in place and don't care about the context.
    """
    from collections import deque
    todo = deque([node])
    while todo:
        node = todo.popleft()
        todo.extend(iter_child_nodes(node))
        yield node
        
        
def ast_dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
    """
    Return a formatted dump of the tree in node.  This is mainly useful for
    debugging purposes.  If annotate_fields is true (by default),
    the returned string will show the names and the values for fields.
    If annotate_fields is false, the result string will be more compact by
    omitting unambiguous field names.  Attributes such as line
    numbers and column offsets are not dumped by default.  If this is wanted,
    include_attributes can be set to true.  If indent is a non-negative
    integer or string, then the tree will be pretty-printed with that indent
    level. None (the default) selects the single line representation.
    """
    def _format(node, level=0):
        if indent is not None:
            level += 1
            prefix = '\n' + indent * level
            sep = ',\n' + indent * level
        else:
            prefix = ''
            sep = ', '
        if isinstance(node, _ast.AST):
            cls = type(node)
            args = []
            allsimple = True
            keywords = annotate_fields
            for name in node._fields:
                try:
                    value = getattr(node, name)
                except AttributeError:
                    keywords = True
                    continue
                if value is None and getattr(cls, name, ...) is None:
                    keywords = True
                    continue
                value, simple = _format(value, level)
                allsimple = allsimple and simple
                if keywords:
                    args.append('%s=%s' % (name, value))
                else:
                    args.append(value)
            if include_attributes and node._attributes:
                for name in node._attributes:
                    try:
                        value = getattr(node, name)
                    except AttributeError:
                        continue
                    if value is None and getattr(cls, name, ...) is None:
                        continue
                    value, simple = _format(value, level)
                    allsimple = allsimple and simple
                    args.append('%s=%s' % (name, value))
            if allsimple and len(args) <= 3:
                return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args
            return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False
        elif isinstance(node, list):
            if not node:
                return '[]', True
            return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False
        return repr(node), True

    if not isinstance(node, _ast.AST):
        raise TypeError('expected AST, got %r' % node.__class__.__name__)
    if indent is not None and not isinstance(indent, str):
        indent = ' ' * indent
    return _format(node)[0]

In [3]:
from Lib import ast as python_ast
import os
from gensim.models.word2vec import Word2Vec
root = 'data/'
lang = 'java'

word2vec = Word2Vec.load(root+lang+"/train/embedding/node_w2v_128").wv
preprocessor = Preprocessor_python(word2vec.vocab)

file_name = file_name = 'ast test/python test 1.py'
file_name = os.path.normpath(file_name)

code = preprocessor.file_to_code(file_name)
ast = preprocessor.code_to_ast(code)
fun_asts = preprocessor.get_functions(ast)
fun_names = [preprocessor.get_function_name(f) for f in fun_asts]


seqence = [preprocessor.trans_to_sequences(f) for f in fun_asts]
blocks = [preprocessor.ast_to_block(f) for f in fun_asts]
token_blocks = [preprocessor.ast_to_token_block(f) for f in fun_asts]

nodes, pos = preprocessor.extract_functions(code)


In [9]:
pos

[(11, 13),
 (14, 59),
 (60, 93),
 (94, 160),
 (162, 194),
 (195, 200),
 (201, 204),
 (205, 213),
 (214, 315),
 (316, 366),
 (367, 409),
 (410, 527),
 (539, 548),
 (557, 560),
 (561, 568)]

In [4]:
print(len(token_blocks[-3]))
token_blocks[-3]

7


[['FunctionDef', ['get_device'], ['arguments']],
 ['If',
  ['Call',
   ['Attribute',
    ['Attribute', ['Name', ['torch'], ['Load']], ['cuda'], ['Load']],
    ['is_available'],
    ['Load']]]],
 ['If',
  ['Compare',
   ['Call',
    ['Attribute',
     ['Attribute', ['Name', ['torch'], ['Load']], ['cuda'], ['Load']],
     ['get_device_name'],
     ['Load']],
    ['Constant']],
   ['Eq'],
   ['Constant', ['GeForce GT 730']]]],
 ['Assign', ['Name', ['device'], ['Store']], ['Constant', ['cpu']]],
 ['Assign', ['Name', ['device'], ['Store']], ['Constant', ['cuda']]],
 ['Assign', ['Name', ['device'], ['Store']], ['Constant', ['cpu']]],
 ['Return',
  ['Call',
   ['Attribute', ['Name', ['torch'], ['Load']], ['device'], ['Load']],
   ['Name', ['device'], ['Load']]]]]

In [17]:
#for i in walk(fun_asts[-3]):
#    print(i, preprocessor.get_token(i), i._fields)
print(ast_dump(fun_asts[-3], indent=4))
#python_ast

FunctionDef(
    name='get_device',
    args=arguments(
        posonlyargs=[],
        args=[],
        vararg=None,
        kwonlyargs=[],
        kw_defaults=[],
        kwarg=None,
        defaults=[]),
    body=[
        If(
            test=Call(
                func=Attribute(
                    value=Attribute(
                        value=Name(id='torch', ctx=Load()),
                        attr='cuda',
                        ctx=Load()),
                    attr='is_available',
                    ctx=Load()),
                args=[],
                keywords=[]),
            body=[
                If(
                    test=Compare(
                        left=Call(
                            func=Attribute(
                                value=Attribute(
                                    value=Name(id='torch', ctx=Load()),
                                    attr='cuda',
                                    ctx=Load()),
                                attr='get_de

In [9]:
print(seqence[-3])

['FunctionDef', 'get_device', 'arguments', 'If', 'Call', 'Attribute', 'Attribute', 'Name', 'torch', 'Load', 'cuda', 'Load', 'is_available', 'Load', 'If', 'Compare', 'Call', 'Attribute', 'Attribute', 'Name', 'torch', 'Load', 'cuda', 'Load', 'get_device_name', 'Load', 'Constant', 'Eq', 'Constant', 'GeForce GT 730', 'Assign', 'Name', 'device', 'Store', 'Constant', 'cpu', 'Assign', 'Name', 'device', 'Store', 'Constant', 'cuda', 'Assign', 'Name', 'device', 'Store', 'Constant', 'cpu', 'Return', 'Call', 'Attribute', 'Name', 'torch', 'Load', 'device', 'Load', 'Name', 'device', 'Load']


In [10]:
# For C++
class Node_cpp(object):
    def __init__(self, node, add_children = True):
        self.node = node
        self.is_str = isinstance(self.node, str)
        self.token = self.get_token()
        if add_children:
            self.children = self.add_children()
        else:
            self.children = []

    def is_leaf(self):
        if self.is_str:
            return True
        return len(self.node.children()) == 0

    def get_token(self, lower=True):
        if self.is_str:
            return self.node
        name = self.node.__class__.__name__
        token = name
        is_name = False
        if self.is_leaf():
            attr_names = self.node.attr_names
            if attr_names:
                if 'names' in attr_names:
                    token = self.node.names[0]
                elif 'name' in attr_names:
                    token = self.node.name
                    is_name = True
                else:
                    token = self.node.value
            else:
                token = name
        else:
            if name == 'TypeDecl':
                token = self.node.declname
            if self.node.attr_names:
                attr_names = self.node.attr_names
                if 'op' in attr_names:
                    if self.node.op[0] == 'p':
                        token = self.node.op[1:]
                    else:
                        token = self.node.op
        if token == None:
            token = name
        if lower and is_name:
            token = token.lower()
        return token

    def add_children(self):
        if self.is_str:
            return []
        children = self.node.children()
        if self.token in ['FuncDef', 'If', 'While', 'DoWhile']:
            return [Node_cpp(children[0][1])]
        elif self.token == 'For':
            return [Node_cpp(children[c][1]) for c in range(0, len(children)-1)]
        else:
            return [Node_cpp(child) for _, child in children]

In [18]:
# For C++
class Preprocessor_cpp():
    def __init__(self, vocab):
        self.vocab = vocab
        self.max_token = len(vocab)
    
    '''
    以下5个是从file到function_ast，也就是把一个文件中代码解析成AST并拆分出各个函数
    '''
    def file_to_code(self, filename):
        import os
        assert os.path.isfile(filename)
        assert filename[-4:].lower()=='.cpp'
        try:
            with open(filename, 'r', encoding="utf-8") as f:
                code = f.read()
            return code
        except:
            return None

    def file_to_ast(self, file):
        #ast = self.code_to_ast(self.file_to_code(file))
        from pycparser import parse_file
        ast = parse_file(file, use_cpp=False,
            cpp_path='cpp'),
            #cpp_args=r'-Iutils/fake_libc_include')
        return ast
    
    def code_to_ast(self, code):
        from pycparser import c_parser
        parser = c_parser.CParser()
        ast = parser.parse(code)
        return ast
        '''
        try:
            from pycparser import c_parser
            parser = c_parser.CParser()
            ast = parser.parse(code)
            return ast
        except:
            return None
        '''


    def get_functions(self, ast):
        # 只要 function
        return [func_ast for func_ast in ast.ext if func_ast.__class__.__name__ == 'FuncDef']

    def get_function_name(self, function_ast):
        return function_ast.decl.name
    
    def extract_functions(self, code):
        '''
        输入代码，输出两个list：一个是函数的ast，一个是函数起止行
        '''
        import clang
        import clang.cindex
        from clang.cindex import CursorKind
        
        
        function_pos = []
        function_nodes = []
        try:
            index = clang.cindex.Index.create()
            tu = index.parse(path='0.cpp', unsaved_files=[('0.cpp',code)])
        except Exception as e:
            pass
            return None, None

        AST_root_node= tu.cursor
        file_string_split = code.split('\n')
        linecount = code.count("\n")
        if not code.endswith("\n"):
            linecount += 1
        ast_list = list(AST_root_node.get_children())

        for idx, cur in enumerate(ast_list):
            if cur.kind == CursorKind.FUNCTION_DECL:
                start_lineno = cur.location.line
                if idx == len(ast_list) - 1:
                    end_lineno = linecount
                else:
                    end_lineno = ast_list[idx+1].location.line - 1
                function_nodes.append(cur)
                function_pos.append((start_lineno, end_lineno))
                
            elif cur.kind == CursorKind.CLASS_DECL:
                ast_list_in_class = list(cur.get_children())
                for idx_in_class, cur_in_class in enumerate(ast_list_in_class):
                    if cur_in_class.kind == CursorKind.CXX_METHOD:
                        start_lineno = cur_in_class.location.line
                        if idx_in_class == len(ast_list_in_class) - 1: 
                            if idx == len(ast_list) - 1:
                                end_lineno = linecount
                            else:
                                end_lineno = ast_list[idx+1].location.line - 1
                            for lineno in range(end_lineno-1, 0, -1):
                                if file_string_split[lineno] and file_string_split[lineno][0]=='}':
                                    end_lineno = lineno
                                    break
                        else:
                            end_lineno = ast_list_in_class[idx_in_class+1].location.line - 1
                        function_nodes.append(cur_in_class)
                        function_pos.append((start_lineno, end_lineno))

        return function_nodes, function_pos

    
    '''
    以下函数，都是把ast变成ASTNN的输入结构，但是前两个使用token index，后两个使用token本身
    主要用前两个
    后两个是为了打印出来方便调试
    '''
    def ast_to_block(self, ast):
        blocks = []
        self.get_blocks(ast, blocks)
        tree = []
        for node in blocks:  # 这里的每个node就是对应一行代码？
            btree = self.replaced_by_index(node)
            tree.append(btree)
        return tree
    
    def replaced_by_index(self, node):
        # 返回的形式：[node, children1, children2, ...]
        # 一个大list，每个children又是一个子list
        token = node.token
        result = [self.vocab[token].index if token in self.vocab else self.max_token]
        children = node.children
        for child in children:
            result.append(self.replaced_by_index(child))
        return result
    
    def ast_to_token_block(self, ast):
        blocks = []
        self.get_blocks(ast, blocks)
        tree = []
        for node in blocks:
            btree = self.replaced_by_token(node)
            tree.append(btree)
        return tree
    
    def replaced_by_token(self, node):
        result = [node.token if node.token in self.vocab else 'UNKNOWN']
        children = node.children
        for child in children:
            result.append(self.replaced_by_token(child))
        return result
    
    
    '''
    
    '''
    def get_blocks(self, node, block_seq):
        children = node.children()
        name = node.__class__.__name__
        if name in ['FuncDef', 'If', 'For', 'While', 'DoWhile']:
            block_seq.append(Node_cpp(node))
            if name != 'For':
                skip = 1
            else:
                skip = len(children) - 1

            for i in range(skip, len(children)):
                child = children[i][1]
                if child.__class__.__name__ not in ['FuncDef', 'If', 'For', 'While', 'DoWhile', 'Compound']:
                    block_seq.append(Node_cpp(child))
                self.get_blocks(child, block_seq)
        elif name == 'Compound':
            block_seq.append(Node_cpp(name))
            for _, child in node.children():
                if child.__class__.__name__ not in ['If', 'For', 'While', 'DoWhile']:
                    block_seq.append(Node_cpp(child))
                self.get_blocks(child, block_seq)
            block_seq.append(Node_cpp('End'))
        else:
            for _, child in node.children():
                self.get_blocks(child, block_seq)

    
    '''
    以下两个函数对ast进行先序遍历获得先序遍历的token序列，用于 word embedding 的训练
    '''
    def get_sequence(self, node, sequence):
        current = Node_cpp(node, False)
        sequence.append(current.get_token())
        for _, child in node.children():
            self.get_sequence(child, sequence)
        if current.get_token().lower() == 'compound':
            sequence.append('End')
            # compound 代码段后面要加 End

    def trans_to_sequences(self, ast):
        # 这个用于生成token列表，用于 word embedding 的训练
        sequence = []
        self.get_sequence(ast, sequence)  # 从根节点开始先序遍历
        return sequence
    
    
    '''
    以下函数待定，只是为了打印出来看一下ast或者block之类的，方便调试
    '''
    def visit_block(self, block):
        pass
        
    def visit_token_block(self, block):
        block.show()
    
    def block_to_embedded(self):
        pass

In [31]:
import os
import pandas as pd
from gensim.models.word2vec import Word2Vec
root = 'data/'
lang = 'java'

word2vec = Word2Vec.load(root+lang+"/train/embedding/node_w2v_128").wv
preprocessor = Preprocessor_cpp(word2vec.vocab)

file_name = 'ast test/cpp test.cpp'
file_name = os.path.normpath(file_name)

#source = pd.read_pickle('D:/GitHub/astnn/classification/data/programs.pkl')

In [32]:
code = preprocessor.file_to_code(file_name)
#code = source[1][0]
#ast = preprocessor.code_to_ast(code)
#fun_asts = preprocessor.get_functions(ast)
#fun_names = [preprocessor.get_function_name(f) for f in fun_asts]


#seqence = [preprocessor.trans_to_sequences(f) for f in fun_asts]
#blocks = [preprocessor.ast_to_block(f) for f in fun_asts]
#token_blocks = [preprocessor.ast_to_token_block(f) for f in fun_asts]

nodes, pos = preprocessor.extract_functions(code)

In [68]:
import clang
import clang.cindex
from clang.cindex import CursorKind

#index = clang.cindex.Index.create()
#tu = index.parse('test.cpp')

def preorder_travers_AST(cursor):
    for cur in cursor.get_children():
        #do something
        print(cur.spelling, cur.kind, cur.location.line, cur.location.column)
        #print(cur.spelling, cur.kind == CursorKind.FUNCTION_DECL)
        #print(cur.spelling, cur.location.line, cur.location.column)
        
        #if cur.kind == CursorKind.FUNCTION_DECL:
        #    print("function: %s, start line: %d, end line: d" % (cur.spelling, cur.location.line, ))
            
        
        preorder_travers_AST(cur)

with open('ast test/cpp test 2.cpp', encoding='utf-8') as f:
    code = f.read()
    
index = clang.cindex.Index.create()
tu = index.parse(path='anything.cpp', unsaved_files=[('anything.cpp',code)])
#print(tu)
#print(tu.spelling, tu.kind, tu.location.line, tu.location.column)

AST_root_node = tu.cursor  #cursor根节点
#print(AST_root_node)
#print(AST_root_node.spelling, AST_root_node.kind, AST_root_node.location.line, AST_root_node.location.column)

preorder_travers_AST(AST_root_node)

#A = getFunctions(code, None, None)

main CursorKind.FUNCTION_DECL 4 5
 CursorKind.COMPOUND_STMT 6 1
 CursorKind.RETURN_STMT 13 5
 CursorKind.INTEGER_LITERAL 13 12
A CursorKind.VAR_DECL 16 5
 CursorKind.BINARY_OPERATOR 16 9
 CursorKind.INTEGER_LITERAL 16 9
 CursorKind.INTEGER_LITERAL 16 11
mainA CursorKind.FUNCTION_DECL 18 5
 CursorKind.COMPOUND_STMT 19 1
 CursorKind.DECL_STMT 20 5
n CursorKind.VAR_DECL 20 9
i CursorKind.VAR_DECL 20 11
shuzu CursorKind.VAR_DECL 20 13
 CursorKind.INTEGER_LITERAL 20 19
count1 CursorKind.VAR_DECL 20 24
 CursorKind.INTEGER_LITERAL 20 31
count3 CursorKind.VAR_DECL 20 33
 CursorKind.INTEGER_LITERAL 20 40
count2 CursorKind.VAR_DECL 20 42
 CursorKind.INTEGER_LITERAL 20 49
count4 CursorKind.VAR_DECL 20 51
 CursorKind.INTEGER_LITERAL 20 58
count5 CursorKind.VAR_DECL 20 60
 CursorKind.INTEGER_LITERAL 20 67
count6 CursorKind.VAR_DECL 20 69
 CursorKind.INTEGER_LITERAL 20 76
 CursorKind.WHILE_STMT 22 5
 CursorKind.BINARY_OPERATOR 22 11
n CursorKind.UNEXPOSED_EXPR 22 11
n CursorKind.DECL_REF_EXPR 22 11


In [86]:
func0 = [i for i in AST_root_node.get_children()][0]
func_body = [i for i in func0.get_children()][0]
func1 = func_body.kind
[i for i in func_body.get_children()][0].kind



CursorKind.RETURN_STMT

In [37]:
nodes[0].kind

CursorKind.FUNCTION_DECL

In [5]:
fun_names

['mainA', 'mainB']

In [6]:
print(seqence)

[['FuncDef', 'Decl', 'FuncDecl', 'mainA', 'int', 'Compound', 'Decl', 'n', 'int', 'Decl', 'i', 'int', 'Decl', 'ArrayDecl', 'shuzu', 'int', '111', 'Decl', 'count1', 'int', '0', 'Decl', 'count3', 'int', '0', 'Decl', 'count2', 'int', '0', 'Decl', 'count4', 'int', '0', 'Decl', 'count5', 'int', '0', 'Decl', 'count6', 'int', '0', 'FuncCall', 'scanf', 'ExprList', '"%d"', '&', 'n', 'While', '>=', 'n', '100', 'Compound', '=', 'n', '-', 'n', '100', '++', 'count1', 'End', 'While', '>=', 'n', '50', 'Compound', '=', 'n', '-', 'n', '50', '++', 'count2', 'End', 'While', '>=', 'n', '20', 'Compound', '=', 'n', '-', 'n', '20', '++', 'count3', 'End', 'While', '>=', 'n', '10', 'Compound', '=', 'n', '-', 'n', '10', '++', 'count4', 'End', 'While', '>=', 'n', '5', 'Compound', '=', 'n', '-', 'n', '5', '++', 'count5', 'End', 'While', '>=', 'n', '1', 'Compound', '=', 'n', '-', 'n', '1', '++', 'count6', 'End', 'FuncCall', 'printf', 'ExprList', '"%d\\n%d\\n%d\\n%d\\n%d\\n%d"', 'count1', 'count2', 'count3', 'count4

In [7]:
blocks

[[[2957, [2957, [2957, [2957, [22]]]]],
  [2957],
  [2957, [140, [22]]],
  [2957, [21, [22]]],
  [2957, [2957, [2957, [22]], [2957]]],
  [2957, [2957, [22]], [18]],
  [2957, [2957, [22]], [18]],
  [2957, [2957, [22]], [18]],
  [2957, [2957, [22]], [18]],
  [2957, [2957, [22]], [18]],
  [2957, [2957, [22]], [18]],
  [2957, [2957], [2957, [2957], [156, [140]]]],
  [2957, [190, [140], [410]]],
  [2957],
  [12, [140], [39, [140], [410]]],
  [48, [2957]],
  [9],
  [2957, [190, [140], [1103]]],
  [2957],
  [12, [140], [39, [140], [1103]]],
  [48, [2957]],
  [9],
  [2957, [190, [140], [655]]],
  [2957],
  [12, [140], [39, [140], [655]]],
  [48, [2957]],
  [9],
  [2957, [190, [140], [302]]],
  [2957],
  [12, [140], [39, [140], [302]]],
  [48, [2957]],
  [9],
  [2957, [190, [140], [241]]],
  [2957],
  [12, [140], [39, [140], [241]]],
  [48, [2957]],
  [9],
  [2957, [190, [140], [26]]],
  [2957],
  [12, [140], [39, [140], [26]]],
  [48, [2957]],
  [9],
  [2957,
   [1643],
   [2957, [2957], [2957

In [8]:
token_blocks

[[['UNKNOWN', ['UNKNOWN', ['UNKNOWN', ['UNKNOWN', ['int']]]]],
  ['UNKNOWN'],
  ['UNKNOWN', ['n', ['int']]],
  ['UNKNOWN', ['i', ['int']]],
  ['UNKNOWN', ['UNKNOWN', ['UNKNOWN', ['int']], ['UNKNOWN']]],
  ['UNKNOWN', ['UNKNOWN', ['int']], ['0']],
  ['UNKNOWN', ['UNKNOWN', ['int']], ['0']],
  ['UNKNOWN', ['UNKNOWN', ['int']], ['0']],
  ['UNKNOWN', ['UNKNOWN', ['int']], ['0']],
  ['UNKNOWN', ['UNKNOWN', ['int']], ['0']],
  ['UNKNOWN', ['UNKNOWN', ['int']], ['0']],
  ['UNKNOWN', ['UNKNOWN'], ['UNKNOWN', ['UNKNOWN'], ['&', ['n']]]],
  ['UNKNOWN', ['>=', ['n'], ['100']]],
  ['UNKNOWN'],
  ['=', ['n'], ['-', ['n'], ['100']]],
  ['++', ['UNKNOWN']],
  ['End'],
  ['UNKNOWN', ['>=', ['n'], ['50']]],
  ['UNKNOWN'],
  ['=', ['n'], ['-', ['n'], ['50']]],
  ['++', ['UNKNOWN']],
  ['End'],
  ['UNKNOWN', ['>=', ['n'], ['20']]],
  ['UNKNOWN'],
  ['=', ['n'], ['-', ['n'], ['20']]],
  ['++', ['UNKNOWN']],
  ['End'],
  ['UNKNOWN', ['>=', ['n'], ['10']]],
  ['UNKNOWN'],
  ['=', ['n'], ['-', ['n'], ['10']]

In [9]:
print(code)



int A = 1+2;
    
int mainA()
{
    int n,i,shuzu[111],count1=0,count3=0,count2=0,count4=0,count5=0,count6=0;
    scanf("%d",&n);
    while(n>=100){
                  n=n-100;
                  count1++;
                  }
                      while(n>=50){
                  n=n-50;
                  count2++;
                  }
                      while(n>=20){
                  n=n-20;
                  count3++;
                  }
                      while(n>=10){
                  n=n-10;
                  count4++;
                  }    while(n>=5){
                  n=n-5;
                  count5++;
                  }
                      while(n>=1){
                  n=n-1;
                  count6++;
                  }
               printf("%d\n%d\n%d\n%d\n%d\n%d",count1,count2,count3,count4,count5,count6);
               return 0;
               }

float B = 2/3;

int mainB()
{
	int num,j,i,an[6]={100,50,20,10,5,1};
	cin>>num;
	cout<<num/an[0]<<endl;
	for(i=1;

In [4]:
def expand(nested_list):  # 生成器，用于展开嵌套的list
    for item in nested_list:
        if isinstance(item, list):
            #for sub_item in expand(item):
            #    yield sub_item
            yield from expand(item)
        elif item:
            yield item

In [1]:
# For java
Logic1 = ['IfStatement', 'ForStatement', 'WhileStatement', 'DoStatement', 'SwitchStatement']
Logic2 = ['MethodDeclaration', 'ConstructorDeclaration']

class Node_java(object):
    def __init__(self, node):
        self.node = node
        self.is_str = isinstance(self.node, str)  # str => 叶子节点 => 无孩子节点
        self.token = self.get_token(node)
        self.children = self.add_children()

    def is_leaf(self):
        if self.is_str:
            return True
        return len(self.node.children) == 0

    def get_token(self, node):
        from javalang.ast import Node
        if isinstance(node, str):
            token = node
        elif isinstance(node, set): # 为什么？为什么是set的时候就是Modifier?
            token = 'Modifier'  # 访问修饰符，比如 public，private，static
        elif isinstance(node, Node):
            token = node.__class__.__name__
        else:
            token = ''
        return token

    def ori_children(self, root):
        from javalang.ast import Node
        if isinstance(root, Node):
            if self.token in Logic2:  # 这两个比较特殊？
                children = root.children[:-1]  # 最后一个丢弃？为什么？最后一个是什么？
            else:
                children = root.children
        elif isinstance(root, set):
            children = list(root)
        else:
            children = []

        return list(expand(children))

    def add_children(self):
        if self.is_str:  # str => 叶子节点 => 无孩子节点
            return []
        children = self.ori_children(self.node)
        
        # 下面是嵌套转化，把所有 javalang.ast.Node 全部转化成 Node_java
        if self.token in Logic1:
            return [Node_java(children[0])]
        elif self.token in Logic2:
            return [Node_java(child) for child in children]
        else:
            return [Node_java(child) for child in children if self.get_token(child) not in Logic1]  # What???

In [2]:
class Preprocessor_java():
    def __init__(self, vocab):
        self.vocab = vocab
        self.max_token = len(vocab)
    
    '''
    以下5个是从file到function_ast，也就是把一个文件中代码解析成AST并拆分出各个函数
    '''
    def file_to_code(self, filename):
        import os
        assert os.path.isfile(filename)
        assert filename[-5:].lower()=='.java'
        try:
            with open(filename, 'r', encoding="utf-8") as f:
                code = f.read()
            return code
        except:
            return None

    def file_to_ast(self, file):
        ast = self.code_to_ast(self.file_to_code(file))
        return ast
    
    def code_to_ast(self, code):
        import javalang
        try:
            ast = javalang.parse.parse(code)
            #seq = trans2tokenseq(ast)
            return ast
        except:
            try:
                tokens = javalang.tokenizer.tokenize(code)
                parser = javalang.parser.Parser(tokens)
                ast = parser.parse_member_declaration()
                return ast
            except:
                return None
            
    def get_functions(self, ast):
        # 提取所有的function
        import javalang
        fun_list = list(ast.filter(javalang.tree.ConstructorDeclaration))
        fun_list.extend(ast.filter(javalang.tree.MethodDeclaration))
        return [f[1] for f in fun_list]
        # f[0] 好像是 path, f[1] 才是 node

    def get_function_name(self, function_ast):
        return function_ast.name
    
    def extract_functions(self, code):
        '''
        输入代码，输出两个list：一个是函数的ast，一个是函数起止行
        '''
    
        import re
        import javalang
        import itertools

        re_string = re.escape("\"") + '.*?' + re.escape("\"")

        comment_inline_p = '//'
        comment_inline = re.escape(comment_inline_p)
        comment_inline_pattern = comment_inline + '.*?$'

        function_nodes = []
        function_pos = []

        tree = None

        try:
            tree = javalang.parse.parse(code)
        except Exception as e:
            return None, None
            #logging.warning("File " + file_path + " cannot be parsed. (1)" + str(e))

        file_string_split = code.split('\n')
        nodes = itertools.chain(tree.filter(
            javalang.tree.ConstructorDeclaration), tree.filter(javalang.tree.MethodDeclaration))

        for path, node in nodes:
            (start_lineno, b) = node.position
            
            ##################################
            end_lineno = start_lineno
            closed = 0
            openned = 0

            for line in file_string_split[start_lineno-1:]:
                if len(line.strip()) == 0:
                    continue
                
                line_re = re.sub(re_string, '', line, flags=re.DOTALL)
                # 先删字符串再删注释
                line_re = re.sub(comment_inline_pattern, '', line_re, flags=re.MULTILINE)

                closed += line_re.count('}')
                openned += line_re.count('{')

                if closed == openned:
                    break
                else:
                    end_lineno += 1
            ###################################

            function_pos.append((start_lineno, end_lineno))
            function_nodes.append(node)


        return function_nodes, function_pos
    
    
    '''
    以下两个函数分别获取节点token和孩子节点，是后续其他操作的基础
    '''
    def get_token(self, node):
        from javalang.ast import Node
        if isinstance(node, str):
            token = node
        elif isinstance(node, set):  # 为什么？为什么是set的时候就是Modifier?
            token = 'Modifier'#node.pop()
        elif isinstance(node, Node):
            token = node.__class__.__name__  # 直接用类名
        else:
            token = ''

        return token
    
    def get_children(self, root):
        from javalang.ast import Node
        if isinstance(root, Node):
            children = root.children
        elif isinstance(root, set):  # 按照get_token，这里应该是Modifier，就把Modifier的子节点直接转成list
            children = list(root)
        else:
            children = []

        return list(expand(children))
    

    
    '''
    以下函数，都是把ast变成ASTNN的输入结构，但是前两个使用token index，后两个使用token本身
    主要用前两个
    后两个是为了打印出来方便调试
    '''
    def ast_to_block(self, ast):
        blocks = []
        self.get_blocks(ast, blocks)
        tree = []
        for node in blocks:  # 这里的每个node就是对应一行代码？
            btree = self.replaced_by_index(node)
            tree.append(btree)
        return tree
    
    def replaced_by_index(self, node):
        # 返回的形式：[node, children1, children2, ...]
        # 一个大list，每个children又是一个子list
        token = node.token
        result = [self.vocab[token].index if token in self.vocab else self.max_token]
        children = node.children
        for child in children:
            result.append(self.replaced_by_index(child))
        return result
    
    def ast_to_token_block(self, ast):
        blocks = []
        self.get_blocks(ast, blocks)
        tree = []
        for node in blocks:
            btree = self.replaced_by_token(node)
            tree.append(btree)
        return tree
    
    def replaced_by_token(self, node):
        result = [node.token if node.token in self.vocab else 'UNKNOWN']
        children = node.children
        for child in children:
            result.append(self.replaced_by_token(child))
        return result
    
    
    '''
    最复杂的东东，根据当前的 node 获得一个或多个 Node_java 并添加到 block_seq 里面
    block_seq.append 的必定是一个 Node_java，可以猜测 Node_java 是一个把 javalang.ast.Node 转化成自定义的节点类
    '''
    def get_blocks(self, node, block_seq):
        name, children = self.get_token(node), self.get_children(node)
        
        # 分4种情况，前3种又进一步对孩子进行细分
        if name in Logic2:
            block_seq.append(Node_java(node))
            body = node.body
            for child in body:
                if self.get_token(child) not in Logic1 and not hasattr(child, 'block'):
                    block_seq.append(Node_java(child))
                else:
                    self.get_blocks(child, block_seq)
                    
        elif name in Logic1:
            block_seq.append(Node_java(node))
            for child in children[1:]:
                token = self.get_token(child)
                if not hasattr(node, 'block') and token not in Logic1 + ['BlockStatement']:
                    block_seq.append(Node_java(child))
                else:
                    self.get_blocks(child, block_seq)
                block_seq.append(Node_java('End'))
                
        elif name == 'BlockStatement' or hasattr(node, 'block'):
            block_seq.append(Node_java(name))
            for child in children:
                if self.get_token(child) not in Logic1:
                    block_seq.append(Node_java(child))
                else:
                    self.get_blocks(child, block_seq)

        else:
            for child in children:
                self.get_blocks(child, block_seq)

    
    '''
    以下两个函数对ast进行先序遍历获得先序遍历的token序列，用于 word embedding 的训练
    '''
    def get_sequence(self, node, sequence):  # 获取先序遍历结果，同时为一些特殊代码块加上'End'
        token, children = self.get_token(node), self.get_children(node)
        sequence.append(token)

        for child in children:
            self.get_sequence(child, sequence)

        if token in Logic1:
            sequence.append('End')  
            # 因为 Logic1 之后会紧接着一个 'BlockStatement' (表示左大括号)
            # 所以在后面要加上一个 'End' (表示右大括号)

    def trans_to_sequences(self, ast):
        # 这个用于生成token列表，用于 word embedding 的训练
        sequence = []
        self.get_sequence(ast, sequence)  # 从根节点开始先序遍历
        return sequence
    
    
    '''
    以下函数待定，只是为了打印出来看一下ast或者block之类的，方便调试
    '''
    def visit_block(self, block):
        pass
        
    def visit_token_block(self, block):
        pass
    
    def block_to_embedded(self):
        pass
    

In [5]:
import os
from gensim.models.word2vec import Word2Vec
root = 'data/'
lang = 'java'

word2vec = Word2Vec.load(root+lang+"/train/embedding/node_w2v_128").wv
preprocessor = Preprocessor_java(word2vec.vocab)

file_name = 'data/other_java/abdera/adapters/filesystem/src/main/java/org/apache/abdera/protocol/server/adapters/filesystem/FilesystemAdapter.java'
file_name = os.path.normpath(file_name)

code = preprocessor.file_to_code(file_name)
ast = preprocessor.code_to_ast(code)
fun_asts = preprocessor.get_functions(ast)
fun_names = [preprocessor.get_function_name(f) for f in fun_asts]


seqence = [preprocessor.trans_to_sequences(f) for f in fun_asts]
blocks = [preprocessor.ast_to_block(f) for f in fun_asts]
token_blocks = [preprocessor.ast_to_token_block(f) for f in fun_asts]

nodes, pos = preprocessor.extract_functions(code)

#embedding_seq = 

In [8]:
pos

[(57, 60),
 (62, 76),
 (78, 89),
 (91, 108),
 (110, 125),
 (127, 137),
 (139, 146),
 (148, 157),
 (159, 176),
 (178, 193),
 (195, 197),
 (199, 204),
 (206, 212),
 (214, 231),
 (234, 236)]

In [5]:
print(seqence[2])

['MethodDeclaration', 'Modifier', 'private', 'addPagingLinks', 'FormalParameter', 'ReferenceType', 'RequestContext', 'request', 'FormalParameter', 'ReferenceType', 'Feed', 'feed', 'FormalParameter', 'BasicType', 'int', 'currentpage', 'FormalParameter', 'BasicType', 'int', 'count', 'LocalVariableDeclaration', 'ReferenceType', 'Map', 'TypeArgument', 'ReferenceType', 'String', 'TypeArgument', 'ReferenceType', 'Object', 'VariableDeclarator', 'params', 'ClassCreator', 'ReferenceType', 'HashMap', 'TypeArgument', 'ReferenceType', 'String', 'TypeArgument', 'ReferenceType', 'Object', 'StatementExpression', 'MethodInvocation', 'params', 'Literal', '"count"', 'MemberReference', 'count', 'put', 'StatementExpression', 'MethodInvocation', 'params', 'Literal', '"page"', 'BinaryOperation', '+', 'MemberReference', 'currentpage', 'Literal', '1', 'put', 'LocalVariableDeclaration', 'ReferenceType', 'String', 'VariableDeclarator', 'next', 'MethodInvocation', 'paging_template', 'MemberReference', 'params', 

In [6]:
print(fun_names[2], len(blocks[2]))
blocks[2]


addPagingLinks 18


[[25,
  [20, [89]],
  [2957],
  [19, [4, [2957]], [122]],
  [19, [4, [2957]], [2957]],
  [19, [15, [22]], [2957]],
  [19, [15, [22]], [201]]],
 [7,
  [4, [266], [58, [4, [16]]], [58, [4, [145]]]],
  [6, [304], [13, [4, [367], [58, [4, [16]]], [58, [4, [145]]]]]]],
 [3, [1, [304], [2, [2957]], [0, [201]], [119]]],
 [3, [1, [304], [2, [2957]], [5, [8], [0, [2957]], [2, [26]]], [119]]],
 [7, [4, [16]], [6, [188], [1, [2957], [0, [304]], [2957]]]],
 [3,
  [10,
   [0, [188]],
   [1, [122], [1, [0, [188]], [2238]], [1, [63]], [2957]],
   [12]]],
 [3, [1, [2957], [0, [188]], [2, [2957]], [2957]]],
 [14, [5, [70], [0, [2957]], [2, [18]]]],
 [11],
 [3, [1, [304], [2, [2957]], [5, [39], [0, [2957]], [2, [26]]], [119]]],
 [7, [4, [16]], [6, [2558], [1, [2957], [0, [304]], [2957]]]],
 [3,
  [10,
   [0, [2558]],
   [1, [122], [1, [0, [2558]], [2238]], [1, [63]], [2957]],
   [12]]],
 [3, [1, [2957], [0, [2558]], [2, [2957]], [2957]]],
 [9],
 [3, [1, [304], [2, [2957]], [2, [18]], [119]]],
 [7, [4, [

In [7]:
token_blocks[2]

[['MethodDeclaration',
  ['Modifier', ['private']],
  ['UNKNOWN'],
  ['FormalParameter', ['ReferenceType', ['UNKNOWN']], ['request']],
  ['FormalParameter', ['ReferenceType', ['UNKNOWN']], ['UNKNOWN']],
  ['FormalParameter', ['BasicType', ['int']], ['UNKNOWN']],
  ['FormalParameter', ['BasicType', ['int']], ['count']]],
 ['LocalVariableDeclaration',
  ['ReferenceType',
   ['Map'],
   ['TypeArgument', ['ReferenceType', ['String']]],
   ['TypeArgument', ['ReferenceType', ['Object']]]],
  ['VariableDeclarator',
   ['params'],
   ['ClassCreator',
    ['ReferenceType',
     ['HashMap'],
     ['TypeArgument', ['ReferenceType', ['String']]],
     ['TypeArgument', ['ReferenceType', ['Object']]]]]]],
 ['StatementExpression',
  ['MethodInvocation',
   ['params'],
   ['Literal', ['UNKNOWN']],
   ['MemberReference', ['count']],
   ['put']]],
 ['StatementExpression',
  ['MethodInvocation',
   ['params'],
   ['Literal', ['UNKNOWN']],
   ['BinaryOperation',
    ['+'],
    ['MemberReference', ['UNKNOW

In [8]:
print(code)
#print(fun_asts[0])

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  The ASF licenses this file to You
 * under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.  For additional information regarding
 * copyright in this work, please see the NOTICE file in the top level
 * directory of this distribution.
 */
package org.apache.abdera.protocol.server.adapters.filesystem;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.i