In [1]:
import yaml
import warnings
from datetime import datetime
import logging
import os
import pandas as pd 
import javalang
from javalang.ast import Node
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.nn.modules.linear import Linear
from torch_geometric.nn.glob import GlobalAttention
from torch_geometric.nn import MessagePassing, GatedGraphConv, GCNConv, global_mean_pool
from anytree import AnyNode
from torch_geometric.data import Data, DataLoader, ClusterData, ClusterLoader
from tqdm import tqdm_notebook as tqdm
import numpy as np
from torchsummary import summary
import random

warnings.filterwarnings('ignore')

Paramerters

In [2]:
config_file = 'config_dgnn.yml'

In [3]:
config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# data source
raw_code_url = config['ccd_data']['data']
train_url = config['ccd_data']['train']
valid_url = config['ccd_data']['valid']
test_url = config['ccd_data']['test']

# training parameter
# batch_size = config['training']['batch_size']
batch_size = 64
num_epoches = config['training']['num_epoches']
lr = config['training']['lr']
decay_ratio = config['training']['lr']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']

# model design
graph_embedding_size = config['model']['graph_embedding_size']
lstm_hidden_size = config['model']['lstm_hidden_size']
divide_node_num = config['model']['divide_node_num']
gnn_layers_num = config['model']['gnn_layers_num']
lstm_layers_num = config['model']['lstm_layers_num']
decoder_input_size = config['model']['decoder_input_size']

# logs
info_prefix = config['logs']['info_prefix']

Logs

In [4]:
run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')
log_file = 'logs/' + run_id + '.log'
exp_dir = 'runs/' + run_id
os.mkdir(exp_dir)

In [5]:
class Info(object):
    def __init__(self, info_prefix=''):
        self.info_prefix = info_prefix
    
    def print_msg(self, msg):
        text = self.info_prefix + ' ' + msg
        print(text)
        logging.info(text)

In [6]:
logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)
msgr = Info(info_prefix)

msgr.print_msg('run_id : {}'.format(run_id))
msgr.print_msg('log_file : {}'.format(log_file))
msgr.print_msg('exp_dir: {}'.format(exp_dir))
msgr.print_msg(str(config))

dgnn run_id : 2021-07-23--08-26-16
dgnn log_file : logs/2021-07-23--08-26-16.log
dgnn exp_dir: runs/2021-07-23--08-26-16
dgnn {'data': {'train': '/data/code/represent-code-in-human/data/code-summarization-new/train.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-new/valid.jsonl', 'test': '/data/code/represent-code-in-human/data/code-summarization-new/test.jsonl'}, 'small_data': {'train': '/data/code/represent-code-in-human/data/code-summarization-small/train.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-small/valid.jsonl', 'test': '/data/code/represent-code-in-human/data/code-summarization-small/test.jsonl'}, 'middle_data': {'train': '/data/code/represent-code-in-human/data/code-summarization-middle/train.jsonl', 'valid': '/data/code/represent-code-in-human/data/code-summarization-middle/valid.jsonl', 'test': '/data/code/represent-code-in-human/data/code-summarization-middle/test.jsonl'}, 'ccd_data': {'data': '/data/dataset/Co

Sequence Preprocess

In [7]:
# define one extra keywords 
UNK_TOKEN = '<UNK>'
UNK = 0

In [8]:
raw_code = pd.read_json(path_or_buf=raw_code_url, lines=True)
raw_code = raw_code.set_index('idx')
raw_code.head()

Unnamed: 0_level_0,func
idx,Unnamed: 1_level_1
10000832,public static void main(String[] args) {\n...
10005623,public synchronized String getSerialNumber...
10005624,public Object run() {\n ...
10005674,public String post() {\n if (conten...
10005879,@Override\n public void onCreate(Bundle...


In [9]:
raw_code_index = raw_code.index.tolist()

over_length_ids = [10000832, 10151623, 10540676, 10690321, 10717656, 10793825, 10934628, 11339042, 11603577, 11940679, 12082150, 12119068, 12335897, 12352751, 12389873, 12415477, 12838273, 12838274, 12914531, 13099033, 13292215, 13292580, 13456795, 1349815, 13563706, 13586003, 13644040, 13650923, 13747998, 13776078, 13961100, 14020143, 14190765, 14310068, 14661394, 1477292, 14794142, 15295408, 15453012, 15723802, 15758923, 16002345, 16002347, 16006791, 16142591, 1637147, 16518661, 17314208, 17551920, 17573144, 17750515, 17829989, 17874921, 17917053, 18284812, 18433984, 18574455, 18790182, 18822890, 19074021, 19090289, 19276021, 19276022, 19340788, 19382420, 19434890, 19434892, 1944490, 19478367, 19556732, 1962490, 19634773, 19841853, 19942676, 20122631, 20275058, 2067794, 20833509, 20856391, 20885480, 21028028, 21161120, 21318345, 21493541, 2157431, 21652119, 22031237, 22033685, 22035132, 22114133, 22222255, 2235431, 2247987, 22580642, 22673614, 2285441, 23041161, 23094550, 23188198, 23248619, 23370708, 23510383, 23611768, 2450, 2476569, 2511576, 2769195, 2771573, 2771574, 285947, 2996859, 325062, 3375714, 3375715, 3375723, 3400236, 37044, 3867253, 416857, 4420769, 4494367, 4581365, 4660318, 4681906, 4780347, 4792385, 4854974, 5021563, 5189131, 5252227, 5389524, 5430189, 552318, 557726, 5620476, 564191, 5691586, 592597, 6008880, 6121196, 6147227, 6304372, 6304373, 6333737, 6403868, 6644748, 6961579, 6966398, 7005223, 7149578, 726690, 7300257, 7300264, 7300267, 733283, 7394826, 7665877, 7687037, 7727956, 793122, 8079516, 8109022, 8335944, 83802, 8581437, 8641070, 9050003, 9221721, 9261908, 9530015, 9581835, 961457, 9647574, 9687064, 9705209, 9782243, 979163, 98309, 98428, 9980609]
# over_length_ids = [10690321, 10717656, 10793825, 11339042, 12119068, 12352751, 12415477, 13099033, 15723802, 1637147, 16518661, 17750515, 18822890, 19276021, 2067794, 20833509, 20856391, 20885480, 21028028, 22033685, 22035132, 2235431, 2247987, 22580642, 2285441, 23094550, 23248619, 2996859, 325062, 3375723, 3867253, 4494367, 4660318, 4780347, 5021563, 5189131, 5252227, 564191, 6147227, 7005223, 7300267, 793122, 8109022, 8581437, 9221721, 9647574, 98309]

def read_ccd_pairs(url):
    data = []
    with open(url) as f:
        for line in f:
            line = line.strip()
            id1, id2, label = line.split('\t')
            if int(id1) not in raw_code_index or int(id2) not in raw_code_index or int(id1) in over_length_ids or int(id2) in over_length_ids:
                continue
            label = 0 if label == '0' else 1
            data.append((int(id1), int(id2), label))
    return data

In [10]:
train_data = read_ccd_pairs(train_url)
valid_data = read_ccd_pairs(valid_url)
test_data = read_ccd_pairs(test_url)

msgr.print_msg('train size: {}, valid size: {}, test size: {}'.format(len(train_data), len(valid_data), len(test_data)))

dgnn train size: 864660, valid size: 402753, test size: 395605


In [11]:
train_data_small = random.sample(train_data, 1000)
valid_data_small = random.sample(valid_data, 500)
test_data_small = random.sample(test_data, 500)

In [12]:
# define vocab class
class Vocab(object):
    def __init__(self, word2id={}):
        self.word2id = dict(word2id)
        self.id2word = {v: k for k, v in self.word2id.items()}

    def build_vocab(self, sentences, min_count=1):
        word_counter = {}
        for word in sentences:
            word_counter[word] = word_counter.get(word, 0) + 1
        
        for word, count in sorted(word_counter.items(), key=lambda x: -x[1]):
            if count < min_count:
                break
            _id = len(self.word2id)
            self.word2id.setdefault(word, _id)
            self.id2word[_id] = word

In [13]:
# construct two vocabulary for ast nodes and natural language repectively
word2id = {
    UNK_TOKEN: UNK
}
vocab_astnodes = Vocab(word2id=word2id)


In [14]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = ''
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [15]:
# use train data to construction ast nodes corpus
astnodes_tokens = []
train_ids = []
for i in range(len(train_data)):
    train_ids.append(train_data[i][0])
    train_ids.append(train_data[i][1])

train_ids = list(set(train_ids))

for id in train_ids:
    sequence = []
    get_sequence(parse_program(raw_code['func'][id]), sequence)
    astnodes_tokens.extend(sequence)

vocab_astnodes.build_vocab(astnodes_tokens, min_count=0)

In [16]:
vocab_astnodes_size = len(vocab_astnodes.id2word)
msgr.print_msg('vocab_astnodes_size: ' + str(vocab_astnodes_size))

dgnn vocab_astnodes_size: 54866


Add Dataflow to AST to generate D-AST

In [17]:
#  generate tree for AST Node
def create_tree(root, node, node_list, parent=None):
    id = len(node_list)
    token, children = get_token(node), get_child(node)
    if id == 0:
        root.token = token
        root.data = node
    else:
        new_node = AnyNode(id=id, token=token, data=node, parent=parent)
    node_list.append(node)
    for child in children:
        if id == 0:
            create_tree(root, child, node_list, parent=root)
        else:
            create_tree(root, child, node_list, parent=new_node)

In [18]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, vocab_dict, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    # print('token', token)
    node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        variable_token_list.append(node.children[0].token)
        variable_id_list.append(node.children[0].id)
    for child in node.children:
        # print('child', child.token)
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, vocab_dict, src, tgt, variable_token_list, variable_id_list)

In [19]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, vocab_dict):
    node_list = []
    new_tree = AnyNode(id=0, token=None, data=None)
    create_tree(new_tree, ast, node_list)
    x = []
    edge_src = []
    edge_tgt = []
    edge_attr = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, vocab_dict, edge_src, edge_tgt, variable_token_list, variable_id_list)
    # print('variable_token_list', variable_token_list)
    # print('variable_id_list', variable_id_list)

    ast_edge_num = len(edge_src)
    # print('ast_edge_num', ast_edge_num)
    # set ast edge type to 0
    for _ in range(ast_edge_num):
        edge_attr.append([0])

    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    
    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    dataflow_edge_num = len(edge_src) - ast_edge_num
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    # print('dataflow_edge_num', dataflow_edge_num)
    return x, edge_index, edge_attr

Batch Data

In [20]:
def transform_to_pygdata(code):
    ast = parse_program(code)
    x, edge_index, edge_attr = get_pyg_data_from_ast(ast, vocab_astnodes)
    return Data(x=torch.tensor(x, dtype=torch.long),
                    edge_index=torch.tensor(edge_index, dtype=torch.long),
                    edge_attr=torch.tensor(edge_attr, dtype=torch.long))


In [21]:
pyg_datas = []
for index in raw_code_index:
    pyg_datas.append(transform_to_pygdata(raw_code['func'][index]))

raw_code['pyg_data'] = pyg_datas

In [22]:
raw_code['pyg_data'][10000832]

Data(edge_attr=[3752, 1], edge_index=[2, 3752], x=[1774, 1])

In [23]:
class PairData(Data):
    def __init__(self, edge_index_s, edge_attr_s, x_s, edge_index_t, edge_attr_t, x_t, label):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.edge_attr_s = edge_attr_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.edge_attr_t = edge_attr_t
        self.x_t = x_t
        self.label = label
    
    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        # if key == 'edge_attr_s':
        #     return self.x_s.size(0)
        # if key == 'edge_attr_t':
        #     return self.x_t.size(0)
        # if key == 'label':
        #     return self.x_s.size(0)
        else:
            return super().__inc__(key, value)
    

In [24]:
def batch_data(data):
    batches = []
    for i in range(len(data)):
        pyg1 = raw_code['pyg_data'][data[i][0]]
        pyg2 = raw_code['pyg_data'][data[i][1]]
        label = torch.tensor(data[i][2], dtype=torch.long)
        # print('attr 1', pyg1.edge_attr)
        # print('attr 2', pyg2.edge_attr)
        pair_data = PairData(x_s=pyg1.x, edge_index_s=pyg1.edge_index, edge_attr_s = pyg1.edge_attr,
                                x_t=pyg2.x, edge_index_t=pyg2.edge_index, edge_attr_t = pyg2.edge_attr,
                                label=label)
        batches.append(pair_data)    
    return batches

In [25]:
train_batch = batch_data(train_data_small)
valid_batch = batch_data(valid_data_small)
test_batch = batch_data(test_data_small)

In [26]:
train_loader = DataLoader(train_batch, batch_size=batch_size, follow_batch=['x_s', 'x_t'])
valid_loader = DataLoader(valid_batch, batch_size=batch_size, follow_batch=['x_s', 'x_t'])
test_loader = DataLoader(test_batch, batch_size=batch_size, follow_batch=['x_s', 'x_t'])

Model

In [27]:
# partitioning D-AST in model instead of data-prepocessing
# partitioning D-AST in model by the num of nodes, which is set in the hyper-parameter `divide_node_num`

In [28]:
class SequenceGNNEncoder(torch.nn.Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num,
                    decoder_input_size, device):
        super(SequenceGNNEncoder, self).__init__()
        self.device = device
        self.embed = nn.Embedding(vocab_len, graph_embedding_size)
        self.edge_embed = nn.Embedding(2, 1) # only two edge types to be set weights, which are AST edge and data flow edge
        self.ggnnlayer = GatedGraphConv(graph_embedding_size, gnn_layers_num)
        self.mlp_gate = nn.Sequential(
            nn.Linear(graph_embedding_size, 300), nn.Sigmoid(), nn.Linear(300, 1), nn.Sigmoid())
        self.pool = GlobalAttention(gate_nn=self.mlp_gate)
        self.divide_node_num = divide_node_num
        self.lstm = nn.LSTM(input_size=graph_embedding_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers_num)
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_layers_num = lstm_layers_num
        self.fc = nn.Linear(graph_embedding_size + lstm_hidden_size, decoder_input_size)

    def subgraph_forward(self, x, edge_index, edge_attr, batch):
        # print('edge_attr', edge_attr)
        # print('shape', edge_attr.shape)
        if type(edge_attr) == type(None):
            edge_weight = None
        else:
            edge_weight = self.edge_embed(edge_attr)
            edge_weight = edge_weight.squeeze(1)
        x = self.ggnnlayer(x, edge_index, edge_weight)
        return self.pool(x, batch=batch)
    
    # partitioning multiple subgraphs by dynamic allocating edges
    def partition_graph(self, x, edge_index, edge_attr, batch):        
        nodes_list = [] # record all nodes number for each subgraph in total batch
        graph_pos_in_batch, graph_length = self.get_subgraph_info_from_batch(batch)
        max_seq_len = max(graph_length)
        subgraph_num = int(max_seq_len/self.divide_node_num) + 1
        for i in range(subgraph_num):
            nodes = []
            for j in range(len(graph_pos_in_batch)):
                if graph_length[j] > i * self.divide_node_num:
                    if graph_length[j] > (i+1) * self.divide_node_num:
                        subgraph_len = self.divide_node_num
                    else:
                        subgraph_len = graph_length[j] - i * self.divide_node_num   
                    for m in range(subgraph_len):
                        nodes.append(graph_pos_in_batch[j] + m)          
            nodes_list.append(set(nodes)) 
        # only count the edge whose target node in subgraph
        sub_edge_src = [[] for _ in range(subgraph_num)]
        sub_edge_tgt = [[] for _ in range(subgraph_num)]
        sub_edge_attr = [[] for _ in range(subgraph_num)]
        # print('nodes_list', nodes_list)
        node_num = len(x)
        node_subgraph_index = [0 for _ in range(node_num)] # use a list to store the subgraph numbers for all nodes
        for i in range(len(nodes_list)):
            for node in nodes_list[i]:
                node_subgraph_index[node] = i
    
        for i in range(len(edge_index[1])):
            src = edge_index[0][i].item()
            tgt = edge_index[1][i].item()
            sub_edge_src[node_subgraph_index[tgt]].append(src)
            sub_edge_tgt[node_subgraph_index[tgt]].append(tgt)
            sub_edge_attr[node_subgraph_index[tgt]].append(edge_attr[i].item())
        edge_index_list = []
        edge_attr_list = []
        for i in range(subgraph_num):
            edge_index_list.append(torch.tensor([sub_edge_src[i], sub_edge_tgt[i]], dtype=torch.long))
            edge_attr_list.append(torch.tensor(sub_edge_attr[i], dtype=torch.long))
        return edge_index_list, edge_attr_list

    def get_subgraph_info_from_batch(self, batch):
        comp = 0
        pos = 0
        graph_pos_in_batch = [0] # record begin positions and end positions of every subgraph
        graph_length = [] # use a list to store the node nums in subgraph
        for i in range(len(batch)):
            if batch[i] != comp:
                graph_pos_in_batch.append(i)
                graph_length.append(i-pos)
                comp = batch[i]
                pos = i
        graph_length.append(len(batch)-pos)
        return graph_pos_in_batch, graph_length        

    def forward(self, x, edge_index, edge_attr, batch):
        edge_index_list, edge_attr_list = self.partition_graph(x, edge_index, edge_attr, batch)
        x = self.embed(x)
        x = x.squeeze(1)
        subgraph_pool_list = [
            self.subgraph_forward(x, edge_index_list[i].to(self.device), edge_attr_list[i].to(self.device), batch)
            for i in range(len(edge_index_list))
        ]
        graph_pool = self.subgraph_forward(x, edge_index, edge_attr, batch)
        subgraph_pool_seq = torch.stack(subgraph_pool_list)
        h0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        c0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1) ,self.lstm_hidden_size).to(self.device)
        subgraph_output, (_, _) = self.lstm(subgraph_pool_seq, (h0, c0))
        return self.fc(torch.cat((subgraph_output[-1], graph_pool), dim=1))
        
            

In [29]:
class SiameseDecoder(nn.Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num, 
    decoder_input_size, device):
        super(SiameseDecoder, self).__init__()
        self.encoder = SequenceGNNEncoder(vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num, 
    decoder_input_size, device)
        self.fc = nn.Linear(2* decoder_input_size, 2)

    
    def forward(self, x_s, edge_index_s, edge_attr_s, x_t, edge_index_t, edge_attr_t, x_s_batch, x_t_batch):
        decoder_input1 = self.encoder(x_s, edge_index_s, edge_attr_s, x_s_batch)
        decoder_input2 = self.encoder(x_t, edge_index_t, edge_attr_t, x_t_batch)
        output = torch.cat((decoder_input1, decoder_input2), dim=1)
        logits = self.fc(output) 
        return F.softmax(logits)      


Training

In [30]:
device = torch.device('cuda: 0')

model_args = {
    'vocab_len': vocab_astnodes_size,
    'graph_embedding_size': graph_embedding_size,
    'gnn_layers_num': gnn_layers_num,
    'lstm_layers_num': lstm_layers_num,
    'lstm_hidden_size': lstm_hidden_size,
    'divide_node_num': divide_node_num,
    'decoder_input_size': decoder_input_size,
    'device': device
}

model = SiameseDecoder(**model_args).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda= lambda epoch: decay_ratio ** epoch)

In [31]:
model_summary = summary(model)

Layer (type:depth-idx)                   Param #
├─SequenceGNNEncoder: 1-1                --
|    └─Embedding: 2-1                    5,486,600
|    └─Embedding: 2-2                    2
|    └─GatedGraphConv: 2-3               --
|    |    └─GRUCell: 3-1                 60,600
|    └─Sequential: 2-4                   --
|    |    └─Linear: 3-2                  30,300
|    |    └─Sigmoid: 3-3                 --
|    |    └─Linear: 3-4                  301
|    |    └─Sigmoid: 3-5                 --
|    └─GlobalAttention: 2-5              --
|    |    └─Sequential: 3-6              (recursive)
|    └─LSTM: 2-6                         381,952
|    └─Linear: 2-7                       68,700
├─Linear: 1-2                            1,202
Total params: 6,029,657
Trainable params: 6,029,657
Non-trainable params: 0


In [32]:
criterion = nn.CrossEntropyLoss()

In [33]:
def compute_loss(data, model, optimizer=None, is_train=True):
    x_s = (data.x_s).to(device)
    edge_index_s = (data.edge_index_s).to(device)
    edge_attr_s = (data.edge_attr_s).to(device)
    x_t = (data.x_t).to(device)
    edge_index_t = (data.edge_index_t).to(device)
    edge_attr_t = (data.edge_attr_t).to(device)
    x_s_batch = (data.x_s_batch).to(device)
    x_t_batch = (data.x_t_batch).to(device)
    label = (data.label).to(device)
    pred = model(x_s, edge_index_s, edge_attr_s, x_t, edge_index_t, edge_attr_t, x_s_batch, x_t_batch)
    loss = criterion(pred, label)

    if is_train:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
    
    return loss.item(), label.cpu().detach().numpy(), pred.cpu().detach().numpy()

In [34]:
class EarlyStopping(object):
    def __init__(self, filename = None, patience=3, warm_up=0, verbose=False):

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.warm_up = warm_up
        self.filename = filename

    def __call__(self, score, model, epoch):

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(score, model)
            
        elif (score <= self.best_score) and (epoch > self.warm_up) :
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            if (epoch <= self.warm_up):
                print('Warming up until epoch', self.warm_up)
            
            else:
                if self.verbose:
                    print(f'Score improved. ({self.best_score:.6f} --> {score:.6f}).')
                
                self.best_score = score
                self.save_checkpoint(score, model)
                self.counter = 0

    def save_checkpoint(self, score, model):
        
        if self.filename is not None:
            torch.save(model.state_dict(), self.filename)
            
        if self.verbose:
            print('Model saved...')

In [35]:
fname = exp_dir + save_name
early_stopping = EarlyStopping(fname, patience, warm_up, verbose=True)

In [36]:
import numpy as np
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score

def calculate(logits, y_trues):
    logits=np.concatenate(logits,0)
    y_trues=np.concatenate(y_trues,0)
    best_threshold=0
    best_f1=0
    for i in range(1,100):
        threshold=i/100
        y_preds=logits[:,1]>threshold        
        recall=recall_score(y_trues, y_preds, average='macro')        
        precision=precision_score(y_trues, y_preds, average='macro')           
        f1=f1_score(y_trues, y_preds, average='macro') 
        if f1>best_f1:
            best_f1=f1
            best_threshold=threshold

    y_preds=logits[:,1]>best_threshold
    recall=recall_score(y_trues, y_preds, average='macro')
    precision=precision_score(y_trues, y_preds, average='macro')   
    f1=f1_score(y_trues, y_preds, average='macro')
    return recall, precision, f1


In [37]:
for epoch in range(1, num_epoches + 1):
    train_loss = 0.
    train_refs = []
    train_hyps = []
    valid_loss = 0.
    valid_refs = []
    valid_hyps = []

    # train
    model.train()
    for data in tqdm(train_loader, total=len(train_loader), desc='TRAIN'):
        loss, gold, pred = compute_loss(data, model, optimizer, is_train=True)
        msgr.print_msg('train loss {}'.format(loss))
        train_loss += loss
        train_refs.append(gold)
        train_hyps.append(pred)
    

    # valid
    model.eval()
    for data in tqdm(valid_loader, total=len(valid_loader), desc='VALID'):
        loss, gold, pred = compute_loss(data, model, optimizer, is_train=False)
        # msgr.print_msg('valid loss {}'.format(loss))
        valid_loss += loss
        valid_refs.append(gold)
        valid_hyps.append(pred)
    
    
    train_loss = np.sum(train_loss) / len(train_loader)
    valid_loss = np.sum(valid_loss) / len(valid_loader)
    
    train_recall, train_precision, train_f1 = calculate(train_hyps, train_refs)
    valid_recall, valid_precision, valid_f1 = calculate(valid_hyps, valid_refs)

#     msgr.print_msg('train_hyps {}'.format(train_hyps[0]))
#     msgr.print_msg('valid_hyps {}'.format(valid_hyps[0]))


    msgr.print_msg('Epoch {}: train_loss: {:5.2f} train_recall: {:2.4f} train_precision: {:2.4f}  train_f1: {:2.4f}  valid_loss: {:5.2f}  valid_recall: {:2.4f} valid_precision: {:2.4f} valid_f1: {:2.4f}'.format(
            epoch, train_loss, train_recall, train_precision, train_f1, valid_loss, valid_recall, valid_precision, valid_f1))
    
    early_stopping(valid_f1, model, epoch)
    if early_stopping.early_stop:
        msgr.print_msg("Early stopping")
        break  
        
    print('-'*80)
    scheduler.step()


HBox(children=(IntProgress(value=0, description='TRAIN', max=16, style=ProgressStyle(description_width='initia…

dgnn train loss 0.6926808953285217
dgnn train loss 0.6931845545768738
dgnn train loss 0.6895528435707092
dgnn train loss 0.699258029460907
dgnn train loss 0.702467679977417
dgnn train loss 0.6882095336914062
dgnn train loss 0.6894508004188538
dgnn train loss 0.6884142160415649
dgnn train loss 0.6927498579025269
dgnn train loss 0.685588002204895
dgnn train loss 0.6796486973762512
dgnn train loss 0.7059070467948914
dgnn train loss 0.6701700687408447
dgnn train loss 0.7042304277420044
dgnn train loss 0.6709396839141846
dgnn train loss 0.6727725267410278



HBox(children=(IntProgress(value=0, description='VALID', max=8, style=ProgressStyle(description_width='initial…


dgnn train_hyps [[0.49625507 0.50374496]
 [0.50160646 0.49839354]
 [0.5055386  0.49446145]
 [0.5079428  0.49205717]
 [0.5050729  0.49492714]
 [0.49973947 0.50026053]
 [0.50332236 0.49667767]
 [0.50439996 0.49560004]
 [0.5056411  0.49435887]
 [0.50214636 0.49785367]
 [0.5032346  0.49676535]
 [0.5021597  0.49784026]
 [0.5037326  0.4962674 ]
 [0.50338    0.49662003]
 [0.5016635  0.4983365 ]
 [0.5023695  0.49763054]
 [0.5034854  0.4965146 ]
 [0.49722376 0.5027762 ]
 [0.50073653 0.49926347]
 [0.5017966  0.49820343]
 [0.49923396 0.50076604]
 [0.50598824 0.49401173]
 [0.5083487  0.49165127]
 [0.5018837  0.4981163 ]
 [0.49981126 0.5001887 ]
 [0.5023577  0.49764228]
 [0.50135994 0.4986401 ]
 [0.50269186 0.4973081 ]
 [0.5016433  0.4983567 ]
 [0.5036248  0.49637526]
 [0.5026432  0.4973568 ]
 [0.5001802  0.49981984]
 [0.50496596 0.49503404]
 [0.49902567 0.5009743 ]
 [0.5006911  0.4993089 ]
 [0.506674   0.49332598]
 [0.5019294  0.49807057]
 [0.50247645 0.49752355]
 [0.49809998 0.5019    ]
 [0.5025

HBox(children=(IntProgress(value=0, description='TRAIN', max=16, style=ProgressStyle(description_width='initia…

dgnn train loss 0.662265419960022
dgnn train loss 0.672351598739624
dgnn train loss 0.6715599894523621
dgnn train loss 0.6683922410011292
dgnn train loss 0.6753723621368408
dgnn train loss 0.6745390892028809
dgnn train loss 0.6707051992416382
dgnn train loss 0.6577792763710022
dgnn train loss 0.6754913330078125
dgnn train loss 0.6622796058654785
dgnn train loss 0.6639368534088135
dgnn train loss 0.6793146729469299
dgnn train loss 0.6486077904701233
dgnn train loss 0.6866417527198792
dgnn train loss 0.6646010875701904
dgnn train loss 0.6661267280578613



HBox(children=(IntProgress(value=0, description='VALID', max=8, style=ProgressStyle(description_width='initial…


dgnn train_hyps [[0.5712152  0.42878482]
 [0.5786814  0.4213186 ]
 [0.5691468  0.43085316]
 [0.6107492  0.3892508 ]
 [0.5853761  0.41462395]
 [0.5469747  0.4530253 ]
 [0.52432084 0.47567916]
 [0.509219   0.49078098]
 [0.5226795  0.47732046]
 [0.62569314 0.3743069 ]
 [0.58384186 0.41615817]
 [0.6230456  0.37695435]
 [0.556782   0.443218  ]
 [0.593134   0.406866  ]
 [0.57714486 0.42285517]
 [0.6144007  0.38559932]
 [0.49863553 0.50136447]
 [0.46248645 0.53751355]
 [0.6092166  0.39078343]
 [0.57340455 0.42659542]
 [0.5997918  0.40020815]
 [0.6508155  0.34918448]
 [0.6137325  0.38626745]
 [0.49766272 0.5023372 ]
 [0.5319243  0.46807566]
 [0.60647756 0.3935225 ]
 [0.54412425 0.45587578]
 [0.6562408  0.34375915]
 [0.5579286  0.44207138]
 [0.5526726  0.44732732]
 [0.46309915 0.5369008 ]
 [0.46563485 0.5343652 ]
 [0.6687622  0.33123776]
 [0.49215314 0.5078469 ]
 [0.50864935 0.4913507 ]
 [0.65341306 0.34658697]
 [0.6149729  0.3850271 ]
 [0.4950189  0.5049811 ]
 [0.56177634 0.43822366]
 [0.5515

HBox(children=(IntProgress(value=0, description='TRAIN', max=16, style=ProgressStyle(description_width='initia…

dgnn train loss 0.6621484756469727
dgnn train loss 0.6722412109375
dgnn train loss 0.6714584231376648
dgnn train loss 0.6682616472244263
dgnn train loss 0.675236701965332
dgnn train loss 0.6744804978370667
dgnn train loss 0.6706343293190002
dgnn train loss 0.6576984524726868
dgnn train loss 0.6754196882247925
dgnn train loss 0.6622290015220642
dgnn train loss 0.6639038920402527
dgnn train loss 0.6792457103729248
dgnn train loss 0.648556113243103
dgnn train loss 0.6866062879562378
dgnn train loss 0.66459059715271
dgnn train loss 0.6661160588264465



HBox(children=(IntProgress(value=0, description='VALID', max=8, style=ProgressStyle(description_width='initial…


dgnn train_hyps [[0.5711071  0.42889288]
 [0.5786027  0.4213973 ]
 [0.56905174 0.4309483 ]
 [0.6109412  0.38905883]
 [0.58539784 0.41460213]
 [0.5467439  0.4532561 ]
 [0.52397335 0.47602668]
 [0.5088192  0.49118084]
 [0.52228016 0.47771984]
 [0.62580156 0.37419847]
 [0.58382314 0.41617683]
 [0.6232514  0.37674865]
 [0.55658746 0.44341254]
 [0.5931774  0.4068226 ]
 [0.5771294  0.42287058]
 [0.6145575  0.38544253]
 [0.49803102 0.501969  ]
 [0.4617925  0.5382075 ]
 [0.6092539  0.3907461 ]
 [0.57328993 0.42671007]
 [0.5997144  0.4002856 ]
 [0.6511838  0.34881628]
 [0.6139261  0.38607392]
 [0.49726418 0.5027358 ]
 [0.5315606  0.4684394 ]
 [0.60653293 0.39346707]
 [0.5437617  0.45623836]
 [0.65646213 0.34353787]
 [0.55771935 0.44228062]
 [0.55244124 0.44755873]
 [0.46231288 0.5376871 ]
 [0.46498504 0.5350149 ]
 [0.669112   0.330888  ]
 [0.49150574 0.50849426]
 [0.5081967  0.4918033 ]
 [0.6537694  0.34623066]
 [0.6151089  0.38489106]
 [0.4945415  0.50545853]
 [0.56154925 0.43845078]
 [0.5513

HBox(children=(IntProgress(value=0, description='TRAIN', max=16, style=ProgressStyle(description_width='initia…

dgnn train loss 0.6621484756469727
dgnn train loss 0.6722412109375
dgnn train loss 0.67145836353302
dgnn train loss 0.6682615876197815
dgnn train loss 0.6752366423606873
dgnn train loss 0.6744804978370667
dgnn train loss 0.6706343293190002
dgnn train loss 0.6576984524726868
dgnn train loss 0.6754196882247925
dgnn train loss 0.6622290015220642
dgnn train loss 0.6639038920402527
dgnn train loss 0.6792457103729248
dgnn train loss 0.6485561728477478
dgnn train loss 0.6866061687469482
dgnn train loss 0.6645904779434204
dgnn train loss 0.6661160588264465



HBox(children=(IntProgress(value=0, description='VALID', max=8, style=ProgressStyle(description_width='initial…


dgnn train_hyps [[0.5711071  0.42889288]
 [0.5786027  0.4213973 ]
 [0.56905174 0.43094823]
 [0.6109413  0.38905865]
 [0.5853979  0.4146021 ]
 [0.5467438  0.45325616]
 [0.5239732  0.47602677]
 [0.50881904 0.49118093]
 [0.52228004 0.47771996]
 [0.6258016  0.37419838]
 [0.5838232  0.41617677]
 [0.62325156 0.3767485 ]
 [0.55658746 0.44341257]
 [0.5931775  0.4068225 ]
 [0.5771295  0.42287052]
 [0.6145576  0.38544238]
 [0.49803084 0.5019692 ]
 [0.46179223 0.53820777]
 [0.60925394 0.39074603]
 [0.57328993 0.42671004]
 [0.5997144  0.40028557]
 [0.65118396 0.34881604]
 [0.61392623 0.3860738 ]
 [0.49726406 0.5027359 ]
 [0.5315605  0.4684395 ]
 [0.60653305 0.39346698]
 [0.54376155 0.45623845]
 [0.6564623  0.34353772]
 [0.55771935 0.44228065]
 [0.5524412  0.4475588 ]
 [0.4623126  0.5376874 ]
 [0.46498477 0.53501517]
 [0.66911227 0.33088773]
 [0.49150556 0.5084945 ]
 [0.5081966  0.4918034 ]
 [0.6537695  0.3462305 ]
 [0.615109   0.38489097]
 [0.49454132 0.50545865]
 [0.5615492  0.43845087]
 [0.5513

HBox(children=(IntProgress(value=0, description='TRAIN', max=16, style=ProgressStyle(description_width='initia…

dgnn train loss 0.6621484756469727
dgnn train loss 0.6722412109375
dgnn train loss 0.67145836353302
dgnn train loss 0.6682615876197815
dgnn train loss 0.6752366423606873
dgnn train loss 0.6744804978370667
dgnn train loss 0.6706343293190002
dgnn train loss 0.6576984524726868
dgnn train loss 0.6754196882247925
dgnn train loss 0.6622289419174194
dgnn train loss 0.6639038920402527
dgnn train loss 0.6792457103729248
dgnn train loss 0.648556113243103
dgnn train loss 0.6866061687469482
dgnn train loss 0.66459059715271
dgnn train loss 0.6661160588264465



HBox(children=(IntProgress(value=0, description='VALID', max=8, style=ProgressStyle(description_width='initial…


dgnn train_hyps [[0.5711071  0.42889288]
 [0.5786027  0.4213973 ]
 [0.56905174 0.43094823]
 [0.6109413  0.38905865]
 [0.5853979  0.4146021 ]
 [0.5467438  0.45325616]
 [0.5239732  0.47602677]
 [0.50881904 0.49118093]
 [0.52228004 0.47771996]
 [0.6258016  0.37419838]
 [0.5838232  0.41617677]
 [0.62325156 0.3767485 ]
 [0.55658746 0.44341257]
 [0.5931775  0.4068225 ]
 [0.5771295  0.42287052]
 [0.6145576  0.38544238]
 [0.49803084 0.5019692 ]
 [0.46179223 0.5382077 ]
 [0.609254   0.390746  ]
 [0.57328993 0.42671004]
 [0.5997144  0.4002856 ]
 [0.65118396 0.34881604]
 [0.61392623 0.38607374]
 [0.4972641  0.5027359 ]
 [0.5315605  0.4684395 ]
 [0.606533   0.393467  ]
 [0.54376155 0.45623845]
 [0.6564623  0.34353772]
 [0.55771935 0.44228065]
 [0.5524412  0.4475588 ]
 [0.4623126  0.53768736]
 [0.46498477 0.53501517]
 [0.66911227 0.33088773]
 [0.49150556 0.5084945 ]
 [0.5081966  0.4918034 ]
 [0.6537695  0.34623045]
 [0.615109   0.38489094]
 [0.49454132 0.50545865]
 [0.5615492  0.43845087]
 [0.5513

HBox(children=(IntProgress(value=0, description='TRAIN', max=16, style=ProgressStyle(description_width='initia…

dgnn train loss 0.6621484756469727
dgnn train loss 0.6722412109375
dgnn train loss 0.67145836353302
dgnn train loss 0.6682616472244263
dgnn train loss 0.6752366423606873
dgnn train loss 0.6744804978370667
dgnn train loss 0.6706343293190002
dgnn train loss 0.6576984524726868
dgnn train loss 0.6754196882247925
dgnn train loss 0.6622290015220642
dgnn train loss 0.6639038920402527
dgnn train loss 0.6792457103729248
dgnn train loss 0.6485561728477478
dgnn train loss 0.6866061687469482
dgnn train loss 0.66459059715271
dgnn train loss 0.6661160588264465



HBox(children=(IntProgress(value=0, description='VALID', max=8, style=ProgressStyle(description_width='initial…


dgnn train_hyps [[0.5711071  0.42889288]
 [0.5786027  0.4213973 ]
 [0.56905174 0.43094823]
 [0.6109413  0.38905865]
 [0.5853979  0.4146021 ]
 [0.5467438  0.45325616]
 [0.52397317 0.4760268 ]
 [0.50881904 0.49118093]
 [0.52228004 0.47771996]
 [0.6258016  0.37419838]
 [0.5838232  0.41617677]
 [0.62325144 0.3767485 ]
 [0.55658746 0.44341257]
 [0.5931775  0.4068225 ]
 [0.5771295  0.42287052]
 [0.6145576  0.38544238]
 [0.49803084 0.5019692 ]
 [0.46179223 0.5382077 ]
 [0.60925394 0.39074603]
 [0.57328993 0.42671004]
 [0.5997144  0.4002856 ]
 [0.65118396 0.34881604]
 [0.61392623 0.3860738 ]
 [0.4972641  0.5027359 ]
 [0.5315605  0.4684395 ]
 [0.606533   0.393467  ]
 [0.54376155 0.45623845]
 [0.6564623  0.34353772]
 [0.55771935 0.44228065]
 [0.5524412  0.4475588 ]
 [0.4623126  0.5376874 ]
 [0.4649848  0.53501517]
 [0.66911227 0.33088773]
 [0.49150556 0.5084945 ]
 [0.5081966  0.4918034 ]
 [0.6537695  0.34623045]
 [0.6151091  0.38489094]
 [0.49454132 0.5054587 ]
 [0.5615492  0.43845087]
 [0.5513

HBox(children=(IntProgress(value=0, description='TRAIN', max=16, style=ProgressStyle(description_width='initia…

dgnn train loss 0.6621484756469727
dgnn train loss 0.6722412109375
dgnn train loss 0.67145836353302
dgnn train loss 0.6682616472244263
dgnn train loss 0.6752366423606873
dgnn train loss 0.6744804978370667
dgnn train loss 0.6706343293190002
dgnn train loss 0.6576984524726868
dgnn train loss 0.6754196882247925
dgnn train loss 0.6622290015220642
dgnn train loss 0.6639038324356079
dgnn train loss 0.6792457103729248
dgnn train loss 0.6485561728477478
dgnn train loss 0.6866061687469482
dgnn train loss 0.66459059715271
dgnn train loss 0.6661160588264465



HBox(children=(IntProgress(value=0, description='VALID', max=8, style=ProgressStyle(description_width='initial…


dgnn train_hyps [[0.5711071  0.42889288]
 [0.5786027  0.4213973 ]
 [0.56905174 0.43094823]
 [0.61094135 0.38905865]
 [0.5853979  0.4146021 ]
 [0.5467438  0.45325616]
 [0.5239732  0.47602677]
 [0.50881904 0.49118093]
 [0.52228004 0.47771996]
 [0.6258016  0.37419838]
 [0.5838232  0.41617677]
 [0.62325156 0.3767485 ]
 [0.55658746 0.44341257]
 [0.5931775  0.40682256]
 [0.5771295  0.42287052]
 [0.6145576  0.3854424 ]
 [0.49803084 0.5019692 ]
 [0.46179223 0.53820777]
 [0.609254   0.390746  ]
 [0.57328993 0.42671004]
 [0.5997144  0.4002856 ]
 [0.65118396 0.34881604]
 [0.61392623 0.38607374]
 [0.4972641  0.5027359 ]
 [0.5315605  0.4684395 ]
 [0.60653305 0.39346698]
 [0.54376155 0.45623845]
 [0.6564623  0.34353772]
 [0.55771935 0.44228065]
 [0.5524412  0.4475588 ]
 [0.4623126  0.53768736]
 [0.46498477 0.53501517]
 [0.66911227 0.33088773]
 [0.49150556 0.5084945 ]
 [0.5081966  0.49180344]
 [0.6537695  0.3462305 ]
 [0.6151091  0.3848909 ]
 [0.49454132 0.50545865]
 [0.5615492  0.43845087]
 [0.5513

HBox(children=(IntProgress(value=0, description='TRAIN', max=16, style=ProgressStyle(description_width='initia…

dgnn train loss 0.6621484756469727
dgnn train loss 0.6722412109375
dgnn train loss 0.67145836353302
dgnn train loss 0.6682616472244263
dgnn train loss 0.6752366423606873
dgnn train loss 0.6744804978370667
dgnn train loss 0.6706343293190002
dgnn train loss 0.6576984524726868


KeyboardInterrupt: 

Test