In [None]:
import yaml
import warnings
from datetime import datetime
import logging
import os
import pandas as pd 
import javalang
from javalang.ast import Node
import torch
import torch.nn.functional as F
from torch.nn import Module, Embedding, LSTM, Sequential, Linear, BatchNorm1d, ReLU, Sigmoid, CrossEntropyLoss
import torch.optim as optim
from torch_geometric.nn.glob import GlobalAttention
from torch_geometric.nn import MessagePassing, GatedGraphConv, GCNConv, global_mean_pool, GINEConv, global_add_pool
from anytree import AnyNode
from torch_geometric.data import Data, DataLoader, ClusterData, ClusterLoader
from tqdm import tqdm_notebook as tqdm
import numpy as np
from torchsummary import summary
import random

warnings.filterwarnings('ignore')

Paramerters

In [None]:
config_file = 'config_dgnn.yml'

In [None]:
config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# data source
raw_code_url = config['ccd_data']['data']
train_url = config['ccd_data']['train']
valid_url = config['ccd_data']['valid']
test_url = config['ccd_data']['test']

# training parameter
# batch_size = config['training']['batch_size']
batch_size = 64
num_epoches = config['training']['num_epoches']
lr = config['training']['lr']
decay_ratio = config['training']['lr']
save_name = config['training']['save_name']
warm_up = config['training']['warm_up']
patience = config['training']['patience']

# model design
graph_embedding_size = config['model']['graph_embedding_size']
lstm_hidden_size = config['model']['lstm_hidden_size']
divide_node_num = config['model']['divide_node_num']
gnn_layers_num = config['model']['gnn_layers_num']
lstm_layers_num = config['model']['lstm_layers_num']
decoder_input_size = config['model']['decoder_input_size']
gine_dim = config['model']['gine_dim']

# logs
info_prefix = config['logs']['info_prefix']

Logs

In [None]:
run_id = datetime.now().strftime('%Y-%m-%d--%H-%M-%S')
log_file = 'logs/' + run_id + '.log'
exp_dir = 'runs/' + run_id
os.mkdir(exp_dir)

In [None]:
class Info(object):
    def __init__(self, info_prefix=''):
        self.info_prefix = info_prefix
    
    def print_msg(self, msg):
        text = self.info_prefix + ' ' + msg
        print(text)
        logging.info(text)

In [None]:
logging.basicConfig(format='%(asctime)s | %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', filename=log_file, level=logging.DEBUG)
msgr = Info(info_prefix)

msgr.print_msg('run_id : {}'.format(run_id))
msgr.print_msg('log_file : {}'.format(log_file))
msgr.print_msg('exp_dir: {}'.format(exp_dir))
msgr.print_msg(str(config))

Sequence Preprocess

In [None]:
# define one extra keywords 
UNK_TOKEN = '<UNK>'
UNK = 0

In [None]:
raw_code = pd.read_json(path_or_buf=raw_code_url, lines=True)
raw_code = raw_code.set_index('idx')
raw_code.head()

In [None]:
raw_code_index = raw_code.index.tolist()

over_length_ids = [10000832, 10151623, 10540676, 10690321, 10717656, 10793825, 10934628, 11339042, 11603577, 11940679, 12082150, 12119068, 12335897, 12352751, 12389873, 12415477, 12838273, 12838274, 12914531, 13099033, 13292215, 13292580, 13456795, 1349815, 13563706, 13586003, 13644040, 13650923, 13747998, 13776078, 13961100, 14020143, 14190765, 14310068, 14661394, 1477292, 14794142, 15295408, 15453012, 15723802, 15758923, 16002345, 16002347, 16006791, 16142591, 1637147, 16518661, 17314208, 17551920, 17573144, 17750515, 17829989, 17874921, 17917053, 18284812, 18433984, 18574455, 18790182, 18822890, 19074021, 19090289, 19276021, 19276022, 19340788, 19382420, 19434890, 19434892, 1944490, 19478367, 19556732, 1962490, 19634773, 19841853, 19942676, 20122631, 20275058, 2067794, 20833509, 20856391, 20885480, 21028028, 21161120, 21318345, 21493541, 2157431, 21652119, 22031237, 22033685, 22035132, 22114133, 22222255, 2235431, 2247987, 22580642, 22673614, 2285441, 23041161, 23094550, 23188198, 23248619, 23370708, 23510383, 23611768, 2450, 2476569, 2511576, 2769195, 2771573, 2771574, 285947, 2996859, 325062, 3375714, 3375715, 3375723, 3400236, 37044, 3867253, 416857, 4420769, 4494367, 4581365, 4660318, 4681906, 4780347, 4792385, 4854974, 5021563, 5189131, 5252227, 5389524, 5430189, 552318, 557726, 5620476, 564191, 5691586, 592597, 6008880, 6121196, 6147227, 6304372, 6304373, 6333737, 6403868, 6644748, 6961579, 6966398, 7005223, 7149578, 726690, 7300257, 7300264, 7300267, 733283, 7394826, 7665877, 7687037, 7727956, 793122, 8079516, 8109022, 8335944, 83802, 8581437, 8641070, 9050003, 9221721, 9261908, 9530015, 9581835, 961457, 9647574, 9687064, 9705209, 9782243, 979163, 98309, 98428, 9980609]
# over_length_ids = [10690321, 10717656, 10793825, 11339042, 12119068, 12352751, 12415477, 13099033, 15723802, 1637147, 16518661, 17750515, 18822890, 19276021, 2067794, 20833509, 20856391, 20885480, 21028028, 22033685, 22035132, 2235431, 2247987, 22580642, 2285441, 23094550, 23248619, 2996859, 325062, 3375723, 3867253, 4494367, 4660318, 4780347, 5021563, 5189131, 5252227, 564191, 6147227, 7005223, 7300267, 793122, 8109022, 8581437, 9221721, 9647574, 98309]

def read_ccd_pairs(url):
    data = []
    with open(url) as f:
        for line in f:
            line = line.strip()
            id1, id2, label = line.split('\t')
            if int(id1) not in raw_code_index or int(id2) not in raw_code_index or int(id1) in over_length_ids or int(id2) in over_length_ids:
                continue
            label = 0 if label == '0' else 1
            data.append((int(id1), int(id2), label))
    return data

In [None]:
train_data = read_ccd_pairs(train_url)
valid_data = read_ccd_pairs(valid_url)
test_data = read_ccd_pairs(test_url)

In [None]:
random.seed(666)
train_data_small = random.sample(train_data, 4000)
valid_data_small = random.sample(valid_data, 2000)
test_data_small = random.sample(test_data, 2000)

msgr.print_msg('train size: {}, valid size: {}, test size: {}'.format(len(train_data_small), len(valid_data_small), len(test_data_small)))

In [None]:
# define vocab class
class Vocab(object):
    def __init__(self, word2id={}):
        self.word2id = dict(word2id)
        self.id2word = {v: k for k, v in self.word2id.items()}

    def build_vocab(self, sentences, min_count=1):
        word_counter = {}
        for word in sentences:
            word_counter[word] = word_counter.get(word, 0) + 1
        
        for word, count in sorted(word_counter.items(), key=lambda x: -x[1]):
            if count < min_count:
                break
            _id = len(self.word2id)
            self.word2id.setdefault(word, _id)
            self.id2word[_id] = word

In [None]:
# construct two vocabulary for ast nodes and natural language repectively
word2id = {
    UNK_TOKEN: UNK
}
vocab_astnodes = Vocab(word2id=word2id)


In [None]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = ''
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [None]:
# use train data to construction ast nodes corpus
astnodes_tokens = []
train_ids = []
for i in range(len(train_data)):
    train_ids.append(train_data[i][0])
    train_ids.append(train_data[i][1])

train_ids = list(set(train_ids))

for id in train_ids:
    sequence = []
    get_sequence(parse_program(raw_code['func'][id]), sequence)
    astnodes_tokens.extend(sequence)

vocab_astnodes.build_vocab(astnodes_tokens, min_count=0)

In [None]:
vocab_astnodes_size = len(vocab_astnodes.id2word)
msgr.print_msg('vocab_astnodes_size: ' + str(vocab_astnodes_size))

Add Dataflow to AST to generate D-AST

In [None]:
#  generate tree for AST Node
def create_tree(root, node, node_list, parent=None):
    id = len(node_list)
    token, children = get_token(node), get_child(node)
    if id == 0:
        root.token = token
        root.data = node
    else:
        new_node = AnyNode(id=id, token=token, data=node, parent=parent)
    node_list.append(node)
    for child in children:
        if id == 0:
            create_tree(root, child, node_list, parent=root)
        else:
            create_tree(root, child, node_list, parent=new_node)

In [None]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, vocab_dict, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    # print('token', token)
    node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        variable_token_list.append(node.children[0].token)
        variable_id_list.append(node.children[0].id)
    for child in node.children:
        # print('child', child.token)
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, vocab_dict, src, tgt, variable_token_list, variable_id_list)

In [None]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, vocab_dict):
    node_list = []
    new_tree = AnyNode(id=0, token=None, data=None)
    create_tree(new_tree, ast, node_list)
    x = []
    edge_src = []
    edge_tgt = []
    edge_attr = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, vocab_dict, edge_src, edge_tgt, variable_token_list, variable_id_list)
    # print('variable_token_list', variable_token_list)
    # print('variable_id_list', variable_id_list)

    ast_edge_num = len(edge_src)
    # print('ast_edge_num', ast_edge_num)
    # set ast edge type to 0
    for _ in range(ast_edge_num):
        edge_attr.append([0])

    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    
    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    dataflow_edge_num = len(edge_src) - ast_edge_num
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    # print('dataflow_edge_num', dataflow_edge_num)
    return x, edge_index, edge_attr

Batch Data

In [None]:
def transform_to_pygdata(code):
    ast = parse_program(code)
    x, edge_index, edge_attr = get_pyg_data_from_ast(ast, vocab_astnodes)
    return Data(x=torch.tensor(x, dtype=torch.long),
                    edge_index=torch.tensor(edge_index, dtype=torch.long),
                    edge_attr=torch.tensor(edge_attr, dtype=torch.long))


In [None]:
pyg_datas = []
for index in raw_code_index:
    pyg_datas.append(transform_to_pygdata(raw_code['func'][index]))

raw_code['pyg_data'] = pyg_datas

In [None]:
raw_code['pyg_data'][10000832]

In [None]:
class PairData(Data):
    def __init__(self, edge_index_s, edge_attr_s, x_s, edge_index_t, edge_attr_t, x_t, label):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.edge_attr_s = edge_attr_s
        self.x_s = x_s
        self.edge_index_t = edge_index_t
        self.edge_attr_t = edge_attr_t
        self.x_t = x_t
        self.label = label
    
    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        # if key == 'edge_attr_s':
        #     return self.x_s.size(0)
        # if key == 'edge_attr_t':
        #     return self.x_t.size(0)
        # if key == 'label':
        #     return self.x_s.size(0)
        else:
            return super().__inc__(key, value)
    

In [None]:
def batch_data(data):
    batches = []
    for i in range(len(data)):
        pyg1 = raw_code['pyg_data'][data[i][0]]
        pyg2 = raw_code['pyg_data'][data[i][1]]
        label = torch.tensor(data[i][2], dtype=torch.long)
        # print('attr 1', pyg1.edge_attr)
        # print('attr 2', pyg2.edge_attr)
        pair_data = PairData(x_s=pyg1.x, edge_index_s=pyg1.edge_index, edge_attr_s = pyg1.edge_attr,
                                x_t=pyg2.x, edge_index_t=pyg2.edge_index, edge_attr_t = pyg2.edge_attr,
                                label=label)
        batches.append(pair_data)    
    return batches

In [None]:
train_batch = batch_data(train_data_small)
valid_batch = batch_data(valid_data_small)
test_batch = batch_data(test_data_small)

In [None]:
train_loader = DataLoader(train_batch, batch_size=batch_size, follow_batch=['x_s', 'x_t'])
valid_loader = DataLoader(valid_batch, batch_size=batch_size, follow_batch=['x_s', 'x_t'])
test_loader = DataLoader(test_batch, batch_size=batch_size, follow_batch=['x_s', 'x_t'])

Model

In [None]:
# partitioning D-AST in model instead of data-prepocessing
# partitioning D-AST in model by the num of nodes, which is set in the hyper-parameter `divide_node_num`

In [None]:
class SequenceGNNEncoder(torch.nn.Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num,
                    decoder_input_size, gine_dim, device):
        super(SequenceGNNEncoder, self).__init__()
        self.device = device
        self.embed = Embedding(vocab_len, graph_embedding_size)
        self.edge_embed = Embedding(2, 1) # only two edge types to be set weights, which are AST edge and data flow edge
        self.edge_embed2 = Embedding(2, graph_embedding_size) # GINE requires node and edge features to be of same dimensionality
        
        # use LSTM to obtain subgraph output at each time step
        self.divide_node_num = divide_node_num
        self.lstm = LSTM(input_size=graph_embedding_size, hidden_size=lstm_hidden_size, num_layers=lstm_layers_num)
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_layers_num = lstm_layers_num
        self.fc = Linear(graph_embedding_size + lstm_hidden_size, decoder_input_size)

        # GatedGraphConv + GlobalAttention 
        self.ggnnlayer = GatedGraphConv(graph_embedding_size, gnn_layers_num)
        self.gate_nn = Sequential(Linear(graph_embedding_size, graph_embedding_size), ReLU(), Linear(graph_embedding_size, 1), ReLU())
        self.nn = Sequential(Linear(graph_embedding_size, graph_embedding_size), ReLU())
        self.pool = GlobalAttention(gate_nn=self.gate_nn, nn=self.nn)

        # GINEConv
        self.gine_dim = gine_dim
        self.conv1 = GINEConv(Sequential(Linear(graph_embedding_size, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))
        self.conv2 = GINEConv(Sequential(Linear(gine_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))  
        self.conv3 = GINEConv(Sequential(Linear(gine_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))  
        self.conv4 = GINEConv(Sequential(Linear(gine_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))  
        self.conv5 = GINEConv(Sequential(Linear(gine_dim, gine_dim), BatchNorm1d(gine_dim), 
                                            ReLU(), Linear(gine_dim, gine_dim), ReLU()))
        self.lin1 = Linear(gine_dim, gine_dim)
        self.lin2 = Linear(gine_dim, graph_embedding_size)
        self.lin3 = Linear(graph_embedding_size, gine_dim)                                                                                                                                        

    # GatedGraphConv + GlobalAttention    
    def subgraph_forward(self, x, edge_index, edge_attr, batch):
        # print('edge_attr', edge_attr)
        # print('shape', edge_attr.shape)
        if type(edge_attr) == type(None):
            edge_weight = None
        else:
            edge_weight = self.edge_embed(edge_attr)
            edge_weight = edge_weight.squeeze(1)
        x = self.ggnnlayer(x, edge_index, edge_weight)
        return self.pool(x, batch=batch)

    # GINEConv
    # def subgraph_forward(self, x, edge_index, edge_attr, batch):
    #     if type(edge_attr) == type(None):
    #         edge_weight = None
    #     else:
    #         edge_weight = self.edge_embed2(edge_attr)
    #         edge_weight = edge_weight.squeeze(1)
    #     x = self.conv1(x, edge_index, edge_weight)
    #     edge_weight = self.lin3(edge_weight)
    #     x = self.conv2(x, edge_index, edge_weight)
    #     x = self.conv3(x, edge_index, edge_weight)
    #     x = self.conv4(x, edge_index, edge_weight)
    #     x = self.conv5(x, edge_index, edge_weight)
    #     x = global_add_pool(x, batch=batch)
    #     x = self.lin1(x).relu()
    #     x = F.dropout(x, p=0.5, training=self.training)
    #     x = self.lin2(x)
    #     return x
    
    # ResGatedGraphConv
    

    
    # partitioning multiple subgraphs by dynamic allocating edges
    def partition_graph(self, x, edge_index, edge_attr, batch):        
        nodes_list = [] # record all nodes number for each subgraph in total batch
        graph_pos_in_batch, graph_length = self.get_subgraph_info_from_batch(batch)
        max_seq_len = max(graph_length)
        subgraph_num = int(max_seq_len/self.divide_node_num) + 1
        for i in range(subgraph_num):
            nodes = []
            for j in range(len(graph_pos_in_batch)):
                if graph_length[j] > i * self.divide_node_num:
                    if graph_length[j] > (i+1) * self.divide_node_num:
                        subgraph_len = self.divide_node_num
                    else:
                        subgraph_len = graph_length[j] - i * self.divide_node_num   
                    for m in range(subgraph_len):
                        nodes.append(graph_pos_in_batch[j] + m)          
            nodes_list.append(set(nodes)) 
        # only count the edge whose target node in subgraph
        sub_edge_src = [[] for _ in range(subgraph_num)]
        sub_edge_tgt = [[] for _ in range(subgraph_num)]
        sub_edge_attr = [[] for _ in range(subgraph_num)]
        # print('nodes_list', nodes_list)
        node_num = len(x)
        node_subgraph_index = [0 for _ in range(node_num)] # use a list to store the subgraph numbers for all nodes
        for i in range(len(nodes_list)):
            for node in nodes_list[i]:
                node_subgraph_index[node] = i
    
        for i in range(len(edge_index[1])):
            src = edge_index[0][i].item()
            tgt = edge_index[1][i].item()
            sub_edge_src[node_subgraph_index[tgt]].append(src)
            sub_edge_tgt[node_subgraph_index[tgt]].append(tgt)
            sub_edge_attr[node_subgraph_index[tgt]].append(edge_attr[i].item())
        edge_index_list = []
        edge_attr_list = []
        for i in range(subgraph_num):
            edge_index_list.append(torch.tensor([sub_edge_src[i], sub_edge_tgt[i]], dtype=torch.long))
            edge_attr_list.append(torch.tensor(sub_edge_attr[i], dtype=torch.long))
        return edge_index_list, edge_attr_list

    def get_subgraph_info_from_batch(self, batch):
        comp = 0
        pos = 0
        graph_pos_in_batch = [0] # record begin positions and end positions of every subgraph
        graph_length = [] # use a list to store the node nums in subgraph
        for i in range(len(batch)):
            if batch[i] != comp:
                graph_pos_in_batch.append(i)
                graph_length.append(i-pos)
                comp = batch[i]
                pos = i
        graph_length.append(len(batch)-pos)
        return graph_pos_in_batch, graph_length        

    def forward(self, x, edge_index, edge_attr, batch):
        edge_index_list, edge_attr_list = self.partition_graph(x, edge_index, edge_attr, batch)
        x = self.embed(x)
        x = x.squeeze(1)
        subgraph_pool_list = [
            self.subgraph_forward(x, edge_index_list[i].to(self.device), edge_attr_list[i].to(self.device), batch)
            for i in range(len(edge_index_list))
        ]
        graph_pool = self.subgraph_forward(x, edge_index, edge_attr, batch)
        subgraph_pool_seq = torch.stack(subgraph_pool_list)
        h0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1), self.lstm_hidden_size).to(self.device)
        c0 = torch.zeros(self.lstm_layers_num, subgraph_pool_seq.size(1), self.lstm_hidden_size).to(self.device)
        subgraph_output, (_, _) = self.lstm(subgraph_pool_seq, (h0, c0))
        return self.fc(torch.cat((subgraph_output[-1], graph_pool), dim=1))
        
            

In [None]:
class SiameseDecoder(Module):
    def __init__(self, vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num, 
    decoder_input_size, gine_dim, device):
        super(SiameseDecoder, self).__init__()
        self.encoder = SequenceGNNEncoder(vocab_len, graph_embedding_size, gnn_layers_num, lstm_layers_num, lstm_hidden_size, divide_node_num, 
    decoder_input_size, gine_dim, device)
        self.fc = Linear(2* decoder_input_size, 2)
    
    def forward(self, x_s, edge_index_s, edge_attr_s, x_t, edge_index_t, edge_attr_t, x_s_batch, x_t_batch):
        decoder_input1 = self.encoder(x_s, edge_index_s, edge_attr_s, x_s_batch)
        decoder_input2 = self.encoder(x_t, edge_index_t, edge_attr_t, x_t_batch)
        output = torch.cat((decoder_input1, decoder_input2), dim=1)
        logits = self.fc(output) 
        return F.softmax(logits)      


Training

In [None]:
device = torch.device('cuda: 0')

model_args = {
    'vocab_len': vocab_astnodes_size,
    'graph_embedding_size': graph_embedding_size,
    'gnn_layers_num': gnn_layers_num,
    'lstm_layers_num': lstm_layers_num,
    'lstm_hidden_size': lstm_hidden_size,
    'divide_node_num': divide_node_num,
    'decoder_input_size': decoder_input_size,
    'gine_dim': gine_dim,
    'device': device
}

model = SiameseDecoder(**model_args).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda= lambda epoch: decay_ratio ** epoch)

In [None]:
model_summary = summary(model)

In [None]:
criterion = CrossEntropyLoss()

In [None]:
def compute_loss(data, model, optimizer=None, is_train=True):
    x_s = (data.x_s).to(device)
    edge_index_s = (data.edge_index_s).to(device)
    edge_attr_s = (data.edge_attr_s).to(device)
    x_t = (data.x_t).to(device)
    edge_index_t = (data.edge_index_t).to(device)
    edge_attr_t = (data.edge_attr_t).to(device)
    x_s_batch = (data.x_s_batch).to(device)
    x_t_batch = (data.x_t_batch).to(device)
    label = (data.label).to(device)
    pred = model(x_s, edge_index_s, edge_attr_s, x_t, edge_index_t, edge_attr_t, x_s_batch, x_t_batch)
    loss = criterion(pred, label)

    if is_train:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
    
    return loss.item(), label.cpu().detach().numpy(), pred.cpu().detach().numpy()

In [None]:
class EarlyStopping(object):
    def __init__(self, filename = None, patience=3, warm_up=0, verbose=False):

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.warm_up = warm_up
        self.filename = filename

    def __call__(self, score, model, epoch):

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(score, model)
            
        elif (score <= self.best_score) and (epoch > self.warm_up) :
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            if (epoch <= self.warm_up):
                print('Warming up until epoch', self.warm_up)
            
            else:
                if self.verbose:
                    print(f'Score improved. ({self.best_score:.6f} --> {score:.6f}).')
                
                self.best_score = score
                self.save_checkpoint(score, model)
                self.counter = 0

    def save_checkpoint(self, score, model):
        
        if self.filename is not None:
            torch.save(model.state_dict(), self.filename)
            
        if self.verbose:
            print('Model saved...')

In [None]:
fname = exp_dir + save_name
early_stopping = EarlyStopping(fname, patience, warm_up, verbose=True)

In [None]:
import numpy as np
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score

def calculate(logits, y_trues):
    logits=np.concatenate(logits,0)
    y_trues=np.concatenate(y_trues,0)
    best_threshold=0
    best_f1=0
    for i in range(1,100):
        threshold=i/100
        y_preds=logits[:,1]>threshold        
        recall=recall_score(y_trues, y_preds, average='macro')        
        precision=precision_score(y_trues, y_preds, average='macro')           
        f1=f1_score(y_trues, y_preds, average='macro') 
        if f1>best_f1:
            best_f1=f1
            best_threshold=threshold

    y_preds=logits[:,1]>best_threshold
    recall=recall_score(y_trues, y_preds, average='macro')
    precision=precision_score(y_trues, y_preds, average='macro')   
    f1=f1_score(y_trues, y_preds, average='macro')
    return recall, precision, f1


In [None]:
for epoch in range(1, num_epoches + 1):
    train_loss = 0.
    train_refs = []
    train_hyps = []
    valid_loss = 0.
    valid_refs = []
    valid_hyps = []

    # train
    model.train()
    for data in tqdm(train_loader, total=len(train_loader), desc='TRAIN'):
        loss, gold, pred = compute_loss(data, model, optimizer, is_train=True)
        # msgr.print_msg('train loss {}'.format(loss))
        train_loss += loss
        train_refs.append(gold)
        train_hyps.append(pred)
    

    # valid
    model.eval()
    for data in tqdm(valid_loader, total=len(valid_loader), desc='VALID'):
        loss, gold, pred = compute_loss(data, model, optimizer, is_train=False)
        # msgr.print_msg('valid loss {}'.format(loss))
        valid_loss += loss
        valid_refs.append(gold)
        valid_hyps.append(pred)
    
    
    train_loss = np.sum(train_loss) / len(train_loader)
    valid_loss = np.sum(valid_loss) / len(valid_loader)
    
    train_recall, train_precision, train_f1 = calculate(train_hyps, train_refs)
    valid_recall, valid_precision, valid_f1 = calculate(valid_hyps, valid_refs)

    # msgr.print_msg('train_hyps {}'.format(train_hyps[0]))
    # msgr.print_msg('valid_hyps {}'.format(valid_hyps[0]))


    msgr.print_msg('Epoch {}: train_loss: {:5.2f} train_recall: {:2.4f} train_precision: {:2.4f}  train_f1: {:2.4f}  valid_loss: {:5.2f}  valid_recall: {:2.4f} valid_precision: {:2.4f} valid_f1: {:2.4f}'.format(
            epoch, train_loss, train_recall, train_precision, train_f1, valid_loss, valid_recall, valid_precision, valid_f1))
    
    early_stopping(valid_f1, model, epoch)
    if early_stopping.early_stop:
        msgr.print_msg("Early stopping")
        break  
        
    print('-'*80)
    scheduler.step()


Test