In [None]:
import yaml
import warnings
from datetime import datetime
import logging
import os
import pandas as pd 
import javalang
from javalang.ast import Node
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.nn.modules.linear import Linear
from torch_geometric.nn.glob import GlobalAttention
from torch_geometric.nn import MessagePassing, GatedGraphConv, GCNConv, global_mean_pool
from anytree import AnyNode
from torch_geometric.data import Data, DataLoader, ClusterData, ClusterLoader
from tqdm import tqdm_notebook as tqdm
import numpy as np
from torchsummary import summary
from ignite.metrics.nlp import Bleu

warnings.filterwarnings('ignore')

Paramerters

In [None]:
config_file = 'config_dgnn.yml'

In [None]:
config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# data source
TRAIN_DIR = config['data']['train']
VALID_DIR = config['data']['valid']
TEST_DIR = config['data']['test']


# prepocess design
# max_seq_len = config['preprocess']['max_seq_len']

# training parameter
batch_size = config['training']['batch_size']
num_epoches = config['training']['num_epoches']
lr = config['training']['lr']
decay_ratio = config['training']['lr']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']

# model design
graph_embedding_size = config['model']['graph_embedding_size']
lstm_hidden_size = config['model']['lstm_hidden_size']
divide_node_num = config['model']['divide_node_num']
gnn_layers_num = config['model']['gnn_layers_num']
lstm_layers_num = config['model']['lstm_layers_num']
decoder_input_size = config['model']['decoder_input_size']
decoder_hidden_size = config['model']['decoder_hidden_size']
decoder_num_layers = config['model']['decoder_num_layers']
decoder_rnn_dropout = config['model']['decoder_rnn_dropout']

# logs
info_prefix = config['logs']['info_prefix']

Logs

In [None]:
run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')
log_file = 'logs/' + run_id + '.log'
exp_dir = 'runs/' + run_id
os.mkdir(exp_dir)

In [None]:
class Info(object):
    def __init__(self, info_prefix=''):
        self.info_prefix = info_prefix
    
    def print_msg(self, msg):
        text = self.info_prefix + ' ' + msg
        print(text)
        logging.info(text)

In [None]:
logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)
msgr = Info(info_prefix)

msgr.print_msg('run_id : {}'.format(run_id))
msgr.print_msg('log_file : {}'.format(log_file))
msgr.print_msg('exp_dir: {}'.format(exp_dir))
msgr.print_msg(str(config))

Sequence Preprocess

In [None]:
# define four extra keywords 
PAD_TOKEN = '<PAD>'
BOS_TOKEN = '<S>'
EOS_TOKEN = '</S>'
UNK_TOKEN = '<UNK>'
PAD = 0
BOS = 1
EOS = 2
UNK = 3

In [None]:
# read dataset
train_data = pd.read_json(path_or_buf=TRAIN_DIR, lines=True)
valid_data = pd.read_json(path_or_buf=VALID_DIR, lines=True)
test_data = pd.read_json(path_or_buf=TEST_DIR, lines=True)

In [None]:
train_data.head(5)

In [None]:
msgr.print_msg('train size: {}, valid size: {}, test size: {}'.format(len(train_data), len(valid_data), len(test_data)))

In [None]:
# define vocab class
class Vocab(object):
    def __init__(self, word2id={}):
        self.word2id = dict(word2id)
        self.id2word = {v: k for k, v in self.word2id.items()}

    def build_vocab(self, sentences, min_count=1):
        word_counter = {}
        for word in sentences:
            word_counter[word] = word_counter.get(word, 0) + 1
        
        for word, count in sorted(word_counter.items(), key=lambda x: -x[1]):
            if count < min_count:
                break
            _id = len(self.word2id)
            self.word2id.setdefault(word, _id)
            self.id2word[_id] = word

In [None]:
# construct two vocabulary for ast nodes and natural language repectively
word2id = {
    PAD_TOKEN: PAD,
    BOS_TOKEN: BOS,
    EOS_TOKEN: EOS,
    UNK_TOKEN: UNK,
}

vocab_astnodes = Vocab(word2id=word2id)
vocab_nl = Vocab(word2id=word2id)



In [None]:
# use `docstring_tokens` in train to generate natural language corpus
nl_tokens = []
for nl_token in train_data['docstring_tokens']:
    nl_tokens.extend(nl_token)

# set the all natural language to lowercase
for i in range(len(nl_tokens)):
    nl_tokens[i] = nl_tokens[i].lower()

vocab_nl.build_vocab(nl_tokens, min_count=0)

In [None]:
vocab_nl_size = len(vocab_nl.id2word)
msgr.print_msg('vocab_nl_size: ' + str(vocab_nl_size))

In [None]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = ''
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [None]:
# use train data to construction ast nodes corpus
astnodes_tokens = []
for code in tqdm(train_data['code']):
    sequence = []
    get_sequence(parse_program(code), sequence)
    astnodes_tokens.extend(sequence)

vocab_astnodes.build_vocab(astnodes_tokens, min_count=0)

In [None]:
vocab_astnodes_size = len(vocab_astnodes.id2word)
msgr.print_msg('vocab_astnodes_size: ' + str(vocab_astnodes_size))

In [None]:
len(set(astnodes_tokens))

In [None]:
len(astnodes_tokens)

In [None]:
astnodes_tokens_set = list(set(astnodes_tokens))
len(astnodes_tokens_set)

In [None]:
import re

def camel_case_split(identifier):
    matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    return [m.group(0) for m in matches]

In [None]:
camel_case_split('CamelCaseXYZ')

In [None]:
all_tokens = []
for token in tqdm(astnodes_tokens_set):
    if '_' in token:
        all_tokens.extend(token.split('_'))
    else:
        all_tokens.extend(camel_case_split(token))

In [None]:
len(all_tokens)

In [None]:
len(set(all_tokens))

In [None]:
# transform sentence to ids
def sentence_to_ids(vocab, sentence):
    ids = [vocab.word2id.get(word.lower(), UNK) for word in sentence]
    ids += [EOS]
    return ids

# transform ids to sentence
def ids_to_sentence(vocab, ids):
    return [vocab.id2word[_id] for _id in ids]

# pad sequence 
def pad_seq(seq, max_length):
    if len(seq) >= max_length:
        return seq[0: max_length]
    res = seq + [PAD for i in range(max_length - len(seq))]
    return res


Add Dataflow to AST to generate D-AST

In [None]:
#  generate tree for AST Node
def create_tree(root, node, node_list, parent=None):
    id = len(node_list)
    token, children = get_token(node), get_child(node)
    if id == 0:
        root.token = token
        root.data = node
    else:
        new_node = AnyNode(id=id, token=token, data=node, parent=parent)
    node_list.append(node)
    for child in children:
        if id == 0:
            create_tree(root, child, node_list, parent=root)
        else:
            create_tree(root, child, node_list, parent=new_node)

In [None]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, vocab_dict, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    # print('token', token)
    node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        variable_token_list.append(node.children[0].token)
        variable_id_list.append(node.children[0].id)
    for child in node.children:
        # print('child', child.token)
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, vocab_dict, src, tgt, variable_token_list, variable_id_list)

In [None]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, vocab_dict):
    node_list = []
    new_tree = AnyNode(id=0, token=None, data=None)
    create_tree(new_tree, ast, node_list)
    x = []
    edge_src = []
    edge_tgt = []
    edge_attr = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, vocab_dict, edge_src, edge_tgt, variable_token_list, variable_id_list)
    # print('variable_token_list', variable_token_list)
    # print('variable_id_list', variable_id_list)

    ast_edge_num = len(edge_src)
    # print('ast_edge_num', ast_edge_num)
    # set ast edge type to 0
    for _ in range(ast_edge_num):
        edge_attr.append([0])

    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    
    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    dataflow_edge_num = len(edge_src) - ast_edge_num
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    # print('dataflow_edge_num', dataflow_edge_num)
    return x, edge_index, edge_attr

Batch Data

In [None]:
def transform_to_pygdata(data):
    pyg_datas = []
    for i in range(len(data)):
        ast = parse_program(data['code'][i])
        label = sentence_to_ids(vocab_nl, data['docstring_tokens'][i])
        x, edge_index, edge_attr = get_pyg_data_from_ast(ast, vocab_astnodes)
        pyg_datas.append(
            Data(x=torch.tensor(x, dtype=torch.long),
                       edge_index=torch.tensor(edge_index, dtype=torch.long),
                       edge_attr=torch.tensor(edge_attr, dtype=torch.long),
                       y=torch.tensor(pad_seq(label, max_seq_len), dtype=torch.long)),
           )
    return pyg_datas


In [None]:
train_pygdata = transform_to_pygdata(train_data)
valid_pygdata = transform_to_pygdata(valid_data)
test_pygdata = transform_to_pygdata(test_data)

In [None]:
train_loader = DataLoader(train_pygdata, batch_size=batch_size)
valid_loader = DataLoader(valid_pygdata, batch_size=batch_size)
test_loader = DataLoader(test_pygdata, batch_size=batch_size)

Model

In [None]:
# partitioning D-AST in model instead of data-prepocessing
# partitioning D-AST in model by the num of nodes, which is set in the hyper-parameter `divide_node_num`
# all codes are started with node `MethodDeclation`, and we use it as super-node that kept in all sub graphs

In [None]:
class SequenceGNNEncoder(torch.nn.Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num,
                    decoder_input_size, device):
        super(SequenceGNNEncoder, self).__init__()
        self.device = device
        self.embed = nn.Embedding(vocab_len, graph_embedding_size, padding_idx=PAD)
        self.edge_embed = nn.Embedding(2, 1) # only two edge types to be set weights, which are AST edge and data flow edge
        self.ggnnlayer = GatedGraphConv(graph_embedding_size, gnn_layers_num)
        self.mlp_gate = nn.Sequential(
            nn.Linear(graph_embedding_size, 300), nn.Sigmoid(), nn.Linear(300, 1), nn.Sigmoid())
        self.pool = GlobalAttention(gate_nn=self.mlp_gate)
        self.divide_node_num = divide_node_num
        self.lstm = nn.LSTM(input_size=graph_embedding_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers_num)
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_layers_num = lstm_layers_num
        self.fc = nn.Linear(graph_embedding_size + lstm_hidden_size, decoder_input_size)

    def subgraph_forward(self, x, edge_index, edge_attr, batch):
        if type(edge_attr) == type(None):
            edge_weight = None
        else:
            edge_weight = self.edge_embed(edge_attr)
            edge_weight = edge_weight.squeeze(1)
        x = self.ggnnlayer(x, edge_index, edge_weight)
        return self.pool(x, batch=batch)
    
    # partitioning multiple subgraphs by dynamic allocating edges
    def partition_graph(self, x, edge_index, edge_attr, batch):        
        nodes_list = [] # record all nodes number for each subgraph in total batch
        graph_pos_in_batch, graph_length = self.get_subgraph_info_from_batch(batch)
        max_seq_len = max(graph_length)
        subgraph_num = int(max_seq_len/self.divide_node_num) + 1
        for i in range(subgraph_num):
            nodes = []
            for j in range(len(graph_pos_in_batch)):
                if graph_length[j] > i * self.divide_node_num:
                    if graph_length[j] > (i+1) * self.divide_node_num:
                        subgraph_len = self.divide_node_num
                    else:
                        subgraph_len = graph_length[j] - i * self.divide_node_num   
                    for m in range(subgraph_len):
                        nodes.append(graph_pos_in_batch[j] + m)          
            nodes_list.append(set(nodes)) 
        # only count the edge whose target node in subgraph
        sub_edge_src = [[] for _ in range(subgraph_num)]
        sub_edge_tgt = [[] for _ in range(subgraph_num)]
        sub_edge_attr = [[] for _ in range(subgraph_num)]
        # print('nodes_list', nodes_list)
        node_num = len(x)
        node_subgraph_index = [0 for _ in range(node_num)] # use a list to store the subgraph numbers for all nodes
        for i in range(len(nodes_list)):
            for node in nodes_list[i]:
                node_subgraph_index[node] = i
    
        for i in range(len(edge_index[1])):
            src = edge_index[0][i].item()
            tgt = edge_index[1][i].item()
            sub_edge_src[node_subgraph_index[tgt]].append(src)
            sub_edge_tgt[node_subgraph_index[tgt]].append(tgt)
            sub_edge_attr[node_subgraph_index[tgt]].append(edge_attr[i].item())
        edge_index_list = []
        edge_attr_list = []
        for i in range(subgraph_num):
            edge_index_list.append(torch.tensor([sub_edge_src[i], sub_edge_tgt[i]], dtype=torch.long))
            edge_attr_list.append(torch.tensor(sub_edge_attr[i], dtype=torch.long))
        return edge_index_list, edge_attr_list

    def get_subgraph_info_from_batch(self, batch):
        comp = 0
        pos = 0
        graph_pos_in_batch = [0] # record begin positions and end positions of every subgraph
        graph_length = [] # use a list to store the node nums in subgraph
        for i in range(len(batch)):
            if batch[i] != comp:
                graph_pos_in_batch.append(i)
                graph_length.append(i-pos)
                comp = batch[i]
                pos = i
                graph_length.append(len(batch)-pos)
        return graph_pos_in_batch, graph_length        

    def forward(self, x, edge_index, edge_attr, batch):
        edge_index_list, edge_attr_list = self.partition_graph(x, edge_index, edge_attr, batch)
        x = self.embed(x)
        x = x.squeeze(1)
        subgraph_pool_list = [
            self.subgraph_forward(x, edge_index_list[i].to(self.device), edge_attr_list[i].to(self.device), batch)
            for i in range(len(edge_index_list))
        ]
        graph_pool = self.subgraph_forward(x, edge_index, edge_attr, batch)
        subgraph_pool_seq = torch.stack(subgraph_pool_list)
        h0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        c0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        subgraph_output, (_, _) = self.lstm(subgraph_pool_seq, (h0, c0))
        return self.fc(torch.cat((subgraph_output[-1], graph_pool), dim=1))
        
            

In [None]:
class RNNDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, rnn_dropout, device):
        super(RNNDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, dropout=rnn_dropout)
        self.fc = nn.Linear(hidden_size, output_size)
        self.device = device
    
    def forward(self, input, hidden):
        output, _ = self.gru(input, hidden)
        output = self.fc(output)
        return output

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num, 
    decoder_input_size, decoder_hidden_size, decoder_output_size, decoder_num_layers, decoder_rnn_dropout, max_seq_len, device):
        super(EncoderDecoder, self).__init__()
        self.encoder = SequenceGNNEncoder(vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num, 
    decoder_input_size, device)
        self.decoder = RNNDecoder(decoder_input_size, decoder_hidden_size, decoder_output_size, decoder_num_layers, decoder_rnn_dropout, device)
        self.max_seq_len = max_seq_len
        self.decoder_num_layers = decoder_num_layers
        self.decoder_hidden_size = decoder_hidden_size
        self.device = device
    
    def forward(self, x, edge_index, edge_attr, batch):
        decoder_input = self.encoder(x, edge_index, edge_attr, batch)
        # print('decoder_input', decoder_input)
        decoder_input = decoder_input.unsqueeze(0)
        decoder_input = decoder_input.expand(self.max_seq_len, -1, -1)
        decoder_h0 = torch.zeros(self.decoder_num_layers, decoder_input.size(1), self.decoder_hidden_size).to(device)
        return self.decoder(decoder_input, decoder_h0)             


Training

In [None]:
device = torch.device('cuda:0')

model_args = {
    'vocab_len': vocab_astnodes_size,
    'graph_embedding_size': graph_embedding_size,
    'gnn_layers_num': gnn_layers_num,
    'lstm_layers_num': lstm_layers_num,
    'lstm_hidden_size': lstm_hidden_size,
    'divide_node_num': divide_node_num,
    'decoder_input_size': decoder_input_size,
    'decoder_hidden_size': decoder_hidden_size,
    'decoder_output_size': vocab_nl_size,
    'decoder_num_layers': decoder_num_layers,
    'decoder_rnn_dropout': decoder_rnn_dropout,
    'max_seq_len': max_seq_len,
    'device': device
}

model = EncoderDecoder(**model_args).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda= lambda epoch: decay_ratio ** epoch)

In [None]:
model_summary = summary(model)

In [None]:
mce = nn.CrossEntropyLoss(size_average=False, ignore_index=PAD)
def masked_cross_entropy(logits, target):
    return mce(logits.view(-1, logits.size(-1)), target.view(-1))
metric = Bleu(ngram=4, smooth='smooth1')

In [None]:
def compute_loss(data, model, optimizer=None, is_train=True):
    x = (data.x).to(device)
    edge_index = (data.edge_index).to(device)
    edge_attr = (data.edge_attr).to(device)
    batch = (data.batch).to(device)
    y = (data.y).to(device)
    pred_y = model(x, edge_index, edge_attr, batch)
    y = torch.stack(torch.split(y, max_seq_len))
    loss = masked_cross_entropy(pred_y.contiguous(), y.contiguous())

    if is_train:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    y = y.contiguous().data.cpu().tolist()
    pred = pred_y.max(dim=-1)[1].data.cpu().numpy().T.tolist()

    return loss.item(), y, pred

In [None]:
def compute_bleu4(metric, refs, hyps):
    metric.reset()
    for i in range(len(refs)):
        metric.update((hyps[i], [refs[i]]))
    return metric.compute()

In [None]:
class EarlyStopping(object):
    def __init__(self, filename = None, patience=3, warm_up=0, verbose=False):

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.warm_up = warm_up
        self.filename = filename

    def __call__(self, score, model, epoch):

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(score, model)
            
        elif (score <= self.best_score) and (epoch > self.warm_up) :
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            if (epoch <= self.warm_up):
                print('Warming up until epoch', self.warm_up)
            
            else:
                if self.verbose:
                    print(f'Score improved. ({self.best_score:.6f} --> {score:.6f}).')
                
                self.best_score = score
                self.save_checkpoint(score, model)
                self.counter = 0

    def save_checkpoint(self, score, model):
        
        if self.filename is not None:
            torch.save(model.state_dict(), self.filename)
            
        if self.verbose:
            print('Model saved...')

In [None]:
fname = exp_dir + save_name
early_stopping = EarlyStopping(fname, patience, warm_up, verbose=True)

In [None]:
for epoch in range(1, num_epoches + 1):
    train_loss = 0.
    train_refs = []
    train_hyps = []
    valid_loss = 0.
    valid_refs = []
    valid_hyps = []

    # train
    model.train()
    for data in tqdm(train_loader, total=len(train_loader), desc='TRAIN'):
        loss, gold, pred = compute_loss(data, model, optimizer, is_train=True)
        train_loss += loss
        train_refs += gold
        train_hyps += pred
    

    # valid
    model.eval()
    for data in tqdm(valid_loader, total=len(valid_loader), desc='VALID'):
        loss, gold, pred = compute_loss(data, model, optimizer, is_train=False)
        valid_loss += loss
        valid_refs += gold
        valid_hyps += pred
    
    
    train_loss = np.sum(train_loss) / len(train_data)
    valid_loss = np.sum(valid_loss) / len(valid_data)
    train_bleu4 = compute_bleu4(metric, train_refs, train_hyps)
    valid_bleu4 = compute_bleu4(metric, valid_refs, valid_hyps)    

    msgr.print_msg('Epoch {}: train_loss: {:5.2f}  train_bleu4: {:2.4f}  valid_loss: {:5.2f}  valid_bleu4: {:2.4f}'.format(
            epoch, train_loss, train_bleu4, valid_loss, valid_bleu4))
    
    early_stopping(valid_bleu4, model, epoch)
    if early_stopping.early_stop:
        msgr.print_msg("Early stopping")
        break
    
    print('-'*80)
    scheduler.step()


Test

In [None]:
model = EncoderDecoder(**model_args).to(device)
fname = exp_dir + save_name
ckpt = torch.load(fname)
model.load_state_dict(ckpt)
model.eval()

test_refs = []
test_hyps = []

for data in tqdm(test_loader, total=len(test_loader), desc='TEST'):
    x = (data.x).to(device)
    edge_index = (data.edge_index).to(device)
    edge_attr = (data.edge_attr).to(device)
    batch = (data.batch).to(device)
    y = (data.y).to(device)
    pred_y = model(x, edge_index, edge_attr, batch)
    y = torch.stack(torch.split(y, max_seq_len))
    pred = pred_y.max(dim=-1)[1].data.cpu().numpy().T.tolist()
    test_refs += y
    test_hyps += pred

test_bleu4 = compute_bleu4(metric, test_refs, test_hyps)
msgr.print_msg('test_bleu4: {:2.4f}'.format(test_bleu4))