In [1]:
import yaml
import warnings
from datetime import datetime
import logging
import os
import pandas as pd 
import javalang
from javalang.ast import Node
import torch
import torch.nn.functional as F
from torch.nn import (Module, Embedding, LSTM, Sequential, Linear, BatchNorm1d, ReLU, Sigmoid, CrossEntropyLoss, TransformerDecoderLayer,
                        TransformerDecoder)
import torch.optim as optim
from torch_geometric.nn.glob import GlobalAttention
from torch_geometric.nn import MessagePassing, GatedGraphConv, GCNConv, global_mean_pool, GINEConv, global_add_pool
from anytree import AnyNode
from torch_geometric.data import Data, DataLoader, ClusterData, ClusterLoader
from tqdm import tqdm_notebook as tqdm
import numpy as np
from torchsummary import summary
import random
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, DataCollatorWithPadding, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, SequentialSampler
import json
from codebert_seq2seq3 import Seq2Seq
import bleu

warnings.filterwarnings('ignore')

Parameters

In [2]:
config_file = 'config_dgnn.yml'

In [3]:
config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# data source
TRAIN_DIR = config['data']['train']
VALID_DIR = config['data']['valid']
TEST_DIR = config['data']['test']


# prepocess design
max_source_length = config['preprocess']['max_source_length']
max_target_length = config['preprocess']['max_target_length']


# training parameter
batch_size = config['training']['batch_size']
num_epoches = config['training']['num_epoches']
lr = config['training']['lr']
decay_ratio = config['training']['lr']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']

# model design
graph_embedding_size = config['model']['graph_embedding_size']
lstm_hidden_size = config['model']['lstm_hidden_size']
divide_node_num = config['model']['divide_node_num']
gnn_layers_num = config['model']['gnn_layers_num']
lstm_layers_num = config['model']['lstm_layers_num']
decoder_input_size = config['model']['decoder_input_size']
decoder_hidden_size = config['model']['decoder_hidden_size']
decoder_num_layers = config['model']['decoder_num_layers']
decoder_rnn_dropout = config['model']['decoder_rnn_dropout']

# logs
info_prefix = config['logs']['info_prefix']

In [4]:
MAX_NODE_NUM = 300 # the max num of subgraph, set for zero padding 
max_subgraph_num = int(MAX_NODE_NUM/divide_node_num) 

Logs

In [5]:
run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')
log_file = 'logs/' + run_id + '.log'
exp_dir = 'runs/' + run_id
os.mkdir(exp_dir)

In [6]:
class Info(object):
    def __init__(self, info_prefix=''):
        self.info_prefix = info_prefix
    
    def print_msg(self, msg):
        text = self.info_prefix + ' ' + msg
        print(text)
        logging.info(text)

In [7]:
logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)
msgr = Info(info_prefix)

msgr.print_msg('run_id : {}'.format(run_id))
msgr.print_msg('log_file : {}'.format(log_file))
msgr.print_msg('exp_dir: {}'.format(exp_dir))
msgr.print_msg(str(config))

dgnn run_id : 2021-08-09--15-27-47
dgnn log_file : logs/2021-08-09--15-27-47.log
dgnn exp_dir: runs/2021-08-09--15-27-47
dgnn {'data': {'train': '/data/code/represent-code-in-human/data/code-summarization-enhanced-full/train_utf8.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-enhanced-full/valid_utf8.jsonl', 'test': '/data/code/represent-code-in-human/data/code-summarization-enhanced-full/test_utf8.jsonl'}, 'small_data': {'train': '/data/code/represent-code-in-human/data/code-summarization-enhanced-small/train_utf8.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-enhanced-small/valid_utf8.jsonl', 'test': '/data/code/represent-code-in-human/data/code-summarization-enhanced-small/test_utf8.jsonl'}, 'middle_data': {'train': '/data/code/represent-code-in-human/data/code-summarization-enhanced-middle/train_utf8.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-enhanced-middle/valid_utf8.jsonl', 'test': '/da

In [8]:
checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
ast_tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
roberta = RobertaModel.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
config = RobertaConfig.from_pretrained(checkpoint)
javalang_special_tokens = ['CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration',
                            'ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type',
                            'BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair',
                            'ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration',
                            'ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration',
                            'VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement',
                            'WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement',
                            'ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement',
                            'BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter',
                            'SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression',
                            'BinaryOperation','Cast','MethodReference','LambdaExpression','Primary','Literal','This',
                            'MemberReference','Invocation','ExplicitConstructorInvocation','SuperConstructorInvocation',
                            'MethodInvocation','SuperMethodInvocation','SuperMemberReference','ArraySelector','ClassReference',
                            'VoidClassReference','Creator','ArrayCreator','ClassCreator','InnerClassCreator','EnumBody',
                            'EnumConstantDeclaration','AnnotationMethod', 'Modifier']
special_tokens_dict = {'additional_special_tokens': javalang_special_tokens}
num_added_toks = ast_tokenizer.add_special_tokens(special_tokens_dict)

In [9]:
class Example(object):
    def __init__(self, idx, source, ast_des, target):
        self.idx = idx
        self.source = source
        self.ast_des = ast_des
        self.target = target

In [10]:
# read dataset
def read_examples(filename):
    examples = []
    with open(filename, encoding='utf-8') as f:
        for idx, line in enumerate(f):
            line = line.strip()
            js = json.loads(line)
            if 'idx' not in js:
                js['idx'] = idx
            
            code = js['code']
            nl = ' '.join(js['docstring_tokens']).replace('\n', '')
            nl = ' '.join(nl.strip().split())
            ast_des = js['ast_des']
            examples.append(
                Example(
                    idx = idx,
                    source = code,
                    ast_des = ast_des,
                    target = nl,
                )
            )
    return examples

In [11]:
train_examples = read_examples(TRAIN_DIR)
valid_examples = read_examples(VALID_DIR)
test_examples = read_examples(TEST_DIR)
msgr.print_msg('train size: {}, valid size: {}, test size: {}'.format(len(train_examples), len(valid_examples), len(test_examples)))

dgnn train size: 164814, valid size: 5179, test size: 10952


In [12]:
train_features = torch.load('features/pgnn/train_features.pt')
valid_features = torch.load('features/pgnn/valid_features.pt')
test_features = torch.load('features/pgnn/test_features.pt')

Model

In [13]:
class GNNEncoder(Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, decoder_input_size, device):
        super(GNNEncoder, self).__init__()
        self.device = device
        self.embeddings = Embedding(vocab_len, graph_embedding_size)
        self.edge_embed = Embedding(4, 1) # only two edge types to be set weights, which are AST edge and data flow edge
        self.ggnnlayer = GatedGraphConv(graph_embedding_size, gnn_layers_num)
        self.mlp_gate = Sequential(
            Linear(graph_embedding_size, 300), Sigmoid(), Linear(300, 1), Sigmoid())
        self.pool = GlobalAttention(gate_nn=self.mlp_gate)
        self.lstm = LSTM(input_size=graph_embedding_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers_num)
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_layers_num = lstm_layers_num
        self.fc = Linear(graph_embedding_size + lstm_hidden_size, decoder_input_size)

    def subgraph_forward(self, x, edge_index, edge_attr, batch):
        if type(edge_attr) == type(None):
            edge_weight = None
        else:
            edge_weight = self.edge_embed(edge_attr)
            edge_weight = edge_weight.squeeze(1)
        x = self.ggnnlayer(x, edge_index, edge_weight)
        return self.pool(x, batch=batch)
    
    # partitioning multiple subgraphs by dynamic allocating edges
    def partition_graph(self, x, edge_index, edge_attr, subgraph_node_num, real_graph_num, ptr):        
        nodes_list = [] # record all nodes number for each subgraph in total batch
        subgraph_num = max(real_graph_num)

        batch_size = subgraph_node_num.size(0)
        start_node_num = [1 for _ in range(batch_size)]
        for i in range(subgraph_num):
            subgraph_nodes_list = []
            for j in range(batch_size):
                if subgraph_node_num[j][i] != 0:
                    for k in range(ptr[j]+start_node_num[j], ptr[j]+start_node_num[j]+subgraph_node_num[j][i]):
                        subgraph_nodes_list.append(k)
                    start_node_num[j] += subgraph_node_num[j][i]
            nodes_list.append(subgraph_nodes_list)

        # only count the edge whose target node in subgraph
        sub_edge_src = [[] for _ in range(subgraph_num)]
        sub_edge_tgt = [[] for _ in range(subgraph_num)]
        sub_edge_attr = [[] for _ in range(subgraph_num)]
        # print('nodes_list', nodes_list)
        node_num = len(x)
        node_subgraph_index = [0 for _ in range(node_num)] # use a list to store the subgraph numbers for all nodes
        for i in range(len(nodes_list)):
            for node in nodes_list[i]:
                node_subgraph_index[node] = i

        for i in range(len(edge_index[1])):
            src = edge_index[0][i].item()
            tgt = edge_index[1][i].item()
            sub_edge_src[node_subgraph_index[tgt]].append(src)
            sub_edge_tgt[node_subgraph_index[tgt]].append(tgt)
            sub_edge_attr[node_subgraph_index[tgt]].append(edge_attr[i].item())
        edge_index_list = []
        edge_attr_list = []
        for i in range(subgraph_num):
            edge_index_list.append(torch.tensor([sub_edge_src[i], sub_edge_tgt[i]], dtype=torch.long))
            edge_attr_list.append(torch.tensor(sub_edge_attr[i], dtype=torch.long))
        # print('nodes_list', nodes_list)
        return edge_index_list, edge_attr_list  

    def forward(self, x, edge_index, edge_attr, subgraph_node_num, real_graph_num, batch, ptr):
        edge_index_list, edge_attr_list = self.partition_graph(x, edge_index, edge_attr, subgraph_node_num, real_graph_num, ptr)
        # print('edge_index_list', edge_index_list)
        # print('edge_attr_list', edge_attr_list)
        x = self.embeddings(x)
        x = x.squeeze(1)
        subgraph_pool_list = [
            self.subgraph_forward(x, edge_index_list[i].to(self.device), edge_attr_list[i].to(self.device), batch)
            for i in range(len(edge_index_list))
        ]
        graph_pool = self.subgraph_forward(x, edge_index, edge_attr, batch)
        # print('graph_pool', graph_pool.shape)
        subgraph_pool_seq = torch.stack(subgraph_pool_list)
        # print('subgraph_pool_seq', subgraph_pool_seq.shape)
        h0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        c0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        subgraph_output, (_, _) = self.lstm(subgraph_pool_seq, (h0, c0))
        return self.fc(torch.cat((subgraph_output[-1], graph_pool), dim=1))

In [14]:
device = torch.device('cuda: 0')
gnn_encoder = GNNEncoder(vocab_len=tokenizer.vocab_size+num_added_toks, graph_embedding_size=graph_embedding_size,
                         gnn_layers_num=gnn_layers_num, lstm_layers_num=lstm_layers_num, lstm_hidden_size=lstm_hidden_size,
                        decoder_input_size=decoder_input_size, device=device)
decoder_layer = TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
decoder = TransformerDecoder(decoder_layer, num_layers=6)
model = Seq2Seq(encoder=roberta, decoder=decoder, gnn_encoder=gnn_encoder, config=config, beam_size=10, max_length=max_target_length, 
                sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
model.to(device)

Seq2Seq(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), e

In [15]:
max_source_length = 256
max_target_length = 32
batch_size = 32
beam_size = 10
lr = 4e-5
warmup_steps = 0
train_steps = 50000
weight_decay = 0.0
adam_epsilon = 1e-8
valid_loss_steps = 500
valid_bleu_steps = 5000
output_dir = exp_dir

train_url = TRAIN_DIR
valid_url = VALID_DIR
test_url = TEST_DIR

In [16]:
train_dataloader = DataLoader(train_features, batch_size=batch_size)

In [17]:
# optimizer and schedule
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=adam_epsilon)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=5000,
#                                             num_training_steps=30000)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                            num_training_steps=train_steps)

In [18]:
from itertools import cycle

#Start training
msgr.print_msg("***** Running training *****")
msgr.print_msg("  Num examples = {}".format(len(train_features)))
msgr.print_msg("  Batch size = {}".format(batch_size))
msgr.print_msg("  lr= {}".format(lr))
msgr.print_msg("  Num epoch = {}".format(batch_size//len(train_features)))
model.train()
valid_dataset = {}
nb_tr_examples, nb_tr_steps, tr_loss, global_step, best_bleu, best_loss = 0, 0, 0, 0, 0, 1e6
bar = tqdm(range(train_steps), total=train_steps)
train_dataloader = cycle(train_dataloader)

for step in bar:
    data = next(train_dataloader)
    data = data.to(device)
    subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
    real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
    source_ids = torch.stack(torch.split(data.source_ids, max_source_length))
    source_mask = torch.stack(torch.split(data.source_mask, max_source_length))
    target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
    target_mask = torch.stack(torch.split(data.target_mask, max_target_length))
    loss, _, _, = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, subgraph_node_num=subgraph_node_num, 
                        real_graph_num=real_graph_num, batch=data.batch, ptr=data.ptr, source_ids=source_ids, source_mask=source_mask, 
                        target_ids=target_ids, target_mask=target_mask)

    tr_loss += loss.item()
    train_loss = round(tr_loss / (nb_tr_steps + 1), 4)
    bar.set_description('loss {}'.format(train_loss))
    nb_tr_examples += data.x.size(0)
    nb_tr_steps += 1
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    global_step += 1

    if (global_step + 1) % valid_loss_steps == 0:
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        valid_sampler = SequentialSampler(valid_features)
        valid_dataloader = DataLoader(valid_features, sampler=valid_sampler, batch_size=batch_size)

        msgr.print_msg("\n***** Running evaluation *****")
        msgr.print_msg("  Num examples = {}".format(len(valid_features)))
        msgr.print_msg("  Batch size = {}".format(batch_size))

        #Start Evaling model
        model.eval()
        valid_loss, tokens_num = 0, 0
        for data in valid_dataloader:
            data = data.to(device)
            subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
            real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
            source_ids = torch.stack(torch.split(data.source_ids, max_source_length))
            source_mask = torch.stack(torch.split(data.source_mask, max_source_length))
            target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
            target_mask = torch.stack(torch.split(data.target_mask, max_target_length))            

            with torch.no_grad():
                _,loss,num = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, batch=data.batch, 
                                    subgraph_node_num=subgraph_node_num, real_graph_num=real_graph_num,  ptr=data.ptr,
                                    source_ids=source_ids, source_mask=source_mask, target_ids=target_ids, target_mask=target_mask)    
            valid_loss += loss.sum().item()
            tokens_num += num.sum().item()
        #Pring loss of valid dataset    
        model.train()
        valid_loss /= tokens_num
        result = { 'valid_loss': valid_loss,
                    'valid_ppl': round(np.exp(valid_loss), 5),
                    'global_step': global_step+1,
                    'train_loss': round(train_loss, 5)}
        for key in sorted(result.keys()):
            msgr.print_msg("{}= {}".format(key, str(result[key])))
        msgr.print_msg("  "+"*"*20)   

        #save last checkpoint
        last_output_dir = os.path.join(output_dir, 'checkpoint-last')
        if not os.path.exists(last_output_dir):
            os.makedirs(last_output_dir)
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(last_output_dir, "pytorch_model.bin")
        torch.save(model_to_save.state_dict(), output_model_file)
        if valid_loss < best_loss:
            msgr.print_msg("  Best ppl:{}".format(round(np.exp(valid_loss), 5)))
            msgr.print_msg("  " + "*" * 20)
            best_loss = valid_loss
            # Save best checkpoint for best ppl
            best_output_dir = os.path.join(output_dir, 'checkpoint-best-ppl')
            if not os.path.exists(best_output_dir):
                os.makedirs(best_output_dir)
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(best_output_dir, "pytorch_model.bin")
            torch.save(model_to_save.state_dict(), output_model_file)  


    if (global_step + 1) % valid_bleu_steps == 0:
        model.eval()
        p=[]
        for data in valid_dataloader:
            data = data.to(device)
            subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
            real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
            source_ids = torch.stack(torch.split(data.source_ids, max_source_length))
            source_mask = torch.stack(torch.split(data.source_mask, max_source_length))
            target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
            target_mask = torch.stack(torch.split(data.target_mask, max_target_length))                  
            with torch.no_grad():
                preds = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, batch=data.batch, 
                                subgraph_node_num=subgraph_node_num, real_graph_num=real_graph_num, ptr=data.ptr,
                                source_ids=source_ids, source_mask=source_mask)
                for pred in preds:
                    t=pred[0].cpu().numpy()
                    t=list(t)
                    if 0 in t:
                        t=t[:t.index(0)]
                    text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
                    p.append(text)
        predictions=[]
        with open(os.path.join(output_dir,"valid.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"valid.gold"),'w', encoding='utf-8') as f1:
            for ref, gold in zip(p, valid_examples):
                predictions.append(str(gold.idx)+'\t'+ref)
                f.write(str(gold.idx)+'\t'+ref+'\n')
                f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

        (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "valid.gold"))
        valid_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
        msgr.print_msg("  {} = {}".format("bleu-4", str(valid_bleu)))
        msgr.print_msg("  "+"*"*20)
        if valid_bleu>best_bleu:
            msgr.print_msg("  Best bleu:{}".format(valid_bleu))
            msgr.print_msg("  "+"*"*20)
            best_bleu=valid_bleu
            # Save best checkpoint for best bleu
            best_output_dir = os.path.join(output_dir, 'checkpoint-best-bleu')
            if not os.path.exists(best_output_dir):
                os.makedirs(best_output_dir)
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(best_output_dir, "pytorch_model.bin")
            torch.save(model_to_save.state_dict(), output_model_file)
    
    model.train()

        

dgnn ***** Running training *****
dgnn   Num examples = 164814
dgnn   Batch size = 32
dgnn   lr= 4e-05
dgnn   Num epoch = 0


HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))

dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 500
dgnn train_loss= 7.9312
dgnn valid_loss= 6.079093101666136
dgnn valid_ppl= 436.63303
dgnn   ********************
dgnn   Best ppl:436.63303
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 1000
dgnn train_loss= 5.889
dgnn valid_loss= 5.290224075046379
dgnn valid_ppl= 198.38787
dgnn   ********************
dgnn   Best ppl:198.38787
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 1500
dgnn train_loss= 5.198
dgnn valid_loss= 4.806069939706137
dgnn valid_ppl= 122.25022
dgnn   ********************
dgnn   Best ppl:122.25022
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 2000
dgnn train_loss= 4.781
dgnn valid_loss= 4.506602704822944
dgnn

Total: 5179


dgnn   bleu-4 = 15.14
dgnn   ********************
dgnn   Best bleu:15.14
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 5500
dgnn train_loss= 3.6047
dgnn valid_loss= 3.5507051182627167
dgnn valid_ppl= 34.83787
dgnn   ********************
dgnn   Best ppl:34.83787
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 6000
dgnn train_loss= 3.5206
dgnn valid_loss= 3.4716236750769354
dgnn valid_ppl= 32.18896
dgnn   ********************
dgnn   Best ppl:32.18896
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 6500
dgnn train_loss= 3.4571
dgnn valid_loss= 3.416497039509136
dgnn valid_ppl= 30.46252
dgnn   ********************
dgnn   Best ppl:30.46252
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   B

Total: 5179


dgnn   bleu-4 = 17.26
dgnn   ********************
dgnn   Best bleu:17.26
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 10500
dgnn train_loss= 3.0238
dgnn valid_loss= 3.200051673841168
dgnn valid_ppl= 24.5338
dgnn   ********************
dgnn   Best ppl:24.5338
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 11000
dgnn train_loss= 2.9868
dgnn valid_loss= 3.178075927163592
dgnn valid_ppl= 24.00053
dgnn   ********************
dgnn   Best ppl:24.00053
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 11500
dgnn train_loss= 2.9883
dgnn valid_loss= 3.163534036884518
dgnn valid_ppl= 23.65404
dgnn   ********************
dgnn   Best ppl:23.65404
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Ba

Total: 5179


dgnn   bleu-4 = 17.88
dgnn   ********************
dgnn   Best bleu:17.88
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 15500
dgnn train_loss= 2.7317
dgnn valid_loss= 3.099601671793917
dgnn valid_ppl= 22.18911
dgnn   ********************
dgnn   Best ppl:22.18911
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 16000
dgnn train_loss= 2.7119
dgnn valid_loss= 3.09990675454699
dgnn valid_ppl= 22.19588
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 16500
dgnn train_loss= 2.7137
dgnn valid_loss= 3.083300565866119
dgnn valid_ppl= 21.83034
dgnn   ********************
dgnn   Best ppl:21.83034
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 17000
dgnn train_los

Total: 5179


dgnn   bleu-4 = 18.13
dgnn   ********************
dgnn   Best bleu:18.13
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 20500
dgnn train_loss= 2.5362
dgnn valid_loss= 3.0542000426012286
dgnn valid_ppl= 21.20422
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 21000
dgnn train_loss= 2.5115
dgnn valid_loss= 3.060889314084771
dgnn valid_ppl= 21.34653
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 21500
dgnn train_loss= 2.519
dgnn valid_loss= 3.0470578870890463
dgnn valid_ppl= 21.05331
dgnn   ********************
dgnn   Best ppl:21.05331
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 22000
dgnn train_loss= 2.5124
dgnn valid_loss= 3.047290673546202
dgnn v

Total: 5179


dgnn   bleu-4 = 18.35
dgnn   ********************
dgnn   Best bleu:18.35
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 25500
dgnn train_loss= 2.3959
dgnn valid_loss= 3.0354243193642305
dgnn valid_ppl= 20.80981
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 26000
dgnn train_loss= 2.3821
dgnn valid_loss= 3.0419563819007105
dgnn valid_ppl= 20.94618
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 26500
dgnn train_loss= 2.3703
dgnn valid_loss= 3.03663705517773
dgnn valid_ppl= 20.83506
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 27000
dgnn train_loss= 2.3939
dgnn valid_loss= 3.0368306610048688
dgnn valid_ppl= 20.83909
dgnn   ********************
dgnn 

Total: 5179


dgnn   bleu-4 = 18.39
dgnn   ********************
dgnn   Best bleu:18.39
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 30500
dgnn train_loss= 2.3178
dgnn valid_loss= 3.0246469806156777
dgnn valid_ppl= 20.58674
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 31000
dgnn train_loss= 2.3065
dgnn valid_loss= 3.0246470089971553
dgnn valid_ppl= 20.58674
dgnn   ********************
dgnn 
***** Running evaluation *****
dgnn   Num examples = 5179
dgnn   Batch size = 32
dgnn global_step= 31500
dgnn train_loss= 2.2997
dgnn valid_loss= 3.0246469877110473
dgnn valid_ppl= 20.58674
dgnn   ********************


KeyboardInterrupt: 

In [19]:
# Calculate bleu
test_sampler = SequentialSampler(test_features)
test_dataloader = DataLoader(test_features, sampler=test_sampler, batch_size=batch_size)

msgr.print_msg("\n***** Running testing *****")
msgr.print_msg("  Num examples = {}".format(len(test_features)))
msgr.print_msg("  Batch size = {}".format(batch_size))

dgnn 
***** Running testing *****
dgnn   Num examples = 10952
dgnn   Batch size = 32


In [20]:
best_ppl_model = output_dir + '/checkpoint-best-ppl/pytorch_model.bin' 
model.load_state_dict(torch.load(best_ppl_model))
model.eval() 
p=[]
for data in tqdm(test_dataloader,total=len(test_dataloader)):
    data = data.to(device)
    subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
    real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
    source_ids = torch.stack(torch.split(data.source_ids, max_source_length))
    source_mask = torch.stack(torch.split(data.source_mask, max_source_length))
    target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
    target_mask = torch.stack(torch.split(data.target_mask, max_target_length))                     
    with torch.no_grad():
        preds = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, subgraph_node_num=subgraph_node_num, 
                        real_graph_num=real_graph_num, batch=data.batch, ptr=data.ptr, source_ids=source_ids, source_mask=source_mask)
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     

HBox(children=(IntProgress(value=0, max=343), HTML(value='')))

Total: 10952


dgnn  bleu-4 = 18.36 
dgnn   ********************


In [21]:
best_bleu_model = output_dir + '/checkpoint-best-bleu/pytorch_model.bin' 
model.load_state_dict(torch.load(best_bleu_model))
model.eval() 
p=[]
for data in tqdm(test_dataloader,total=len(test_dataloader)):
    data = data.to(device)
    subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
    real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
    source_ids = torch.stack(torch.split(data.source_ids, max_source_length))
    source_mask = torch.stack(torch.split(data.source_mask, max_source_length))
    target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
    target_mask = torch.stack(torch.split(data.target_mask, max_target_length))                 
    with torch.no_grad():
        preds = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, subgraph_node_num=subgraph_node_num, 
                        real_graph_num=real_graph_num, batch=data.batch, ptr=data.ptr, source_ids=source_ids, source_mask=source_mask)
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"checkpoint-best-bleu/test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"checkpoint-best-bleu/test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     

HBox(children=(IntProgress(value=0, max=343), HTML(value='')))

Total: 10952


dgnn  bleu-4 = 18.31 
dgnn   ********************


In [22]:
last_model = output_dir + '/checkpoint-last/pytorch_model.bin' 
model.load_state_dict(torch.load(last_model))
model.eval() 
p=[]
for data in tqdm(test_dataloader,total=len(test_dataloader)):
    data = data.to(device)
    subgraph_node_num = torch.stack(torch.split(data.subgraph_node_num, max_subgraph_num))
    real_graph_num = torch.stack(torch.split(data.real_graph_num, 1))
    source_ids = torch.stack(torch.split(data.source_ids, max_source_length))
    source_mask = torch.stack(torch.split(data.source_mask, max_source_length))
    target_ids = torch.stack(torch.split(data.target_ids, max_target_length))
    target_mask = torch.stack(torch.split(data.target_mask, max_target_length))                   
    with torch.no_grad():
        preds = model(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, subgraph_node_num=subgraph_node_num, 
                        real_graph_num=real_graph_num, batch=data.batch, ptr=data.ptr, source_ids=source_ids, source_mask=source_mask)
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"checkpoint-last/test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"checkpoint-last/test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     

HBox(children=(IntProgress(value=0, max=343), HTML(value='')))

Total: 10952


dgnn  bleu-4 = 18.31 
dgnn   ********************
