In [1]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [2]:
data_dir = '/datadrive/data/'
batch_size = 256
batch_num  = 128
epoch_num  = 1000
samp_num   = 7

device = torch.device("cuda:1")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t <= 2015}
valid_range = {t: True for t in graph.times if t != None and t > 2015  and t < 2018}
test_range  = {t: True for t in graph.times if t != None and t >= 2018}

In [4]:
class GNN(nn.Module):
    def __init__(self, n_hid, num_types, num_relations, n_heads, n_layers, device, dropout = 0.3):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.adapt = nn.Linear(n_hid, n_hid)
        self.n_hid = n_hid
        for l in range(n_layers):
            self.gcs.append(RAGCNConv(n_hid, n_hid, num_types, num_relations, n_heads, device, dropout))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
        meta_xs = F.elu(self.adapt(node_feature))
        for gc in self.gcs:
            meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
        return meta_xs
    def __repr__(self):
        return '{}(n_hid={}, n_layers={})'.format(
            self.__class__.__name__, self.n_hid, len(gcs))
    
class CPC_Predictor(nn.Module):
    def __init__(self, n_hid, num_types, num_relations, n_heads, n_layers, device, dropout = 0.3):
        super(CPC_Predictor, self).__init__()
        self.gnn = GNN(n_hid, num_types, num_relations, n_heads, n_layers, device, dropout)
        self.matcher = MatchConv(n_hid, n_hid, num_types, num_relations, n_heads, device, dropout)
        self.score   = nn.Linear(n_heads, 1)
        
    def forward(self, pos_paper_ids, neg_paper_ids, node_feature, node_type, edge_time, edge_index, edge_type):
        meta_xs = self.gnn(node_feature, node_type, edge_time, edge_index, edge_type)
        meta_xs = self.matcher(meta_xs, node_type, edge_index, edge_type)
        pos_res = self.score(meta_xs[pos_paper_ids])
        neg_res = self.score(meta_xs[neg_paper_ids])
        neg_res = neg_res.view(pos_res.size(0), neg_res.size(0) // pos_res.size(0))
        return torch.cat([pos_res, neg_res], dim=-1)

def gen_negative_sample(dim1, dim2, num):
    res = []
    for i in range(num):
        res += [[np.random.randint(dim1), np.random.randint(dim2)]]
    return res

def neg_sample(size, pos_id, num):
    res = {}
    while len(res) != num:
        s = np.random.choice(size)
        if s in res or s == pos_id:
            continue
        res[s] = True
    return list(res.keys())

In [5]:
def to_torch(feature, time, edge_list):
    ser = 0
    node_dict = {}
    for t in feature:
        if t != 'fake_paper':
            node_dict[t] = [0, ser]
            ser += 1
    node_dict['fake_paper'] = [0, node_dict['paper'][1]]
    node_num = 0
    node_feature = []
    node_type    = []
    node_time    = []
    for t in feature:
        node_dict[t][0] = node_num
        node_feature += list(feature[t])
        node_time    += list(time[t])
        node_type    += [node_dict[t][1] for _ in range(len(feature[t]))]
        node_num     += len(feature[t])
    edge_index = []
    edge_type  = []
    edge_time  = []
    edge_dict  = {}
    for target_type in edge_list:
        for source_type in edge_list[target_type]:
            for relation_type in edge_list[target_type][source_type]:
                if relation_type not in edge_dict:
                    edge_dict[relation_type] = len(edge_dict)
                for ti, si in edge_list[target_type][source_type][relation_type]:
                    sid, tid = si + node_dict[source_type][0], ti + node_dict[target_type][0]
                    edge_index += [[sid, tid]]
                    edge_type  += [edge_dict[relation_type]]   
                    edge_time  += [node_time[tid] - node_time[sid] + 120]
    node_feature = torch.FloatTensor(node_feature).to(device)
    node_type    = torch.LongTensor(node_type).to(device)
    edge_time    = torch.LongTensor(edge_time).to(device)
    edge_index   = torch.LongTensor(edge_index).to(device).t()
    edge_type    = torch.LongTensor(edge_type).to(device)
    return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict

def mask(feature, time, edge_list, batch_size, target_type, source_type, rel_type):
    edges = np.array(edge_list[target_type][source_type][rel_type])
    rem_ids = np.random.choice(np.arange(len(edges)), batch_size, replace = False)
    ori_ids = np.array([i for i in range(len(edges)) if i not in rem_ids])
    ori_edges = edges[ori_ids]
    rem_edges = edges[rem_ids]

    edge_list[target_type][source_type][rel_type] = list(ori_edges)
    edge_list[source_type][target_type]['rev_' + rel_type] = list(np.stack((ori_edges[:,1], ori_edges[:,0])).T)
    node_feature, node_type, node_time, edge_index, edge_type, node_dict, _ = to_torch(feature, time, edge_list)
    edge_list[target_type][source_type][rel_type] = list(edges)
    edge_list[source_type][target_type]['rev_' + rel_type] = list(np.stack((edges[:,1], edges[:,0])).T)
    return rem_edges, node_feature, node_type, node_time, edge_index, edge_type, node_dict

def edge_loss(gnn, predictor, feature, time, edge_list, batch_size, target_type, source_type, rel_type):
    rem_edges, node_feature, node_type, node_time, edge_index, edge_type, node_dict = \
            mask(feature, time, edge_list, batch_size, target_type, source_type, rel_type)
    target_size = len(feature[target_type])
    positive_target_ids, source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1] + node_dict[source_type][0]
    negative_target_ids = np.array([neg_sample(target_size, pos_id, samp_num - 1) for pos_id in positive_target_ids])
    target_ids = np.concatenate((positive_target_ids, negative_target_ids), axis=-1) + node_dict[target_type][0]

    node_emb = gnn.forward(node_feature, node_type, node_time, edge_index, edge_type)
    source_emb = node_emb[source_ids].repeat(1, target_ids.shape[1]\
                                      ).view(target_ids.shape[0] * target_ids.shape[1], -1)
    
    target_emb = node_emb[target_ids].view(target_ids.shape[0] * target_ids.shape[1], -1)
    res = predictor.forward(source_emb, target_emb).view(-1, target_ids.shape[1])
    return -torch.log_softmax(res, dim=-1)[:,0].mean()

In [6]:
batch_size = 128
batch_num  = 128
epoch_num  = 100
samp_num   = 7

device = torch.device("cuda:1")
train_feature, train_time, train_edge_list = \
        graph.sample_subgraph(time_range=train_range, sampled_depth = 4, sampled_number = 64)
node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict  = \
        to_torch(train_feature, train_time, train_edge_list)

In [7]:
gnn = GNN(n_hid = 400, num_types = len(node_dict), num_relations = len(edge_dict), \
                       n_heads = 4, n_layers = 2, device = device, dropout = 0.5).to(device)
fp_predictor = NTN(200, 200, 4).to(device)
vp_predictor = NTN(200, 200, 4).to(device)
pp_predictor = NTN(200, 200, 4).to(device)
ap_predictor = NTN(200, 200, 4).to(device)

In [8]:
optimizer = torch.optim.Adam(list(fp_predictor.parameters()) + list(vp_predictor.parameters()) + \
                             list(pp_predictor.parameters()) + list(ap_predictor.parameters()) + list(gnn.parameters()))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 2000, eta_min=1e-5)

In [None]:
stats = []
train_step = 0
best_val   = 100
models = [gnn, fp_predictor, vp_predictor, pp_predictor, ap_predictor]
task_set = [[fp_predictor, 'field', 'paper', 'PF_in'],\
            [vp_predictor, 'venue', 'paper', 'PV_in'],\
            [pp_predictor, 'paper', 'paper', 'PP_cite'],\
            [ap_predictor, 'paper', 'author', 'AP_write']]
for epoch in np.arange(epoch_num)+1:
    train_losses = []
    valid_losses = []
    torch.cuda.empty_cache()
    for m in models:
        m.train()
    for out_batch in np.arange(batch_num // 4) + 1:
        train_feature, train_time, train_edge_list = \
            graph.sample_subgraph(time_range=train_range, sampled_depth = 5, sampled_number = 64)
        for predictor, target_type, source_type, rel_type in task_set:
            '''
                Train
            '''
            loss = edge_loss(gnn, predictor, train_feature, train_time, train_edge_list, batch_size, target_type, source_type, rel_type)
            optimizer.zero_grad() 
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
    '''
        Valid
    '''
    for m in models:
        m.eval()
    valid_feature, valid_time, valid_edge_list = \
        graph.sample_subgraph(time_range=valid_range, sampled_depth = 5, sampled_number = 64)
    for predictor, target_type, source_type, rel_type in task_set:
        '''
            Train
        '''
        loss = edge_loss(gnn, predictor, valid_feature, valid_time, valid_edge_list, batch_size, target_type, source_type, rel_type)
        valid_losses += [loss.cpu().detach().tolist()]
    print(("Epoch: %d  LR: %.5f Train Loss: %f  Valid Loss: %s") % \
          (epoch, optimizer.param_groups[0]['lr'], np.average(train_losses), valid_losses))
    '''
        Test
    '''
#     test_feature, test_time, test_edge_list = \
#         graph.sample_subgraph(time_range=test_range, sampled_depth = 5, sampled_number = 128)
#     res = []
#     for i in range(21):
#         ratio = i / 20.
#         loss, pred = cpc_loss(test_feature, test_time, test_edge_list, cpc_predictor, batch_size, samp_num, ratio = ratio)
#         for i in loss.detach().cpu().tolist():
#             res += [[ratio, i]]
#         del loss, pred
#     sb.lineplot(x="ratio", y="loss",
#                  data=pd.DataFrame(res, columns = ['ratio', 'loss']))
#     plt.show()
    stats += [[train_losses, valid_losses]]
    if np.average(valid_losses) < best_val:
        best_val = np.average(valid_losses)
        torch.save(gnn, './save/mt_model.pt')

Epoch: 1  LR: 0.00099 Train Loss: 1.940670  Valid Loss: [1.2019134759902954, 2.1395201683044434, 1.9515711069107056, 1.9282042980194092]


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 2  LR: 0.00096 Train Loss: 1.697782  Valid Loss: [1.7943284511566162, 1.992253303527832, 1.956581950187683, 1.9190278053283691]
Epoch: 3  LR: 0.00091 Train Loss: 1.679364  Valid Loss: [1.0076727867126465, 2.341188907623291, 1.9548419713974, 1.8858458995819092]


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 4  LR: 0.00085 Train Loss: 1.653797  Valid Loss: [1.3943597078323364, 2.197718620300293, 1.9615848064422607, 1.867427110671997]
Epoch: 5  LR: 0.00077 Train Loss: 1.644723  Valid Loss: [1.5225019454956055, 2.1239171028137207, 1.946450114250183, 1.897566556930542]
Epoch: 6  LR: 0.00068 Train Loss: 1.616090  Valid Loss: [1.135371208190918, 2.09604811668396, 1.9305073022842407, 1.8914133310317993]


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 7  LR: 0.00059 Train Loss: 1.596987  Valid Loss: [0.9930037260055542, 2.379387140274048, 1.8653638362884521, 1.8962191343307495]
Epoch: 8  LR: 0.00049 Train Loss: 1.569882  Valid Loss: [1.0702203512191772, 2.2987213134765625, 2.0838184356689453, 1.888795256614685]
Epoch: 9  LR: 0.00039 Train Loss: 1.558523  Valid Loss: [1.2430710792541504, 2.6062722206115723, 1.9731640815734863, 1.9110252857208252]
Epoch: 10  LR: 0.00029 Train Loss: 1.547903  Valid Loss: [0.6451696157455444, 2.0382494926452637, 1.932870626449585, 1.8846811056137085]


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 11  LR: 0.00021 Train Loss: 1.525456  Valid Loss: [0.9846047163009644, 2.2114388942718506, 1.9248523712158203, 1.8778998851776123]
Epoch: 12  LR: 0.00014 Train Loss: 1.506042  Valid Loss: [0.844925045967102, 2.3784337043762207, 1.8391391038894653, 1.8840439319610596]
