In [1]:
import yaml
import warnings
from datetime import datetime
import logging
import os
import pandas as pd 
import javalang
from javalang.ast import Node
import torch
import torch.nn.functional as F
from torch.nn import (Module, Embedding, LSTM, Sequential, Linear, BatchNorm1d, ReLU, Sigmoid, CrossEntropyLoss, TransformerDecoderLayer,
                        TransformerDecoder)
import torch.optim as optim
from torch_geometric.nn.glob import GlobalAttention
from torch_geometric.nn import MessagePassing, GatedGraphConv, GCNConv, global_mean_pool, GINEConv, global_add_pool
from anytree import AnyNode
from torch_geometric.data import Data, DataLoader, ClusterData, ClusterLoader
from tqdm import tqdm_notebook as tqdm
import numpy as np
from torchsummary import summary
import random
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, DataCollatorWithPadding, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, SequentialSampler
import json
from codebert_seq2seq4 import Seq2Seq
import bleu
import bleu

warnings.filterwarnings('ignore')

Parameters

In [2]:
config_file = 'config_dgnn.yml'

In [3]:
config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# data source
TRAIN_DIR = config['data']['train']
VALID_DIR = config['data']['valid']
TEST_DIR = config['data']['test']


# prepocess design
max_source_length = config['preprocess']['max_source_length']
max_target_length = config['preprocess']['max_target_length']


# training parameter
batch_size = config['training']['batch_size']
num_epoches = config['training']['num_epoches']
lr = config['training']['lr']
decay_ratio = config['training']['lr']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']

# model design
graph_embedding_size = config['model']['graph_embedding_size']
lstm_hidden_size = config['model']['lstm_hidden_size']
divide_node_num = config['model']['divide_node_num']
gnn_layers_num = config['model']['gnn_layers_num']
lstm_layers_num = config['model']['lstm_layers_num']
decoder_input_size = config['model']['decoder_input_size']
decoder_hidden_size = config['model']['decoder_hidden_size']
decoder_num_layers = config['model']['decoder_num_layers']
decoder_rnn_dropout = config['model']['decoder_rnn_dropout']

# logs
info_prefix = config['logs']['info_prefix']

In [22]:
MAX_NODE_NUM = 300 # the max num of subgraph, set for zero padding 
max_subgraph_num = int(MAX_NODE_NUM/divide_node_num) 

Logs

In [4]:
run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')
log_file = 'logs/' + run_id + '.log'
exp_dir = 'runs/' + run_id
os.mkdir(exp_dir)

In [5]:
class Info(object):
    def __init__(self, info_prefix=''):
        self.info_prefix = info_prefix
    
    def print_msg(self, msg):
        text = self.info_prefix + ' ' + msg
        print(text)
        logging.info(text)

In [6]:
logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)
msgr = Info(info_prefix)

msgr.print_msg('run_id : {}'.format(run_id))
msgr.print_msg('log_file : {}'.format(log_file))
msgr.print_msg('exp_dir: {}'.format(exp_dir))
msgr.print_msg(str(config))

dgnn run_id : 2021-08-03--15-15-07
dgnn log_file : logs/2021-08-03--15-15-07.log
dgnn exp_dir: runs/2021-08-03--15-15-07
dgnn {'data': {'train': '/data/code/represent-code-in-human/data/code-summarization-enhanced-full/train_utf8.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-enhanced-full/valid_utf8.jsonl', 'test': '/data/code/represent-code-in-human/data/code-summarization-enhanced-full/test_utf8.jsonl'}, 'small_data': {'train': '/data/code/represent-code-in-human/data/code-summarization-enhanced-small/train_utf8.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-enhanced-small/valid_utf8.jsonl', 'test': '/data/code/represent-code-in-human/data/code-summarization-enhanced-small/test_utf8.jsonl'}, 'middle_data': {'train': '/data/code/represent-code-in-human/data/code-summarization-enhanced-middle/train_utf8.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-enhanced-middle/valid_utf8.jsonl', 'test': '/da

Prepocess

In [7]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = 'None'
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [10]:
checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
ast_tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
roberta = RobertaModel.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
config = RobertaConfig.from_pretrained(checkpoint)
javalang_special_tokens = ['CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration',
                            'ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type',
                            'BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair',
                            'ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration',
                            'ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration',
                            'VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement',
                            'WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement',
                            'ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement',
                            'BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter',
                            'SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression',
                            'BinaryOperation','Cast','MethodReference','LambdaExpression','Primary','Literal','This',
                            'MemberReference','Invocation','ExplicitConstructorInvocation','SuperConstructorInvocation',
                            'MethodInvocation','SuperMethodInvocation','SuperMemberReference','ArraySelector','ClassReference',
                            'VoidClassReference','Creator','ArrayCreator','ClassCreator','InnerClassCreator','EnumBody',
                            'EnumConstantDeclaration','AnnotationMethod', 'Modifier']
special_tokens_dict = {'additional_special_tokens': javalang_special_tokens}
num_added_toks = ast_tokenizer.add_special_tokens(special_tokens_dict)

In [9]:
#  generate tree for AST Node
def create_tree(root, node, node_list, sub_id_list, leave_list, tokenizer, parent=None):
    id = len(node_list)
    node_list.append(node)
    token, children = get_token(node), get_child(node)

    if children == []:
        # print('this is a leaf:', token, id)
        leave_list.append(id)

    # Use roberta.tokenizer to generate subtokens
    # If a token can be divided into multiple(>1) subtokens, the first subtoken will be set as the previous node, 
    # and the other subtokens will be set as its new children
    token = token.encode('utf-8','ignore').decode("utf-8")   
    sub_token_list = tokenizer.tokenize(token)
    
    if id == 0:
        root.token = sub_token_list[0] # the root node is one of the tokenizer's special tokens
        root.data = node
        # record the num of nodes for every children of root
        root_children_node_num = []
        for child in children:
            node_num = len(node_list)
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=root)
            root_children_node_num.append(len(node_list) - node_num)        
        return root_children_node_num
    else:
        # print(sub_token_list)
        new_node = AnyNode(id=id, token=sub_token_list[0], data=node, parent=parent)
        if len(sub_token_list) > 1:
            sub_id_list.append(id)
            for sub_token in sub_token_list[1:]:
                id += 1
                AnyNode(id=id, token=sub_token, data=node, parent=new_node)
                node_list.append(sub_token)
                sub_id_list.append(id)
        
        for child in children:
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=new_node)
    # print(token, id)

In [10]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    node_index_list.append(tokenizer.convert_tokens_to_ids(token))
    # node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        variable_token_list.append(node.children[0].token)
        variable_id_list.append(node.children[0].id)   
    
    for child in node.children:
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list)

In [11]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, tokenizer):
    node_list = []
    sub_id_list = [] # record the ids of node that can be divide into multple subtokens
    leave_list = [] # record the ids of leave 
    new_tree = AnyNode(id=0, token=None, data=None)
    root_children_node_num = create_tree(new_tree, ast, node_list, sub_id_list, leave_list, tokenizer)
    # print('root_children_node_num', root_children_node_num)
    x = []
    edge_src = []
    edge_tgt = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, tokenizer, edge_src, edge_tgt, variable_token_list, variable_id_list)

    ast_edge_num = len(edge_src)
    edge_attr = [[0] for _ in range(ast_edge_num)]
    # set subtoken edge type to 2
    for i in range(len(edge_attr)):
        if edge_src[i] in sub_id_list and edge_tgt[i] in sub_id_list:
            edge_attr[i] = [2]
    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    dataflow_edge_num = len(edge_src) - ast_edge_num

    # add next-token edge
    nexttoken_edge_num = len(leave_list)-1
    for i in range(nexttoken_edge_num):
        edge_src.append(leave_list[i])
        edge_tgt.append(leave_list[i+1])
        edge_src.append(leave_list[i+1])
        edge_tgt.append(leave_list[i])

    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    
    # set data flow edge type to 3
    for _ in range(nexttoken_edge_num * 2):
        edge_attr.append([3])
    
    return x, edge_index, edge_attr, root_children_node_num

In [12]:
# def get_ast_statistic(data):
#     node_num = []
#     edge_num = []
#     sub_node_num = []
#     for code in tqdm(data['code']):
#         x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(parse_program(code), tokenizer)
#         node_num.append(len(x))
#         edge_num.append(len(edge_attr))
#         sub_node_num.append(sum(root_children_node_num))
#     return node_num, edge_num, sub_node_num

# train_data = pd.read_json(TRAIN_DIR, lines=True)
# valid_data = pd.read_json(VALID_DIR, lines=True)
# test_data = pd.read_json(TEST_DIR, lines=True)

# train_node_num, train_edge_num, train_sub_node_num = get_ast_statistic(train_data)
# valid_node_num, valid_edge_num, valid_sub_node_num = get_ast_statistic(valid_data)
# test_node_num, test_edge_num, test_sub_node_num = get_ast_statistic(test_data)

HBox(children=(IntProgress(value=0, max=164814), HTML(value='')))




HBox(children=(IntProgress(value=0, max=5179), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10952), HTML(value='')))




In [21]:
# train_sub_node_series = pd.Series(train_sub_node_num)
# train_sub_node_series.describe()

count    164814.000000
mean        136.569824
std         120.903685
min          20.000000
25%          59.000000
50%          94.000000
75%         166.000000
max        3240.000000
dtype: float64

In [35]:

def get_subgraph_node_num(root_children_node_num, divide_node_num):
    subgraph_node_num = []
    node_sum = 0
    real_graph_num = 0
    for num in root_children_node_num:
        node_sum += num
        if node_sum >= divide_node_num:
            subgraph_node_num.append(node_sum)
            node_sum = 0    
    
    subgraph_node_num.append(node_sum)
    real_graph_num = len(subgraph_node_num)

    if real_graph_num >= max_subgraph_num:
        return subgraph_node_num[: max_subgraph_num], max_subgraph_num

    # print(len(subgraph_node_num))
    # if the last subgraph node num < divide_node_num, then put the last subgraph to the second to last subgraph
    # if subgraph_node_num[-1] < divide_node_num:
    #     subgraph_node_num[-2] = subgraph_node_num[-2] + subgraph_node_num[-1]
    #     subgraph_node_num[-1] = 0
    #     real_graph_num -= 1

    # zero padding for tensor transforming
    for _ in range(real_graph_num, max_subgraph_num):
        subgraph_node_num.append(0)
    
    return subgraph_node_num, real_graph_num

In [36]:
def convert_examples_to_features(examples, ast_tokenizer, tokenizer, stage=None):
    features = []
    for example in tqdm(examples):
        # pyg
        ast = parse_program(example.source)
        x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(ast, ast_tokenizer)
        subgraph_node_num, real_graph_num = get_subgraph_node_num(root_children_node_num, divide_node_num)

        # source
        source_tokens = tokenizer.tokenize(example.ast_des)[: max_source_length-2]
        source_tokens = [tokenizer.cls_token] + source_tokens + [tokenizer.sep_token]
        source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
        source_mask = [1] * (len(source_ids))
        padding_length = max_source_length - len(source_ids)
        source_ids += [tokenizer.pad_token_id] * padding_length
        source_mask += [0] * padding_length

        # target
        if stage == 'test':
            target_tokens = tokenizer.tokenize('None')
        else:
            target_tokens = tokenizer.tokenize(example.target)[: max_target_length-2]
        target_tokens = [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
        target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
        target_mask = [1] * len(target_ids)
        padding_length = max_target_length - len(target_ids)
        target_ids += [tokenizer.pad_token_id] * padding_length
        target_mask += [0] * padding_length

        features.append(
            Data(
                x= torch.tensor(x, dtype=torch.long),
                edge_index=torch.tensor(edge_index, dtype=torch.long),
                edge_attr=torch.tensor(edge_attr, dtype=torch.long),
                source_ids=torch.tensor(source_ids, dtype=torch.long),
                source_mask=torch.tensor(source_mask, dtype=torch.long),
                target_ids=torch.tensor(target_ids, dtype=torch.long),
                target_mask=torch.tensor(target_mask, dtype=torch.long),
                subgraph_node_num=torch.tensor(subgraph_node_num, dtype=torch.long),
                real_graph_num=torch.tensor(real_graph_num, dtype=torch.long)
            )
        )
    return features

In [19]:
class Example(object):
    def __init__(self, idx, source, ast_des, target):
        self.idx = idx
        self.source = source
        self.ast_des = ast_des
        self.target = target

In [17]:
# read dataset
def read_examples(filename):
    examples = []
    with open(filename, encoding='utf-8') as f:
        for idx, line in enumerate(f):
            line = line.strip()
            js = json.loads(line)
            if 'idx' not in js:
                js['idx'] = idx
            
            code = js['code']
            nl = ' '.join(js['docstring_tokens']).replace('\n', '')
            nl = ' '.join(nl.strip().split())
            ast_des = js['ast_des']
            examples.append(
                Example(
                    idx = idx,
                    source = code,
                    ast_des = ast_des,
                    target = nl,
                )
            )
    return examples

In [20]:
train_examples = read_examples(TRAIN_DIR)
valid_examples = read_examples(VALID_DIR)
test_examples = read_examples(TEST_DIR)
msgr.print_msg('train size: {}, valid size: {}, test size: {}'.format(len(train_examples), len(valid_examples), len(test_examples)))

dgnn train size: 164814, valid size: 5179, test size: 10952


In [37]:
train_features = convert_examples_to_features(train_examples, ast_tokenizer, tokenizer, stage='train')
valid_features = convert_examples_to_features(valid_examples, ast_tokenizer, tokenizer, stage='valid')
test_features = convert_examples_to_features(test_examples, ast_tokenizer, tokenizer, stage='test')

HBox(children=(IntProgress(value=0, max=164814), HTML(value='')))




HBox(children=(IntProgress(value=0, max=5179), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10952), HTML(value='')))




In [38]:
torch.save(train_features,'features/pgnn/train_features.pt')
torch.save(valid_features,'features/pgnn/valid_features.pt')
torch.save(test_features,'features/pgnn/test_features.pt')

In [7]:
train_features = torch.load('features/pgnn/train_features.pt')
valid_features = torch.load('features/pgnn/valid_features.pt')
test_features = torch.load('features/pgnn/test_features.pt')

Model

In [8]:
class GNNEncoder(Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, decoder_input_size, device):
        super(GNNEncoder, self).__init__()
        self.device = device
        self.embeddings = Embedding(vocab_len, graph_embedding_size)
        self.edge_embed = Embedding(4, 1) # only two edge types to be set weights, which are AST edge and data flow edge
        self.ggnnlayer = GatedGraphConv(graph_embedding_size, gnn_layers_num)
        self.mlp_gate = Sequential(
            Linear(graph_embedding_size, 300), Sigmoid(), Linear(300, 1), Sigmoid())
        self.pool = GlobalAttention(gate_nn=self.mlp_gate)
        self.lstm = LSTM(input_size=graph_embedding_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers_num)
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_layers_num = lstm_layers_num
        self.fc = Linear(graph_embedding_size + lstm_hidden_size, decoder_input_size)

    def subgraph_forward(self, x, edge_index, edge_attr, batch):
        if type(edge_attr) == type(None):
            edge_weight = None
        else:
            edge_weight = self.edge_embed(edge_attr)
            edge_weight = edge_weight.squeeze(1)
        x = self.ggnnlayer(x, edge_index, edge_weight)
        return self.pool(x, batch=batch)
    
    # partitioning multiple subgraphs by dynamic allocating edges
    def partition_graph(self, x, edge_index, edge_attr, subgraph_node_num, real_graph_num, ptr):        
        nodes_list = [] # record all nodes number for each subgraph in total batch
        subgraph_num = max(real_graph_num)

        batch_size = subgraph_node_num.size(0)
        start_node_num = [1 for _ in range(batch_size)]
        for i in range(subgraph_num):
            subgraph_nodes_list = []
            for j in range(batch_size):
                if subgraph_node_num[j][i] != 0:
                    for k in range(ptr[j]+start_node_num[j], ptr[j]+start_node_num[j]+subgraph_node_num[j][i]):
                        subgraph_nodes_list.append(k)
                    start_node_num[j] += subgraph_node_num[j][i]
            nodes_list.append(subgraph_nodes_list)

        # only count the edge whose target node in subgraph
        sub_edge_src = [[] for _ in range(subgraph_num)]
        sub_edge_tgt = [[] for _ in range(subgraph_num)]
        sub_edge_attr = [[] for _ in range(subgraph_num)]
        # print('nodes_list', nodes_list)
        node_num = len(x)
        node_subgraph_index = [0 for _ in range(node_num)] # use a list to store the subgraph numbers for all nodes
        for i in range(len(nodes_list)):
            for node in nodes_list[i]:
                node_subgraph_index[node] = i

        for i in range(len(edge_index[1])):
            src = edge_index[0][i].item()
            tgt = edge_index[1][i].item()
            sub_edge_src[node_subgraph_index[tgt]].append(src)
            sub_edge_tgt[node_subgraph_index[tgt]].append(tgt)
            sub_edge_attr[node_subgraph_index[tgt]].append(edge_attr[i].item())
        edge_index_list = []
        edge_attr_list = []
        for i in range(subgraph_num):
            edge_index_list.append(torch.tensor([sub_edge_src[i], sub_edge_tgt[i]], dtype=torch.long))
            edge_attr_list.append(torch.tensor(sub_edge_attr[i], dtype=torch.long))
        # print('nodes_list', nodes_list)
        return edge_index_list, edge_attr_list  

    def forward(self, x, edge_index, edge_attr, subgraph_node_num, real_graph_num, batch, ptr):
        edge_index_list, edge_attr_list = self.partition_graph(x, edge_index, edge_attr, subgraph_node_num, real_graph_num, ptr)
        # print('edge_index_list', edge_index_list)
        # print('edge_attr_list', edge_attr_list)
        x = self.embeddings(x)
        x = x.squeeze(1)
        subgraph_pool_list = [
            self.subgraph_forward(x, edge_index_list[i].to(self.device), edge_attr_list[i].to(self.device), batch)
            for i in range(len(edge_index_list))
        ]
        graph_pool = self.subgraph_forward(x, edge_index, edge_attr, batch)
        # print('graph_pool', graph_pool.shape)
        subgraph_pool_seq = torch.stack(subgraph_pool_list)
        # print('subgraph_pool_seq', subgraph_pool_seq.shape)
        h0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        c0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        subgraph_output, (_, _) = self.lstm(subgraph_pool_seq, (h0, c0))
        return self.fc(torch.cat((subgraph_output[-1], graph_pool), dim=1))

In [11]:
device = torch.device('cuda: 0')
gnn_encoder = GNNEncoder(vocab_len=tokenizer.vocab_size+num_added_toks, graph_embedding_size=graph_embedding_size,
                         gnn_layers_num=gnn_layers_num, lstm_layers_num=lstm_layers_num, lstm_hidden_size=lstm_hidden_size,
                        decoder_input_size=decoder_input_size, device=device)
decoder_layer = TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
decoder = TransformerDecoder(decoder_layer, num_layers=6)
model = Seq2Seq(encoder=gnn_encoder, decoder=decoder, config=config, beam_size=10, max_length=max_target_length, 
                sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
model.to(device)

Seq2Seq(
  (encoder): GNNEncoder(
    (embeddings): Embedding(50336, 768)
    (edge_embed): Embedding(4, 1)
    (ggnnlayer): GatedGraphConv(768, num_layers=2)
    (mlp_gate): Sequential(
      (0): Linear(in_features=768, out_features=300, bias=True)
      (1): Sigmoid()
      (2): Linear(in_features=300, out_features=1, bias=True)
      (3): Sigmoid()
    )
    (pool): GlobalAttention(gate_nn=Sequential(
      (0): Linear(in_features=768, out_features=300, bias=True)
      (1): Sigmoid()
      (2): Linear(in_features=300, out_features=1, bias=True)
      (3): Sigmoid()
    ), nn=None)
    (lstm): LSTM(768, 128, num_layers=3)
    (fc): Linear(in_features=896, out_features=768, bias=True)
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): Linear(in_

In [35]:
max_source_length = 256
max_target_length = 32
batch_size = 32
beam_size = 10
lr = 5e-5
warmup_steps = 0
train_steps = 50000
weight_decay = 0.0
adam_epsilon = 1e-8
valid_steps = 500
output_dir = exp_dir

train_url = TRAIN_DIR
valid_url = VALID_DIR
test_url = TEST_DIR

In [26]:
train_dataloader = DataLoader(train_features, batch_size=batch_size)

In [14]:
# optimizer and schedule
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                            num_training_steps=train_steps)

In [36]:
from itertools import cycle

#Start training
msgr.print_msg("***** Running training *****")
msgr.print_msg("  Num examples = {}".format(len(train_examples)))
msgr.print_msg("  Batch size = {}".format(batch_size))
msgr.print_msg("  Num epoch = {}".format(batch_size//len(train_examples)))
model.train()
valid_dataset = {}
nb_tr_examples, nb_tr_steps, tr_loss, global_step, best_bleu, best_loss = 0, 0, 0, 0, 0, 1e6
bar = tqdm(range(train_steps), total=train_steps)
train_dataloader = cycle(train_dataloader)

for step in bar:
    data = next(train_dataloader)
    data = data.to(device)
    subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
    real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
    target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
    target_mask = torch.stack(torch.split(data.target_mask, max_target_length))  
    loss, _, _, = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, subgraph_node_num=subgraph_node_num, 
                        real_graph_num=real_graph_num, batch=data.batch, ptr=data.ptr, target_ids=target_ids, target_mask=target_mask)

    tr_loss += loss.item()
    train_loss = round(tr_loss / (nb_tr_steps + 1), 4) 
    bar.set_description('loss {}'.format(train_loss))
    nb_tr_examples += data.x.size(0)
    nb_tr_steps += 1
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    global_step += 1
    
    if (global_step + 1) % valid_steps == 0:
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        valid_sampler = SequentialSampler(valid_features)
        valid_dataloader = DataLoader(valid_features, sampler=valid_sampler, batch_size=batch_size)

        msgr.print_msg("\n***** Running evaluation *****")
        msgr.print_msg("  Num examples = {}".format(len(valid_examples)))
        msgr.print_msg("  Batch size = {}".format(batch_size))

        #Start Evaling model
        model.eval()
        valid_loss, tokens_num = 0, 0
        for data in valid_dataloader:
            data = data.to(device)
            subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
            real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
            target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
            target_mask = torch.stack(torch.split(data.target_mask, max_target_length))            

            with torch.no_grad():
                _,loss,num = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, subgraph_node_num=subgraph_node_num, 
                        real_graph_num=real_graph_num, batch=data.batch, ptr=data.ptr)    
            valid_loss += loss.sum().item()
            tokens_num += num.sum().item()
        #Pring loss of valid dataset    
        model.train()
        valid_loss = valid_loss / tokens_num
        result = { 'valid_loss': valid_loss,
                    'valid_ppl': round(np.exp(valid_loss), 5),
                    'global_step': global_step+1,
                    'train_loss': round(train_loss, 5)}
        for key in sorted(result.keys()):
            msgr.print_msg("{}= {}".format(key, str(result[key])))
        msgr.print_msg("  "+"*"*20)   
        
        #save last checkpoint
        last_output_dir = os.path.join(output_dir, 'checkpoint-last')
        if not os.path.exists(last_output_dir):
            os.makedirs(last_output_dir)
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(last_output_dir, "pytorch_model.bin")
        torch.save(model_to_save.state_dict(), output_model_file)                    
        if valid_loss < best_loss:
            msgr.print_msg("  Best ppl:{}".format(round(np.exp(valid_loss), 5)))
            msgr.print_msg("  " + "*" * 20)
            best_loss = valid_loss
            # Save best checkpoint for best ppl
            best_output_dir = os.path.join(output_dir, 'checkpoint-best-ppl')
            if not os.path.exists(best_output_dir):
                os.makedirs(best_output_dir)
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(best_output_dir, "pytorch_model.bin")
            torch.save(model_to_save.state_dict(), output_model_file)  
                    
                    
        #Calculate bleu  
        # if 'valid_bleu' in valid_dataset:
        #     valid_examples, valid_features = valid_dataset['valid_bleu']
        # else:
        #     valid_examples = read_examples(valid_url)
        #     valid_examples = random.sample(valid_examples, min(1000,len(valid_examples)))
        #     valid_features = convert_examples_to_features(valid_examples, tokenizer,stage='test')  
        #     valid_dataset['valid_bleu']= valid_examples, valid_features

        # valid_sampler = SequentialSampler(valid_features)
        # valid_dataloader = DataLoader(valid_features, sampler=valid_sampler, batch_size=batch_size)

        model.eval() 
        p=[]
        # i = 0
        for data in valid_dataloader:
            # print('i', i)
            # i += 1
            data = data.to(device)
            subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
            real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
            target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
            target_mask = torch.stack(torch.split(data.target_mask, max_target_length))                  
            with torch.no_grad():
                preds = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, subgraph_node_num=subgraph_node_num, 
                        real_graph_num=real_graph_num, batch=data.batch, ptr=data.ptr)
                for pred in preds:
                    t=pred[0].cpu().numpy()
                    t=list(t)
                    if 0 in t:
                        t=t[:t.index(0)]
                    text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
                    p.append(text)
        model.train()
        predictions=[]
        with open(os.path.join(output_dir,"valid.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"valid.gold"),'w', encoding='utf-8') as f1:
            for ref, gold in zip(p, valid_examples):
                predictions.append(str(gold.idx)+'\t'+ref)
                f.write(str(gold.idx)+'\t'+ref+'\n')
                f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

        (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "valid.gold")) 
        valid_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
        msgr.print_msg("  {} = {}".format("bleu-4", str(valid_bleu)))
        msgr.print_msg("  "+"*"*20)    
        if valid_bleu>best_bleu:
            msgr.print_msg("  Best bleu:{}".format(valid_bleu))
            msgr.print_msg("  "+"*"*20)
            best_bleu=valid_bleu
            # Save best checkpoint for best bleu
            best_output_dir = os.path.join(output_dir, 'checkpoint-best-bleu')
            if not os.path.exists(best_output_dir):
                os.makedirs(best_output_dir)
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(best_output_dir, "pytorch_model.bin")
            torch.save(model_to_save.state_dict(), output_model_file)
        

dgnn ***** Running training *****
dgnn   Num examples = 164814
dgnn   Batch size = 32
dgnn   Num epoch = 0


HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))

In [23]:
test_examples = read_examples(test_url)
test_features = convert_examples_to_features(test_examples, tokenizer, stage='test')
# Calculate bleu
test_sampler = SequentialSampler(test_features)
test_dataloader = DataLoader(test_features, sampler=test_sampler, batch_size=batch_size)

best_ppl_model = output_dir + '/checkpoint-best-ppl/pytorch_model.bin' 
model.load_state_dict(torch.load(best_ppl_model))
model.eval() 
p=[]
for batch in tqdm(test_dataloader,total=len(test_dataloader)):
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_attr = data.edge_attr.to(device)
    batch = data.batch.to(device)                     
    with torch.no_grad():
        preds = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)   
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     

HBox(children=(IntProgress(value=0, max=18), HTML(value='')))


dgnn  bleu-4 = 9.43 
dgnn   ********************


Total: 1095


In [24]:
best_bleu_model = output_dir + '/checkpoint-best-bleu/pytorch_model.bin' 
model.load_state_dict(torch.load(best_bleu_model))
model.eval() 
p=[]
for batch in tqdm(test_dataloader,total=len(test_dataloader)):
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_attr = data.edge_attr.to(device)
    batch = data.batch.to(device)                 
    with torch.no_grad():
        preds = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)   
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"checkpoint-best-bleu/test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"checkpoint-best-bleu/test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     

HBox(children=(IntProgress(value=0, max=18), HTML(value='')))


dgnn  bleu-4 = 9.68 
dgnn   ********************


Total: 1095


In [25]:
last_model = output_dir + '/checkpoint-last/pytorch_model.bin' 
model.load_state_dict(torch.load(last_model))
model.eval() 
p=[]
for batch in tqdm(test_dataloader,total=len(test_dataloader)):
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_attr = data.edge_attr.to(device)
    batch = data.batch.to(device)                   
    with torch.no_grad():
        preds = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)   
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"checkpoint-last/test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"checkpoint-last/test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     

HBox(children=(IntProgress(value=0, max=18), HTML(value='')))


dgnn  bleu-4 = 8.89 
dgnn   ********************


Total: 1095
