In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

import yaml
import os
import warnings
import logging
import torch
import pickle 
import random
import numpy as np
import time
from src import utils
from sklearn.utils import shuffle
from torch.utils.data import Dataset, DataLoader

# 1-Data

In [3]:
config_file = '../config_code2seq.yml'
config = yaml.load(open(config_file), Loader=yaml.FullLoader)
# Data source
DATA_HOME = config['data']['home']
DICT_FILE = DATA_HOME + config['data']['dict']
TRAIN_FILE = DATA_HOME + config['data']['train']
VALID_FILE = DATA_HOME + config['data']['valid']
TEST_FILE = DATA_HOME + config['data']['test']

# Training parameter
batch_size = config['training']['batch_size']
num_epochs = config['training']['num_epochs']
lr = config['training']['lr']
teacher_forcing_rate = config['training']['teacher_forcing_rate']
nesterov = config['training']['nesterov']
weight_decay = config['training']['weight_decay']
momentum = config['training']['momentum']
decay_ratio = config['training']['decay_ratio']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']



# Model parameter
token_size = config['model']['token_size']
hidden_size = config['model']['hidden_size']
num_layers = config['model']['num_layers']
bidirectional = config['model']['bidirectional']
rnn_dropout = config['model']['rnn_dropout']
embeddings_dropout = config['model']['embeddings_dropout']
num_k = config['model']['num_k']

# etc
slack_url_path = config['etc']['slack_url_path']
info_prefix = config['etc']['info_prefix']

In [4]:
# warnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1)
random_state = 22

In [5]:
PAD_TOKEN = '<PAD>' 
BOS_TOKEN = '<S>' 
EOS_TOKEN = '</S>'
UNK_TOKEN = '<UNK>'
PAD = 0
BOS = 1
EOS = 2
UNK = 3

# load vocab dict
with open(DICT_FILE, 'rb') as file:
    subtoken_to_count = pickle.load(file)
    node_to_count = pickle.load(file) 
    target_to_count = pickle.load(file)
    max_contexts = pickle.load(file)
    num_training_examples = pickle.load(file)
    print('Dictionaries loaded.')


Dictionaries loaded.


In [6]:
# display information in the dict
# idx = 0
# for k, v in node_to_count.items():
#     idx += 1
#     if idx == 5:
#         break
#     print(k, v)
print('vocab_size_subtoken：' + str(len(subtoken_to_count.items())))
print('vocab_size_nodes：' + str(len(node_to_count.items())))
print('vocab_size_target：' + str(len(target_to_count.items())))

vocab_size_subtoken：73904
vocab_size_nodes：321
vocab_size_target：11316


In [7]:
    
# making vocab dicts (词汇词典) for terminal subtoken, nonterminal node and target.
word2id = {
    PAD_TOKEN: PAD,
    BOS_TOKEN: BOS,
    EOS_TOKEN: EOS,
    UNK_TOKEN: UNK,
    }

vocab_subtoken = utils.Vocab(word2id=word2id)
vocab_nodes = utils.Vocab(word2id=word2id)
vocab_target = utils.Vocab(word2id=word2id)

vocab_subtoken.build_vocab(list(subtoken_to_count.keys()), min_count=1)
vocab_nodes.build_vocab(list(node_to_count.keys()), min_count=1)
vocab_target.build_vocab(list(target_to_count.keys()), min_count=1)

vocab_size_subtoken = len(vocab_subtoken.id2word)
vocab_size_nodes = len(vocab_nodes.id2word)
vocab_size_target = len(vocab_target.id2word)


print('vocab_size_subtoken：' + str(vocab_size_subtoken))
print('vocab_size_nodes：' + str(vocab_size_nodes))
print('vocab_size_target：' + str(vocab_size_target))

num_length_train = num_training_examples
print('num_examples : ' + str(num_length_train))

vocab_size_subtoken：73908
vocab_size_nodes：325
vocab_size_target：11320
num_examples : 691974


In [108]:
class MyDataset(Dataset):
    def __init__(self, data_path):
        super(MyDataset, self).__init__()

        f = open(data_path, 'r')
        # 存储data_path中所有method的AST
        self.seqs_S = []
        self.seqs_E = []
        self.seqs_N = []
        self.seqs_Y = []
        # 每一个line对应一个java的method
        # 即一个 AST = Start Terminal + Nodes + End Terminals
        for line in f:
            seq_S = []
            seq_N = []
            seq_E = []

            # 每个line由target和syntax_path两部分组成
            target, *syntax_path = f.readline().split(' ')
            # 将target映射到相应idx上，也要划分为subtoken
            target = utils.sentence_to_ids(vocab_target, target.split('|'))

            # 去掉syntax_path中的''和'\n'
            syntax_path = [s for s in syntax_path if s != '' and s != '\n']

            # 如果syntax_path的长度大于num_k，则随机从中选出num_k个node
            if len(syntax_path) > num_k:
                sampled_path_index = random.sample(range(len(syntax_path)), self.num_k)
            else:
                sampled_path_index = range(len(syntax_path))

            # 对于每一个path由三部分组成
            for j in sampled_path_index:
                terminal1, ast_path, terminal2 = syntax_path[j].split(',')

                terminal1 = utils.sentence_to_ids(vocab_subtoken, terminal1.split('|'))
                ast_path = utils.sentence_to_ids(vocab_nodes, ast_path.split('|'))
                terminal2 = utils.sentence_to_ids(vocab_subtoken, terminal2.split('|'))

                # 将3部分拆开来存储
                seq_S.append(terminal1)
                seq_E.append(terminal2)
                seq_N.append(ast_path)

            # 对于每个method即AST，X对应SNE分别存储
            self.seqs_S.append(seq_S)
            self.seqs_N.append(seq_N)
            self.seqs_E.append(seq_E)
            self.seqs_Y.append(target)

    def __getitem__(self, idx):
        return self.seqs_S[idx], self.seqs_N[idx], self.seqs_E[idx], self.seqs_Y[idx]

    def __len__(self):
        return len(self.seqs_Y)


In [109]:
f = open(TRAIN_FILE, 'r')
idx = 0
# seqs_S = []
# seqs_E = []
# seqs_N = []
# seqs_Y = []
print("对于每一个完整的java method，即一个AST：")
print("-----------------------------------------------------------------------------")
for line in f:
    seq_S = []
    seq_N = []
    seq_E = []

    target, *syntax_path = f.readline().split(' ')
    # 将target映射到相应idx上
    target_name = target
    target = utils.sentence_to_ids(vocab_target, target.split('|'))
    
    print("Java Method Name: " + str(target_name))
    print("target length(包EOS): "+str(len(target)))
    print(str(target) + '\n')

    # 去掉syntax_path中的 ‘’和'\n'
    syntax_path = [s for s in syntax_path if s != '' and s != '\n']

    # 如果syntax_path的长度大于num_k，则随机从中选出num_k个node
    if len(syntax_path) > num_k:
        sampled_path_index = random.sample(range(len(syntax_path)) , self.num_k)
    else:
        sampled_path_index = range(len(syntax_path))

    print("AST length: "+str(len(syntax_path)))
    print(str(syntax_path[:3])+'\n')    
    
    # 对于AST下的每一个path
    for j in sampled_path_index:
        if j == 6:
            break
        terminal1, ast_path, terminal2 = syntax_path[j].split(',')
        print("第" + str(j+1) + "条path:")
        print("Start Terminal Name —— " + str(terminal1) )
        print("AST path Name       —— " + str(ast_path) )
        print("End Terminal Name   —— " + str(terminal2) )
        
        terminal1 = utils.sentence_to_ids(vocab_subtoken, terminal1.split('|'))
        ast_path = utils.sentence_to_ids(vocab_nodes, ast_path.split('|'))
        terminal2 = utils.sentence_to_ids(vocab_subtoken, terminal2.split('|')) 
        
        print("Start Terminal vector —— " + str(terminal1) )
        print("AST path vector       —— " + str(ast_path) )
        print("End Terminal vector   —— " + str(terminal2) )
        
        seq_S.append(terminal1)
        seq_N.append(ast_path)
        seq_E.append(terminal2)
    
#     seqs_S.append(seq_S)
#     seqs_N.append(seq_N)
#     seqs_E.append(seq_E)
#     seqs_Y.append(target)
# -----------------------------------------------------------------------------

    
    
    print('-------------------------------------------------------')
    idx += 1
    if idx == 5:
        break

对于每一个完整的java method，即一个AST：
-----------------------------------------------------------------------------
Java Method Name: handle|exception
target length(包EOS): 3
[108, 172, 2]

AST length: 20
['exception|handler,Nm0|SMEx|ClsEx|Cls0,exception', 'exception|handler,Nm0|SMEx|Mth|MarkerExpr|Nm0,response|body', 'exception|handler,Nm0|SMEx|Mth|Cls2,string']

第1条path:
Start Terminal Name —— exception|handler
AST path Name       —— Nm0|SMEx|ClsEx|Cls0
End Terminal Name   —— exception
Start Terminal vector —— [8, 134, 2]
AST path vector       —— [7, 105, 83, 17, 2]
End Terminal vector   —— [8, 2]
第2条path:
Start Terminal Name —— exception|handler
AST path Name       —— Nm0|SMEx|Mth|MarkerExpr|Nm0
End Terminal Name   —— response|body
Start Terminal vector —— [8, 134, 2]
AST path vector       —— [7, 105, 11, 44, 7, 2]
End Terminal vector   —— [89, 465, 2]
第3条path:
Start Terminal Name —— exception|handler
AST path Name       —— Nm0|SMEx|Mth|Cls2
End Terminal Name   —— string
Start Terminal vector 

In [110]:
class DataLoader(object):
    
    def __init__(self, data_path, batch_size, num_k, vocab_subtoken, vocab_nodes, vocab_target, shuffle=True, batch_time=False):
        """
        data_path : path for data 
        num_examples : total lines of data file
        batch_size : batch size
        num_k : max ast pathes included to one examples
        vocab_subtoken : dict of subtoken and its id
        vocab_nodes : dict of node simbol and its id
        vocab_target : dict of target simbol and its id
        """
        self.data_path = data_path
        self.batch_size = batch_size
        
        self.num_examples = self.file_count(data_path)
        self.num_k = num_k
        
        self.vocab_subtoken = vocab_subtoken
        self.vocab_nodes = vocab_nodes
        self.vocab_target = vocab_target
        
        self.index = 0
        self.pointer = np.array(range(self.num_examples))
        self.shuffle = shuffle
        
        self.batch_time = batch_time
        
        self.reset()
        
    def read_batch(self, ids):
        seqs_S = []
        seqs_E = []
        seqs_N = []
        seqs_Y = []
        
        f = open(TRAIN_FILE, 'r')
        idx = 0
        for line in f:
            seq_S = []
            seq_N = []
            seq_E = []

            target, *syntax_path = f.readline().split(' ')
            # 将target映射到相应idx上
            target = utils.sentence_to_ids(self.vocab_target, target.split('|'))

            # 去掉syntax_path中的 ‘’和'\n'
            syntax_path = [s for s in syntax_path if s != '' and s != '\n']

            # 如果syntax_path的长度大于num_k，则随机从中选出num_k个node
            if len(syntax_path) > num_k:
                sampled_path_index = random.sample(range(len(syntax_path)) , self.num_k)
            else:
                sampled_path_index = range(len(syntax_path))

            # 对于每一个path
            for j in sampled_path_index:
                terminal1, ast_path, terminal2 = syntax_path[j].split(',')

                terminal1 = utils.sentence_to_ids(self.vocab_subtoken, terminal1.split('|'))
                ast_path = utils.sentence_to_ids(self.vocab_nodes, ast_path.split('|'))
                terminal2 = utils.sentence_to_ids(self.vocab_subtoken, terminal2.split('|')) 

                seq_S.append(terminal1)
                seq_E.append(terminal2)
                seq_N.append(ast_path)
                
            seqs_S.append(seq_S)
            seqs_N.append(seq_N)
            seqs_E.append(seq_E)
            seqs_Y.append(target)
            
    return seqs_S, seqs_N, seqs_E, seqs_Y

    def reset(self):
        if self.shuffle:
            self.pointer = shuffle(self.pointer)
        self.index = 0 

SyntaxError: 'return' outside function (<ipython-input-110-96dd5973cc68>, line 74)

# 3-Loss Function & Optimizer