In [None]:
import yaml
import warnings
from datetime import datetime
import logging
import os
import pandas as pd 
import javalang
from javalang.ast import Node
import torch
import torch.nn.functional as F
from torch.nn import (Module, Embedding, LSTM, Sequential, Linear, BatchNorm1d, ReLU, Sigmoid, CrossEntropyLoss, TransformerDecoderLayer,
                        TransformerDecoder)
import torch.optim as optim
from torch_geometric.nn.glob import GlobalAttention
from torch_geometric.nn import MessagePassing, GatedGraphConv, GCNConv, global_mean_pool, GINEConv, global_add_pool, ResGatedGraphConv, GINEConv
from anytree import AnyNode
from torch_geometric.data import Data, DataLoader, ClusterData, ClusterLoader
from tqdm import tqdm_notebook as tqdm
import numpy as np
from torchsummary import summary
import random
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, DataCollatorWithPadding, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, SequentialSampler
import json
from codebert_seq2seq1 import Seq2Seq
import bleu

warnings.filterwarnings('ignore')

Parameters

In [None]:
config_file = 'config_dgnn.yml'

In [None]:
config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# data source
TRAIN_DIR = config['small_data']['train']
VALID_DIR = config['small_data']['valid']
TEST_DIR = config['small_data']['test']


# prepocess design
max_source_length = config['preprocess']['max_source_length']
max_target_length = config['preprocess']['max_target_length']


# training parameter
batch_size = config['training']['batch_size']
num_epoches = config['training']['num_epoches']
lr = config['training']['lr']
decay_ratio = config['training']['lr']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']

# model design
graph_embedding_size = config['model']['graph_embedding_size']
lstm_hidden_size = config['model']['lstm_hidden_size']
divide_node_num = config['model']['divide_node_num']
gnn_layers_num = config['model']['gnn_layers_num']
lstm_layers_num = config['model']['lstm_layers_num']
decoder_input_size = config['model']['decoder_input_size']
decoder_hidden_size = config['model']['decoder_hidden_size']
decoder_num_layers = config['model']['decoder_num_layers']
decoder_rnn_dropout = config['model']['decoder_rnn_dropout']

# logs
info_prefix = config['logs']['info_prefix']

Logs

In [None]:
run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')
log_file = 'logs/' + run_id + '.log'
exp_dir = 'runs/' + run_id
os.mkdir(exp_dir)

In [None]:
class Info(object):
    def __init__(self, info_prefix=''):
        self.info_prefix = info_prefix
    
    def print_msg(self, msg):
        text = self.info_prefix + ' ' + msg
        print(text)
        logging.info(text)

In [None]:
logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)
msgr = Info(info_prefix)

msgr.print_msg('run_id : {}'.format(run_id))
msgr.print_msg('log_file : {}'.format(log_file))
msgr.print_msg('exp_dir: {}'.format(exp_dir))
msgr.print_msg(str(config))

Prepocess

In [None]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = 'None'
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [None]:
checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
roberta = RobertaModel.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
config = RobertaConfig.from_pretrained(checkpoint)
javalang_special_tokens = ['CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration',
                            'ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type',
                            'BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair',
                            'ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration',
                            'ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration',
                            'VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement',
                            'WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement',
                            'ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement',
                            'BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter',
                            'SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression',
                            'BinaryOperation','Cast','MethodReference','LambdaExpression','Primary','Literal','This',
                            'MemberReference','Invocation','ExplicitConstructorInvocation','SuperConstructorInvocation',
                            'MethodInvocation','SuperMethodInvocation','SuperMemberReference','ArraySelector','ClassReference',
                            'VoidClassReference','Creator','ArrayCreator','ClassCreator','InnerClassCreator','EnumBody',
                            'EnumConstantDeclaration','AnnotationMethod', 'Modifier']
special_tokens_dict = {'additional_special_tokens': javalang_special_tokens}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [None]:
#  generate tree for AST Node
def create_tree(root, node, node_list, sub_id_list, tokenizer, parent=None):
    id = len(node_list)
    node_list.append(node)
    token, children = get_token(node), get_child(node)
    # Use roberta.tokenizer to generate subtokens
    # If a token can be divided into multiple(>1) subtokens, the first subtoken will be set as the previous node, 
    # and the other subtokens will be set as its new children
    token = token.encode('utf-8','ignore').decode("utf-8")
    sub_token_list = tokenizer.tokenize(token)
        
    if id == 0:
        root.token = sub_token_list[0] # the root node is one of the tokenizer's special tokens
        root.data = node
        for child in children:
            create_tree(root, child, node_list, sub_id_list, tokenizer, parent=root)
    else:
        new_node = AnyNode(id=id, token=sub_token_list[0], data=node, parent=parent)
        if len(sub_token_list) > 1:
            sub_id_list.append(id)
            for sub_token in sub_token_list[1:]:
                id += 1
                AnyNode(id=id, token=sub_token, data=node, parent=new_node)
                node_list.append(sub_token)
                sub_id_list.append(id)
        
        for child in children:
            create_tree(root, child, node_list, sub_id_list, tokenizer, parent=new_node)

In [None]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    node_index_list.append(tokenizer.convert_tokens_to_ids(token))
    # node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        variable_token_list.append(node.children[0].token)
        variable_id_list.append(node.children[0].id)   
    
    for child in node.children:
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list)

In [None]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, tokenizer):
    node_list = []
    sub_id_list = [] # record the ids of node that can be divide into multple subtokens
    new_tree = AnyNode(id=0, token=None, data=None)
    create_tree(new_tree, ast, node_list, sub_id_list, tokenizer)
    x = []
    edge_src = []
    edge_tgt = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, tokenizer, edge_src, edge_tgt, variable_token_list, variable_id_list)

    ast_edge_num = len(edge_src)
    edge_attr = [[0] for _ in range(ast_edge_num)]
    # set subtoken edge type to 2
    for i in range(len(edge_attr)):
        if edge_src[i] in sub_id_list and edge_tgt[i] in sub_id_list:
            edge_attr[i] = [2]
    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]

    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    dataflow_edge_num = len(edge_src) - ast_edge_num
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    return x, edge_index, edge_attr

In [None]:
def convert_examples_to_features(examples, tokenizer, mode='GNN_only', stage=None):
    features = []
    for example in examples:
        # pyg
        ast = parse_program(example.source)
        x, edge_index, edge_attr = get_pyg_data_from_ast(ast, tokenizer)

        # source
        if mode == 'GNN_only':
            source_tokens = tokenizer.tokenize('None')
        else:
            source_tokens = tokenizer.tokenize(example.source)[: max_source_length-2]
        source_tokens = [tokenizer.cls_token] + source_tokens + [tokenizer.sep_token]
        source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
        source_mask = [1] * (len(source_ids))
        padding_length = max_source_length - len(source_ids)
        source_ids += [tokenizer.pad_token_id] * padding_length
        source_mask += [0] * padding_length

        # target
        if stage == 'test':
            target_tokens = tokenizer.tokenize('None')
        else:
            target_tokens = tokenizer.tokenize(example.target)[: max_target_length-2]
        target_tokens = [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
        target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
        target_mask = [1] * len(target_ids)
        padding_length = max_target_length - len(target_ids)
        target_ids += [tokenizer.pad_token_id] * padding_length
        target_mask += [0] * padding_length

        features.append(
            Data(
                x= torch.tensor(x, dtype=torch.long),
                edge_index=torch.tensor(edge_index, dtype=torch.long),
                edge_attr=torch.tensor(edge_attr, dtype=torch.long),
                source_ids=torch.tensor(source_ids, dtype=torch.long),
                source_mask=torch.tensor(source_mask, dtype=torch.long),
                target_ids=torch.tensor(target_ids, dtype=torch.long),
                target_mask=torch.tensor(target_mask, dtype=torch.long),
            )
        )
    return features

In [None]:
class Example(object):
    def __init__(self, idx, source, target):
        self.idx = idx
        self.source = source
        self.target = target

In [None]:
# read dataset
def read_examples(filename):
    examples = []
    with open(filename, encoding='utf-8') as f:
        for idx, line in enumerate(f):
            line = line.strip()
            js = json.loads(line)
            if 'idx' not in js:
                js['idx'] = idx
            
            code = js['code']
            nl = ' '.join(js['docstring_tokens']).replace('\n', '')
            nl = ' '.join(nl.strip().split())
            examples.append(
                Example(
                    idx = idx,
                    source = code,
                    target = nl,
                )
            )
    return examples

In [None]:
train_examples = read_examples(TRAIN_DIR)
valid_examples = read_examples(VALID_DIR)
test_examples = read_examples(TEST_DIR)
msgr.print_msg('train size: {}, valid size: {}, test size: {}'.format(len(train_examples), len(valid_examples), len(test_examples)))

In [None]:
train_features = convert_examples_to_features(train_examples, tokenizer, stage='train')
torch.save(train_features, 'features/train_features.pt')

In [17]:
train_dataloader = DataLoader(train_features, batch_size=batch_size)

Model

In [27]:
class GNNEncoder(Module):
    def __init__(self, vocab_len, embedding_dim, num_layers):
        super(GNNEncoder, self).__init__()
        self.device = device
        self.embeddings = Embedding(vocab_len, embedding_dim)
        # only two edge types to be set weights, which are AST edge and data flow edge, subtoken edge
        self.edge_embed = Embedding(3, embedding_dim)
        self.fc1 = Linear(embedding_dim, embedding_dim)
        # GINEConv
        gine_dim = 100
        self.conv1 = GINEConv(Sequential(Linear(embedding_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))
        self.conv2 = GINEConv(Sequential(Linear(gine_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))  
        self.conv3 = GINEConv(Sequential(Linear(gine_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))  
        self.conv4 = GINEConv(Sequential(Linear(gine_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))  
        self.conv5 = GINEConv(Sequential(Linear(gine_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))
#         self.res_ggnn_layer = ResGatedGraphConv(embedding_dim, embedding_dim)
#         self.ggnnlayer = GatedGraphConv(embedding_dim, num_layers)
#         self.mlp_gate = Sequential(Linear(embedding_dim, 1), Sigmoid())
#         self.pool = GlobalAttention(gate_nn=self.mlp_gate)
#         self.fc = Linear(embedding_dim, embedding_dim)
        self.fc = Linear(gine_dim, 768)

            
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.embeddings(x)
        x = x.squeeze(1)

        edge_weight = self.edge_embed(edge_attr)
        edge_weight = self.fc1(edge_weight)
        edge_weight = edge_weight.squeeze(1)
        x = self.conv1(x, edge_index, edge_weight)
        x = self.conv2(x, edge_index, edge_weight)
        x = self.conv3(x, edge_index, edge_weight)
        x = self.conv4(x, edge_index, edge_weight)
        x = self.conv5(x, edge_index, edge_weight)

#         x = self.res_ggnn_layer(x, edge_index, edge_weight)

#         x = self.ggnnlayer(x, edge_index, edge_weight)
        x = global_mean_pool(x, batch)
        x = F.relu(self.fc(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc(x)
        return x


In [29]:
device = torch.device('cuda: 0')
gnn_encoder = GNNEncoder(vocab_len=tokenizer.vocab_size+num_added_toks, embedding_dim=graph_embedding_size, num_layers=gnn_layers_num)
decoder_layer = TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
decoder = TransformerDecoder(decoder_layer, num_layers=6)
model = Seq2Seq(encoder=gnn_encoder, decoder=decoder, config=config, beam_size=10, max_length=max_target_length, 
                sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
model.to(device)

Seq2Seq(
  (encoder): GNNEncoder(
    (embeddings): Embedding(50336, 768)
    (edge_embed): Embedding(3, 768)
    (fc1): Linear(in_features=768, out_features=768, bias=True)
    (conv1): GINEConv(nn=Sequential(
      (0): Linear(in_features=768, out_features=100, bias=True)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=100, out_features=100, bias=True)
      (4): ReLU()
    ))
    (conv2): GINEConv(nn=Sequential(
      (0): Linear(in_features=100, out_features=100, bias=True)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=100, out_features=100, bias=True)
      (4): ReLU()
    ))
    (conv3): GINEConv(nn=Sequential(
      (0): Linear(in_features=100, out_features=100, bias=True)
      (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Line

In [30]:
max_source_length = 256
max_target_length = 32
batch_size = 32
beam_size = 10
lr = 5e-5
warmup_steps = 0
train_steps = 50000
# train_steps = 1000
weight_decay = 0.0
adam_epsilon = 1e-8
valid_steps = 1000
# valid_steps = 200
output_dir = exp_dir

train_url = TRAIN_DIR
valid_url = VALID_DIR
test_url = TEST_DIR

In [31]:
# optimizer and schedule
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                            num_training_steps=train_steps)

In [32]:
from itertools import cycle

#Start training
msgr.print_msg("***** Running training *****")
msgr.print_msg("  Num examples = {}".format(len(train_examples)))
msgr.print_msg("  Batch size = {}".format(batch_size))
msgr.print_msg("  Num epoch = {}".format(batch_size//len(train_examples)))
model.train()
valid_dataset = {}
nb_tr_examples, nb_tr_steps, tr_loss, global_step, best_bleu, best_loss = 0, 0, 0, 0, 0, 1e6
bar = tqdm(range(train_steps), total=train_steps)
train_dataloader = cycle(train_dataloader)

for step in bar:
    data = next(train_dataloader)
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_attr = data.edge_attr.to(device)
    batch = data.batch.to(device)
    target_ids = data.target_ids.to(device)
    target_ids = torch.stack(torch.split(target_ids, max_target_length))
    target_mask = data.target_mask.to(device)
    target_mask = torch.stack(torch.split(target_mask, max_target_length))
    loss, _, _, = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch, target_ids=target_ids, target_mask=target_mask)

    tr_loss += loss.item()
    train_loss = round(tr_loss / (nb_tr_steps + 1), 4) 
    bar.set_description('loss {}'.format(train_loss))
    nb_tr_examples += x.size(0)
    nb_tr_steps += 1
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    global_step += 1
    
    if (global_step + 1) % valid_steps == 0:
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        if 'valid_loss' in valid_dataset:
            valid_examples, valid_features=valid_dataset['valid_loss']
        else:
            valid_examples = read_examples(valid_url)
            valid_features = convert_examples_to_features(valid_examples, tokenizer, stage='valid')
            valid_dataset['valid_loss']=valid_examples, valid_features
        valid_sampler = SequentialSampler(valid_features)
        valid_dataloader = DataLoader(valid_features, sampler=valid_sampler, batch_size=batch_size)

        msgr.print_msg("\n***** Running evaluation *****")
        msgr.print_msg("  Num examples = {}".format(len(valid_examples)))
        msgr.print_msg("  Batch size = {}".format(batch_size))

        #Start Evaling model
        model.eval()
        valid_loss, tokens_num = 0, 0
        for data in valid_dataloader:
            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            edge_attr = data.edge_attr.to(device)
            batch = data.batch.to(device)
            target_ids = data.target_ids.to(device)
            target_ids = torch.stack(torch.split(target_ids, max_target_length))
            target_mask = data.target_mask.to(device)
            target_mask = torch.stack(torch.split(target_mask, max_target_length))            

            with torch.no_grad():
                _,loss,num = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch, target_ids=target_ids, target_mask=target_mask)     
            valid_loss += loss.sum().item()
            tokens_num += num.sum().item()
        #Pring loss of valid dataset    
        model.train()
        valid_loss = valid_loss / tokens_num
        result = {'valid_ppl': round(np.exp(valid_loss), 5),
                    'global_step': global_step+1,
                    'train_loss': round(train_loss, 5),
                 'valid_loss': round(valid_loss, 5)}
        for key in sorted(result.keys()):
            msgr.print_msg("{}= {}".format(key, str(result[key])))
        msgr.print_msg("  "+"*"*20)   
        
        #save last checkpoint
        last_output_dir = os.path.join(output_dir, 'checkpoint-last')
        if not os.path.exists(last_output_dir):
            os.makedirs(last_output_dir)
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(last_output_dir, "pytorch_model.bin")
        torch.save(model_to_save.state_dict(), output_model_file)                    
        if valid_loss < best_loss:
            msgr.print_msg("  Best ppl:{}".format(round(np.exp(valid_loss), 5)))
            msgr.print_msg("  " + "*" * 20)
            best_loss = valid_loss
            # Save best checkpoint for best ppl
            best_output_dir = os.path.join(output_dir, 'checkpoint-best-ppl')
            if not os.path.exists(best_output_dir):
                os.makedirs(best_output_dir)
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(best_output_dir, "pytorch_model.bin")
            torch.save(model_to_save.state_dict(), output_model_file)  
                    
                    
        #Calculate bleu  
        if 'valid_bleu' in valid_dataset:
            valid_examples, valid_features = valid_dataset['valid_bleu']
        else:
            valid_examples = read_examples(valid_url)
            valid_examples = random.sample(valid_examples, min(1000,len(valid_examples)))
            valid_features = convert_examples_to_features(valid_examples, tokenizer,stage='test')  
            valid_dataset['valid_bleu']= valid_examples, valid_features

        valid_sampler = SequentialSampler(valid_features)
        valid_dataloader = DataLoader(valid_features, sampler=valid_sampler, batch_size=batch_size)

        model.eval() 
        p=[]
        # i = 0
        for data in valid_dataloader:
            # print('i', i)
            # i += 1
            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            edge_attr = data.edge_attr.to(device)
            batch = data.batch.to(device)                  
            with torch.no_grad():
                preds = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)  
                for pred in preds:
                    t=pred[0].cpu().numpy()
                    t=list(t)
                    if 0 in t:
                        t=t[:t.index(0)]
                    text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
                    p.append(text)
        model.train()
        predictions=[]
        with open(os.path.join(output_dir,"valid.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"valid.gold"),'w', encoding='utf-8') as f1:
            for ref, gold in zip(p, valid_examples):
                predictions.append(str(gold.idx)+'\t'+ref)
                f.write(str(gold.idx)+'\t'+ref+'\n')
                f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

        (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "valid.gold")) 
        valid_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
        msgr.print_msg("  {} = {}".format("bleu-4", str(valid_bleu)))
        msgr.print_msg("  "+"*"*20)    
        if valid_bleu>best_bleu:
            msgr.print_msg("  Best bleu:{}".format(valid_bleu))
            msgr.print_msg("  "+"*"*20)
            best_bleu=valid_bleu
            # Save best checkpoint for best bleu
            best_output_dir = os.path.join(output_dir, 'checkpoint-best-bleu')
            if not os.path.exists(best_output_dir):
                os.makedirs(best_output_dir)
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(best_output_dir, "pytorch_model.bin")
            torch.save(model_to_save.state_dict(), output_model_file)
        

dgnn ***** Running training *****
dgnn   Num examples = 1000
dgnn   Batch size = 32
dgnn   Num epoch = 0


HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))

AssertionError: 

In [None]:
test_examples = read_examples(test_url)
test_features = convert_examples_to_features(test_examples, tokenizer, stage='test')
# Calculate bleu
test_sampler = SequentialSampler(test_features)
test_dataloader = DataLoader(test_features, sampler=test_sampler, batch_size=batch_size)

best_ppl_model = output_dir + '/checkpoint-best-ppl/pytorch_model.bin' 
model.load_state_dict(torch.load(best_ppl_model))
model.eval() 
p=[]
for batch in tqdm(test_dataloader,total=len(test_dataloader)):
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_attr = data.edge_attr.to(device)
    batch = data.batch.to(device)                     
    with torch.no_grad():
        preds = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)   
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     

In [None]:
best_bleu_model = output_dir + '/checkpoint-best-bleu/pytorch_model.bin' 
model.load_state_dict(torch.load(best_bleu_model))
model.eval() 
p=[]
for batch in tqdm(test_dataloader,total=len(test_dataloader)):
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_attr = data.edge_attr.to(device)
    batch = data.batch.to(device)                 
    with torch.no_grad():
        preds = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)   
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"checkpoint-best-bleu/test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"checkpoint-best-bleu/test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     

In [None]:
last_model = output_dir + '/checkpoint-last/pytorch_model.bin' 
model.load_state_dict(torch.load(last_model))
model.eval() 
p=[]
for batch in tqdm(test_dataloader,total=len(test_dataloader)):
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_attr = data.edge_attr.to(device)
    batch = data.batch.to(device)                   
    with torch.no_grad():
        preds = model(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)   
        for pred in preds:
            t=pred[0].cpu().numpy()
            t=list(t)
            if 0 in t:
                t=t[:t.index(0)]
            text = tokenizer.decode(t,clean_up_tokenization_spaces=False)
            p.append(text)
model.train()
predictions=[]
with open(os.path.join(output_dir,"checkpoint-last/test.output"),'w', encoding='utf-8') as f, open(os.path.join(output_dir,"checkpoint-last/test.gold"),'w', encoding='utf-8') as f1:
    for ref,gold in zip(p,test_examples):
        predictions.append(str(gold.idx)+'\t'+ref)
        f.write(str(gold.idx)+'\t'+ref+'\n')
        f1.write(str(gold.idx)+'\t'+gold.target+'\n')     

(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(output_dir, "test.gold")) 
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
msgr.print_msg(" {} = {} ".format("bleu-4", str(dev_bleu)))
msgr.print_msg("  "+"*"*20)     