In [4]:
from transformers import RobertaTokenizer
import javalang
from javalang.ast import Node
from anytree import AnyNode
import torch
from torch_geometric.data import Data, DataLoader
import torch.nn as nn
from torch_geometric.nn.glob import GlobalAttention
from torch_geometric.nn import MessagePassing, GatedGraphConv, GCNConv, global_mean_pool
import yaml

In [2]:
test_func = '''public int getLarger(int a, int b) {
		a = Math.abs(a);
		b = Math.abs(b);
		if(a > b) {
			return a;
		}else {
			return b;
		}
	}
'''

In [2]:
checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)

In [3]:
tokenizer.tokenize('getLarger')

['get', 'L', 'arger']

In [5]:
javalang_special_tokens = ['CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration',
                            'ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type',
                            'BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair',
                            'ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration',
                            'ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration',
                            'VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement',
                            'WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement',
                            'ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement',
                            'BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter',
                            'SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression',
                            'BinaryOperation','Cast','MethodReference','LambdaExpression','Primary','Literal','This',
                            'MemberReference','Invocation','ExplicitConstructorInvocation','SuperConstructorInvocation',
                            'MethodInvocation','SuperMethodInvocation','SuperMemberReference','ArraySelector','ClassReference',
                            'VoidClassReference','Creator','ArrayCreator','ClassCreator','InnerClassCreator','EnumBody',
                            'EnumConstantDeclaration','AnnotationMethod', 'Modifier']

In [6]:
special_tokens_dict = {'additional_special_tokens': javalang_special_tokens}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [8]:
num_added_toks

71

In [None]:
tokenizer.vocab_size

In [6]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = 'None'
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [7]:
#  generate tree for AST Node
def create_tree(root, node, node_list, sub_id_list, leave_list, tokenizer, parent=None):
    id = len(node_list)
    node_list.append(node)
    token, children = get_token(node), get_child(node)

    if children == []:
        # print('this is a leaf:', token, id)
        leave_list.append(id)

    # Use roberta.tokenizer to generate subtokens
    # If a token can be divided into multiple(>1) subtokens, the first subtoken will be set as the previous node, 
    # and the other subtokens will be set as its new children   
    sub_token_list = tokenizer.tokenize(token)
    
    if id == 0:
        root.token = sub_token_list[0] # the root node is one of the tokenizer's special tokens
        root.data = node
        # record the num of nodes for every children of root
        root_children_node_num = []
        for child in children:
            node_num = len(node_list)
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=root)
            root_children_node_num.append(len(node_list) - node_num)        
        return root_children_node_num
    else:
        # print(sub_token_list)
        new_node = AnyNode(id=id, token=sub_token_list[0], data=node, parent=parent)
        if len(sub_token_list) > 1:
            sub_id_list.append(id)
            for sub_token in sub_token_list[1:]:
                id += 1
                AnyNode(id=id, token=sub_token, data=node, parent=new_node)
                node_list.append(sub_token)
                sub_id_list.append(id)
        
        for child in children:
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=new_node)
    # print(token, id)
    

In [8]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    # print('id', node.id, 'token', node.token)
    node_index_list.append(tokenizer.convert_tokens_to_ids(token))
    # node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        variable_token_list.append(node.children[0].token)
        variable_id_list.append(node.children[0].id)   
    for child in node.children:
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list)

In [9]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, tokenizer):
    node_list = []
    sub_id_list = [] # record the ids of node that can be divide into multple subtokens
    leave_list = [] # record the ids of leave 
    new_tree = AnyNode(id=0, token=None, data=None)
    root_children_node_num = create_tree(new_tree, ast, node_list, sub_id_list, leave_list, tokenizer)
    # print('root_children_node_num', root_children_node_num)
    x = []
    edge_src = []
    edge_tgt = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, tokenizer, edge_src, edge_tgt, variable_token_list, variable_id_list)

    ast_edge_num = len(edge_src)
    edge_attr = [[0] for _ in range(ast_edge_num)]
    # set subtoken edge type to 2
    for i in range(len(edge_attr)):
        if edge_src[i] in sub_id_list and edge_tgt[i] in sub_id_list:
            edge_attr[i] = [2]
    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    dataflow_edge_num = len(edge_src) - ast_edge_num

    # add next-token edge
    nexttoken_edge_num = len(leave_list)-1
    for i in range(nexttoken_edge_num):
        edge_src.append(leave_list[i])
        edge_tgt.append(leave_list[i+1])
        edge_src.append(leave_list[i+1])
        edge_tgt.append(leave_list[i])

    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    
    # set data flow edge type to 3
    for _ in range(nexttoken_edge_num * 2):
        edge_attr.append([3])
    
    return x, edge_index, edge_attr, root_children_node_num

In [10]:
x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(ast=parse_program(func=test_func), tokenizer=tokenizer)

In [11]:
root_children_node_num

[2, 2, 3, 4, 4, 10, 10, 15]

In [12]:
divide_node_num = 10
max_subgraph_num = int(100/divide_node_num)

def get_subgraph_node_num(root_children_node_num, divide_node_num):
    subgraph_node_num = []
    node_sum = 0
    real_graph_num = 0
    for num in root_children_node_num:
        node_sum += num
        if node_sum >= divide_node_num:
            subgraph_node_num.append(node_sum)
            node_sum = 0    
    real_graph_num = len(subgraph_node_num)

    # if the last subgraph node num < divide_node_num, then put the last subgraph to the second to last subgraph
    if subgraph_node_num[-1] < divide_node_num:
        subgraph_node_num[-2] = subgraph_node_num[-2] + subgraph_node_num[-1]
        subgraph_node_num[-1] = 0
        real_graph_num -= 1

    # zero padding for tensor transforming
    for _ in range(real_graph_num, max_subgraph_num):
        subgraph_node_num.append(0)
    
    return subgraph_node_num, real_graph_num

In [13]:

subgraph_node_num, real_graph_num = get_subgraph_node_num(root_children_node_num, divide_node_num)
subgraph_node_num

[11, 14, 10, 15, 0, 0, 0, 0, 0, 0]

In [14]:
real_graph_num

4

In [15]:
data = Data(
    x = torch.tensor(x, dtype=torch.long),
    edge_index = torch.tensor(edge_index, dtype=torch.long),
    edge_attr=torch.tensor(edge_attr, dtype=torch.long),
    subgraph_node_num=torch.tensor(subgraph_node_num, dtype=torch.long),
    real_graph_num=torch.tensor(real_graph_num, dtype=torch.long)
)

In [16]:
data_list = [data for _ in range(32)]
loader = DataLoader(data_list, batch_size=32)

In [17]:
batch = next(iter(loader))

In [26]:
class SequenceGNNEncoder(torch.nn.Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num,
                    decoder_input_size, device):
        super(SequenceGNNEncoder, self).__init__()
        self.device = device
        self.embed = nn.Embedding(vocab_len, graph_embedding_size)
        self.edge_embed = nn.Embedding(4, 1) # only two edge types to be set weights, which are AST edge and data flow edge
        self.ggnnlayer = GatedGraphConv(graph_embedding_size, gnn_layers_num)
        self.mlp_gate = nn.Sequential(
            nn.Linear(graph_embedding_size, 300), nn.Sigmoid(), nn.Linear(300, 1), nn.Sigmoid())
        self.pool = GlobalAttention(gate_nn=self.mlp_gate)
        self.divide_node_num = divide_node_num
        self.lstm = nn.LSTM(input_size=graph_embedding_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers_num)
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_layers_num = lstm_layers_num
        self.fc = nn.Linear(graph_embedding_size + lstm_hidden_size, decoder_input_size)

    def subgraph_forward(self, x, edge_index, edge_attr, batch):
        if type(edge_attr) == type(None):
            edge_weight = None
        else:
            edge_weight = self.edge_embed(edge_attr)
            edge_weight = edge_weight.squeeze(1)
        x = self.ggnnlayer(x, edge_index, edge_weight)
        return self.pool(x, batch=batch)
    
    # partitioning multiple subgraphs by dynamic allocating edges
    def partition_graph(self, x, edge_index, edge_attr, subgraph_node_num, real_graph_num, ptr):        
        nodes_list = [] # record all nodes number for each subgraph in total batch
        subgraph_num = max(real_graph_num)

        batch_size = subgraph_node_num.size(0)
        start_node_num = [1 for _ in range(batch_size)]
        for i in range(subgraph_num):
            subgraph_nodes_list = []
            for j in range(batch_size):
                if subgraph_node_num[j][i] != 0:
                    for k in range(ptr[j]+start_node_num[j], ptr[j]+start_node_num[j]+subgraph_node_num[j][i]):
                        subgraph_nodes_list.append(k)
                    start_node_num[j] += subgraph_node_num[j][i]
            nodes_list.append(subgraph_nodes_list)

        # only count the edge whose target node in subgraph
        sub_edge_src = [[] for _ in range(subgraph_num)]
        sub_edge_tgt = [[] for _ in range(subgraph_num)]
        sub_edge_attr = [[] for _ in range(subgraph_num)]
        # print('nodes_list', nodes_list)
        node_num = len(x)
        node_subgraph_index = [0 for _ in range(node_num)] # use a list to store the subgraph numbers for all nodes
        for i in range(len(nodes_list)):
            for node in nodes_list[i]:
                node_subgraph_index[node] = i

        for i in range(len(edge_index[1])):
            src = edge_index[0][i].item()
            tgt = edge_index[1][i].item()
            sub_edge_src[node_subgraph_index[tgt]].append(src)
            sub_edge_tgt[node_subgraph_index[tgt]].append(tgt)
            sub_edge_attr[node_subgraph_index[tgt]].append(edge_attr[i].item())
        edge_index_list = []
        edge_attr_list = []
        for i in range(subgraph_num):
            edge_index_list.append(torch.tensor([sub_edge_src[i], sub_edge_tgt[i]], dtype=torch.long))
            edge_attr_list.append(torch.tensor(sub_edge_attr[i], dtype=torch.long))
        print('nodes_list', nodes_list)
        return edge_index_list, edge_attr_list  

    def forward(self, x, edge_index, edge_attr, subgraph_node_num, real_graph_num, batch, ptr):
        edge_index_list, edge_attr_list = self.partition_graph(x, edge_index, edge_attr, subgraph_node_num, real_graph_num, ptr)
        print('edge_index_list', edge_index_list)
        print('edge_attr_list', edge_attr_list)
        x = self.embed(x)
        x = x.squeeze(1)
        subgraph_pool_list = [
            self.subgraph_forward(x, edge_index_list[i].to(self.device), edge_attr_list[i].to(self.device), batch)
            for i in range(len(edge_index_list))
        ]
        graph_pool = self.subgraph_forward(x, edge_index, edge_attr, batch)
        print('graph_pool', graph_pool.shape)
        subgraph_pool_seq = torch.stack(subgraph_pool_list)
        print('subgraph_pool_seq', subgraph_pool_seq.shape)
        h0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        c0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        subgraph_output, (_, _) = self.lstm(subgraph_pool_seq, (h0, c0))
        return self.fc(torch.cat((subgraph_output[-1], graph_pool), dim=1))

In [19]:
config_file = 'config_dgnn.yml'
config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# data source
TRAIN_DIR = config['middle_data']['train']
VALID_DIR = config['middle_data']['valid']
TEST_DIR = config['middle_data']['test']


# training parameter
batch_size = config['training']['batch_size']
num_epoches = config['training']['num_epoches']
lr = config['training']['lr']
decay_ratio = config['training']['lr']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']

# model design
graph_embedding_size = config['model']['graph_embedding_size']
lstm_hidden_size = config['model']['lstm_hidden_size']
# divide_node_num = config['model']['divide_node_num']
gnn_layers_num = config['model']['gnn_layers_num']
lstm_layers_num = config['model']['lstm_layers_num']
decoder_input_size = config['model']['decoder_input_size']
decoder_hidden_size = config['model']['decoder_hidden_size']
decoder_num_layers = config['model']['decoder_num_layers']
decoder_rnn_dropout = config['model']['decoder_rnn_dropout']

# logs
info_prefix = config['logs']['info_prefix']

In [27]:
device = torch.device('cuda:0')

model_args = {
    'vocab_len': 51000,
    'graph_embedding_size': graph_embedding_size,
    'gnn_layers_num': gnn_layers_num,
    'lstm_layers_num': lstm_layers_num,
    'lstm_hidden_size': lstm_hidden_size,
    'divide_node_num': divide_node_num,
    'decoder_input_size': decoder_input_size,
    'device': device
}

model = SequenceGNNEncoder(**model_args).to(device)

In [21]:
batch = batch.to(device)
subgraph_node_num = torch.stack(torch.split(batch.subgraph_node_num, max_subgraph_num))
subgraph_node_num

tensor([[11, 14, 10, 15,  0,  0,  0,  0,  0,  0],
        [11, 14, 10, 15,  0,  0,  0,  0,  0,  0]], device='cuda:0')

In [22]:
real_graph_num = torch.stack(torch.split(batch.real_graph_num, 1))
real_graph_num

tensor([[4],
        [4]], device='cuda:0')

In [23]:
max(real_graph_num)

tensor([4], device='cuda:0')

In [24]:
batch.ptr

tensor([  0,  51, 102], device='cuda:0')

In [28]:
model(batch.x, batch.edge_index, batch.edge_attr, subgraph_node_num, real_graph_num, batch.batch, batch.ptr).shape

nodes_list [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62], [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76], [26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86], [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101]]
edge_index_list [tensor([[ 0,  1,  1,  2,  0,  3,  3,  4,  0,  5,  5,  6,  5,  7,  0,  8,  8,  9,
          9, 10,  8, 11, 12, 16, 26, 36,  2,  4,  4,  5,  5, 10, 10, 11, 14, 51,
         52, 52, 53, 51, 54, 54, 55, 51, 56, 56, 57, 56, 58, 51, 59, 59, 60, 60,
         61, 59, 62, 63, 67, 77, 87, 53, 55, 55, 56, 56, 61, 61, 62, 65],
        [ 1,  0,  2,  1,  3,  0,  4,  3,  5,  0,  6,  5,  7,  5,  8,  0,  9,  8,
         10,  9, 11,  8,  0,  0,  0,  0,  4,  2,  5,  4, 10,  5, 11, 10, 11, 52,
         51, 53, 52, 54, 51, 55, 54, 56, 51, 57, 56, 58, 56, 59, 51, 60, 59, 61,
         6

torch.Size([2, 300])

Use regex to match all special tokens in Javalang

In [None]:
ast_node = '''
from .ast import Node

# ------------------------------------------------------------------------------

class CompilationUnit(Node):
    attrs = ("package", "imports", "types")

class Import(Node):
    attrs = ("path", "static", "wildcard")

class Documented(Node):
    attrs = ("documentation",)

class Declaration(Node):
    attrs = ("modifiers", "annotations")

class TypeDeclaration(Declaration, Documented):
    attrs = ("name", "body")

    @property
    def fields(self):
        return [decl for decl in self.body if isinstance(decl, FieldDeclaration)]

    @property
    def methods(self):
        return [decl for decl in self.body if isinstance(decl, MethodDeclaration)]

    @property
    def constructors(self):
        return [decl for decl in self.body if isinstance(decl, ConstructorDeclaration)]

class PackageDeclaration(Declaration, Documented):
    attrs = ("name",)

class ClassDeclaration(TypeDeclaration):
    attrs = ("type_parameters", "extends", "implements")

class EnumDeclaration(TypeDeclaration):
    attrs = ("implements",)

    @property
    def fields(self):
        return [decl for decl in self.body.declarations if isinstance(decl, FieldDeclaration)]

    @property
    def methods(self):
        return [decl for decl in self.body.declarations if isinstance(decl, MethodDeclaration)]

class InterfaceDeclaration(TypeDeclaration):
    attrs = ("type_parameters", "extends",)

class AnnotationDeclaration(TypeDeclaration):
    attrs = ()

# ------------------------------------------------------------------------------

class Type(Node):
    attrs = ("name", "dimensions",)

class BasicType(Type):
    attrs = ()

class ReferenceType(Type):
    attrs = ("arguments", "sub_type")

class TypeArgument(Node):
    attrs = ("type", "pattern_type")

# ------------------------------------------------------------------------------

class TypeParameter(Node):
    attrs = ("name", "extends")

# ------------------------------------------------------------------------------

class Annotation(Node):
    attrs = ("name", "element")

class ElementValuePair(Node):
    attrs = ("name", "value")

class ElementArrayValue(Node):
    attrs = ("values",)

# ------------------------------------------------------------------------------

class Member(Documented):
    attrs = ()

class MethodDeclaration(Member, Declaration):
    attrs = ("type_parameters", "return_type", "name", "parameters", "throws", "body")

class FieldDeclaration(Member, Declaration):
    attrs = ("type", "declarators")

class ConstructorDeclaration(Declaration, Documented):
    attrs = ("type_parameters", "name", "parameters", "throws", "body")

# ------------------------------------------------------------------------------

class ConstantDeclaration(FieldDeclaration):
    attrs = ()

class ArrayInitializer(Node):
    attrs = ("initializers",)

class VariableDeclaration(Declaration):
    attrs = ("type", "declarators")

class LocalVariableDeclaration(VariableDeclaration):
    attrs = ()

class VariableDeclarator(Node):
    attrs = ("name", "dimensions", "initializer")

class FormalParameter(Declaration):
    attrs = ("type", "name", "varargs")

class InferredFormalParameter(Node):
    attrs = ('name',)

# ------------------------------------------------------------------------------

class Statement(Node):
    attrs = ("label",)

class IfStatement(Statement):
    attrs = ("condition", "then_statement", "else_statement")

class WhileStatement(Statement):
    attrs = ("condition", "body")

class DoStatement(Statement):
    attrs = ("condition", "body")

class ForStatement(Statement):
    attrs = ("control", "body")

class AssertStatement(Statement):
    attrs = ("condition", "value")

class BreakStatement(Statement):
    attrs = ("goto",)

class ContinueStatement(Statement):
    attrs = ("goto",)

class ReturnStatement(Statement):
    attrs = ("expression",)

class ThrowStatement(Statement):
    attrs = ("expression",)

class SynchronizedStatement(Statement):
    attrs = ("lock", "block")

class TryStatement(Statement):
    attrs = ("resources", "block", "catches", "finally_block")

class SwitchStatement(Statement):
    attrs = ("expression", "cases")

class BlockStatement(Statement):
    attrs = ("statements",)

class StatementExpression(Statement):
    attrs = ("expression",)

# ------------------------------------------------------------------------------

class TryResource(Declaration):
    attrs = ("type", "name", "value")

class CatchClause(Statement):
    attrs = ("parameter", "block")

class CatchClauseParameter(Declaration):
    attrs = ("types", "name")

# ------------------------------------------------------------------------------

class SwitchStatementCase(Node):
    attrs = ("case", "statements")

class ForControl(Node):
    attrs = ("init", "condition", "update")

class EnhancedForControl(Node):
    attrs = ("var", "iterable")

# ------------------------------------------------------------------------------

class Expression(Node):
    attrs = ()

class Assignment(Expression):
    attrs = ("expressionl", "value", "type")

class TernaryExpression(Expression):
    attrs = ("condition", "if_true", "if_false")

class BinaryOperation(Expression):
    attrs = ("operator", "operandl", "operandr")

class Cast(Expression):
    attrs = ("type", "expression")

class MethodReference(Expression):
    attrs = ("expression", "method", "type_arguments")

class LambdaExpression(Expression):
    attrs = ('parameters', 'body')

# ------------------------------------------------------------------------------

class Primary(Expression):
    attrs = ("prefix_operators", "postfix_operators", "qualifier", "selectors")

class Literal(Primary):
    attrs = ("value",)

class This(Primary):
    attrs = ()

class MemberReference(Primary):
    attrs = ("member",)

class Invocation(Primary):
    attrs = ("type_arguments", "arguments")

class ExplicitConstructorInvocation(Invocation):
    attrs = ()

class SuperConstructorInvocation(Invocation):
    attrs = ()

class MethodInvocation(Invocation):
    attrs = ("member",)

class SuperMethodInvocation(Invocation):
    attrs = ("member",)

class SuperMemberReference(Primary):
    attrs = ("member",)

class ArraySelector(Expression):
    attrs = ("index",)

class ClassReference(Primary):
    attrs = ("type",)

class VoidClassReference(ClassReference):
    attrs = ()

# ------------------------------------------------------------------------------

class Creator(Primary):
    attrs = ("type",)

class ArrayCreator(Creator):
    attrs = ("dimensions", "initializer")

class ClassCreator(Creator):
    attrs = ("constructor_type_arguments", "arguments", "body")

class InnerClassCreator(Creator):
    attrs = ("constructor_type_arguments", "arguments", "body")

# ------------------------------------------------------------------------------

class EnumBody(Node):
    attrs = ("constants", "declarations")

class EnumConstantDeclaration(Declaration, Documented):
    attrs = ("name", "arguments", "body")

class AnnotationMethod(Declaration):
    attrs = ("name", "return_type", "dimensions", "default")
'''

In [None]:
import re

result = re.findall('.*class (.*)\(.*', ast_node)

In [None]:
'\',\''.join(result)

"CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration','ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type','BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair','ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration','ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration','VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement','WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement','ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement','BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter','SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression','BinaryOperation','Cast','MethodReference','LambdaExpression','Pri

In [None]:
# try:   
#     !jupyter nbconvert --to python subword_test.ipynb
#     # python即转化为.py，script即转化为.html
#     # file_name.ipynb即当前module的文件名
# except:
#     pass
