In [1]:
import sys
sys.path.append('../')
from utils import *
from model import *
from collections import defaultdict

In [2]:
data_dir = '/datadrive/data_cs/'
batch_size = 256
batch_num  = 128
epoch_num  = 200

device = torch.device("cuda:1")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t < 2015}
valid_range = {t: True for t in graph.times if t != None and t >= 2015  and t <= 2016}
test_range  = {t: True for t in graph.times if t != None and t > 2016}

In [None]:
# len(graph.node_feature['paper']['time'])
# tot = 0
# for t in train_range:
#     tot += np.sum(graph.node_feature['paper']['time'] == t)
# print(tot)
# tot = 0
# for t in valid_range:
#     tot += np.sum(graph.node_feature['paper']['time'] == t)
# print(tot)
# tot = 0
# for t in test_range:
#     tot += np.sum(graph.node_feature['paper']['time'] == t)
# print(tot)
# l = [377082, 77467, 89695]
# l = np.array(l)
# l / np.sum(l)

In [36]:
class GNN(nn.Module):
    def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.3):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.num_types = num_types
        self.in_dim    = in_dim
        self.n_hid     = n_hid
        self.aggregat_ws   = nn.ModuleList()
        self.drop          = nn.Dropout(dropout)
        for t in range(num_types):
            self.aggregat_ws.append(nn.Linear(in_dim, n_hid))
        for l in range(n_layers):
            self.gcs.append(RAGCNConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
        res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
        for t_id in range(self.num_types):
            aggregat_w = self.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = self.drop(res)
        del res
        for gc in self.gcs:
            meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
        return meta_xs
    
class CPC_Predictor(nn.Module):
    def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.3):
        super(CPC_Predictor, self).__init__()
        self.gnn = GNN(in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout)
        self.matcher = nn.Linear(n_hid, n_heads)
        self.score   = nn.Linear(n_heads, 1)
        
    def forward(self, pos_paper_ids, neg_paper_ids, node_feature, node_type, edge_time, edge_index, edge_type):
        meta_xs = self.gnn(node_feature, node_type, edge_time, edge_index, edge_type)
        meta_xs = self.score(F.elu(self.matcher(meta_xs)))
        pos_res = meta_xs[pos_paper_ids]
        neg_res = meta_xs[neg_paper_ids]
        neg_res = neg_res.view(pos_res.size(0), neg_res.size(0) // pos_res.size(0))
        return torch.cat([pos_res, neg_res], dim=-1)

def gen_negative_sample(dim1, dim2, num):
    res = []
    for i in range(num):
        res += [[np.random.randint(dim1), np.random.randint(dim2)]]
    return res

def neg_sample(size, pos_id, num):
    res = {}
    while len(res) != num:
        s = np.random.choice(size)
        if s in res or s == pos_id:
            continue
        res[s] = True
    return list(res.keys())

In [None]:
def to_torch(feature, time, edge_list, graph):
    '''
        Transform a sampled sub-graph into pytorch Tensor
        node_dict: {node_type: <node_number, node_type_ID>} node_number is used to trace back the nodes in original graph.
        edge_dict: {edge_type: edge_type_ID}
    '''
    node_dict = {}
    node_feature = []
    node_type    = []
    node_time    = []
    edge_index   = []
    edge_type    = []
    edge_time    = []
    
    node_num = 0
    for t in graph.get_types():
        type_id = len(node_dict)
        if t == 'fake_paper':
            type_id = -1
        node_dict[t] = [node_num, type_id]
        node_num     += len(feature[t])
    if 'fake_paper' in node_dict:
        node_dict['fake_paper'][1] = node_dict['paper'][1]
    
    for t in graph.get_types():
        node_feature += list(feature[t])
        node_time    += list(time[t])
        type_id = node_dict[t][1]
        node_type    += [type_id for _ in range(len(feature[t]))]
        
    edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())}
    edge_dict['self'] = len(edge_dict)
    for target_type in edge_list:
        for source_type in edge_list[target_type]:
            for relation_type in edge_list[target_type][source_type]:
                for ti, si in edge_list[target_type][source_type][relation_type]:
                    sid, tid = si + node_dict[source_type][0], ti + node_dict[target_type][0]
                    edge_index += [[sid, tid]]
                    edge_type  += [edge_dict[relation_type]]   
                    '''
                        Our time ranges from 1900 - 2020, largest span is 120.
                    '''
                    edge_time  += [node_time[tid] - node_time[sid] + 120]
    node_feature = torch.FloatTensor(node_feature)
    node_type    = torch.LongTensor(node_type)
    edge_time    = torch.LongTensor(edge_time)
    edge_index   = torch.LongTensor(edge_index).t()
    edge_type    = torch.LongTensor(edge_type)
    return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict

In [50]:
def cpc_sample(seed, in_feature, in_times, in_edge_list, batch_size, neg_size, ratio):
    np.random.seed(seed)
    feature, times, edge_list = copy.deepcopy(in_feature), copy.deepcopy(in_times), copy.deepcopy(in_edge_list)
    paper_ids = np.random.choice(len(times['paper']), batch_size, replace = False)
    shuff_ids = np.random.choice(len(times['paper']), batch_size * neg_size, replace = False)
    paper_dict = {p: i for i, p in enumerate(paper_ids)}
    feature['fake_paper'] = np.array(feature['paper'])[paper_ids].repeat(neg_size, 0)
    times['fake_paper']   = np.array(times['paper'])[paper_ids].repeat(neg_size, 0)
    te = edge_list['paper']
    for source_type in te:
        tes = te[source_type]
        for relation_type in tes:
            tesr = tes[relation_type]
            tesd = defaultdict(lambda: [])
            for target_ser, source_ser in tesr:
                tesd[target_ser] += [source_ser]
            for target_ser in paper_dict:
                for j in range(neg_size):
                    fake_ser = paper_dict[target_ser] * neg_size + j
                    shuf_ser = shuff_ids[fake_ser]
                    for source_ser in tesd[target_ser]:
                        if relation_type == 'self':
                            edge_list['fake_paper']['fake_paper']['self'] += [[fake_ser, fake_ser]]
                        elif np.random.random() >= ratio or len(tesd[shuf_ser]) == 0:
                            edge_list['fake_paper'][source_type][relation_type] += [[fake_ser, source_ser]]
                        else:
                            if len(tesd[shuf_ser]) > 0:
                                rd_ser = np.random.choice(tesd[shuf_ser])
                                edge_list['fake_paper'][source_type][relation_type] += [[fake_ser, rd_ser]]
    return feature, times, edge_list, paper_ids

def cpc_loss(data, paper_ids, num_fake, model, device):
    node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict  = data
    pos_paper_ids = paper_ids + node_dict['paper'][0]
    neg_paper_ids = np.arange(num_fake) + node_dict['fake_paper'][0]
    pred = model.forward(pos_paper_ids, neg_paper_ids, node_feature.to(device), node_type.to(device), \
                         edge_time.to(device), edge_index.to(device), edge_type.to(device))
    return -torch.log_softmax(pred, dim=-1)[:, 0], pred

def random_sample(seed, t_range, ratio, sampled_depth = 4, sampled_number = 100, neg_num = 3, batch_size = 128):
    np.random.seed(seed)
    feature, _time, edge_list, _ = sample_subgraph(graph, t_range, inp = None, \
                        sampled_depth = sampled_depth, sampled_number = sampled_number)  
    fake_feature, fake_time, fake_edge_list, paper_ids = cpc_sample(seed, feature, _time, edge_list, batch_size, neg_num, ratio)
    return to_torch(fake_feature, fake_time, fake_edge_list, graph), paper_ids, len(fake_time['fake_paper'])

def prepare_data(pool, process_ids):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(random_sample, args=(np.random.randint(2**32 - 1), train_range, np.random.random() * 0.4 + 0.1))
        jobs.append(p)
    p = pool.apply_async(random_sample, args=(np.random.randint(2**32 - 1), valid_range, 0.3))
    jobs.append(p)
    return jobs

In [42]:
cpc_predictor = CPC_Predictor(in_dim = len(graph.node_feature['paper']['emb'][0]) + 401, n_hid = 256, num_types = len(graph.get_types()), \
                              num_relations = len(graph.get_meta_graph()) + 1, n_heads = 8, n_layers = 3).to(device)
optimizer = torch.optim.AdamW(cpc_predictor.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=1e-6)

In [53]:
stats = []
pool = mp.Pool(8)
process_ids = np.arange(batch_num // 8)
st = time.time()
jobs = prepare_data(pool, process_ids)
train_step = 1500
best_val   = 0
for epoch in np.arange(epoch_num)+1:
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.close()
    pool.join()
    pool = mp.Pool(8)
    jobs = prepare_data(pool, process_ids)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    train_losses = []
    cpc_predictor.train()
    torch.cuda.empty_cache()
    for data, paper_ids, num_fake in train_data:
        loss, pred = cpc_loss(data, paper_ids, num_fake, cpc_predictor, device)
        optimizer.zero_grad() 
        loss.mean().backward()
        torch.nn.utils.clip_grad_norm_(cpc_predictor.parameters(), 0.2)
        optimizer.step()
        train_losses += loss.cpu().detach().tolist()
        train_step += 1
        scheduler.step(train_step)
        del loss, pred
        torch.cuda.empty_cache()
    '''
        Valid
    '''
    cpc_predictor.eval()
    with torch.no_grad():
        data, paper_ids, num_fake = valid_data
        loss, pred = cpc_loss(data, paper_ids, num_fake, cpc_predictor, device)
        valid_losses = loss.cpu().detach().tolist()
        s = (pred.argmax(dim=1) == 0).sum()
        valid_acc = s.tolist()/batch_size
        st = time.time()
        print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %f  Valid Loss: %f  Valid Acc: %f") % \
              (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), np.average(valid_losses), valid_acc))
        if valid_acc > best_val:
            best_val = valid_acc
            torch.save(cpc_predictor.gnn, './save/cpc_model.pt')
        stats += [[train_losses, valid_losses]]
        
        if epoch % 5 == 0:
            feature, _time, edge_list = sample_subgraph(graph, test_range, inp = None, sampled_depth = 4, sampled_number = 128)
            losses = []
            accs   = []
            for ratio in np.arange(10) / 10:
                fake_feature, fake_time, fake_edge_list, paper_ids = cpc_sample(np.random.randint(2**32 - 1), \
                                                                feature, _time, edge_list, batch_size, 5, ratio)
                data = to_torch(fake_feature, fake_time, fake_edge_list, graph)
                loss, pred = cpc_loss(data, paper_ids, len(fake_time['fake_paper']), cpc_predictor, device)
                for li in loss.tolist():
                    losses += [[li, ratio]]
                s = (pred.argmax(dim=1) == 0).sum()
                acc = s.tolist()/batch_size
                accs += [acc]
                del loss, pred
                torch.cuda.empty_cache()
            sb.lineplot(data = pd.DataFrame(losses, columns=['loss', 'ratio']), x = 'ratio', y='loss')
            plt.show()
            plt.plot(accs)
            plt.show()

IndexError: list index out of range

In [None]:
def to_torch(feature, time, edge_list, graph):
    '''
        Transform a sampled sub-graph into pytorch Tensor
        node_dict: {node_type: <node_number, node_type_ID>} node_number is used to trace back the nodes in original graph.
        edge_dict: {edge_type: edge_type_ID}
    '''
    node_dict = {}
    node_feature = []
    node_type    = []
    node_time    = []
    edge_index   = []
    edge_type    = []
    edge_time    = []
    
    node_num = 0
    for t in graph.get_types():
        type_id = len(node_dict)
        node_dict[t] = [node_num, type_id]
        node_num     += len(feature[t])
    if 'fake_paper' in feature:
        node_dict['fake_paper'] = [node_num, node_dict['paper'][1]]
        node_num     += len(feature['fake_paper'])
    for t in node_dict:
        node_feature += list(feature[t])
        node_time    += list(time[t])
        type_id = node_dict[t][1]
        node_type    += [type_id for _ in range(len(feature[t]))]
        
    edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())}
    edge_dict['self'] = len(edge_dict)
    for target_type in edge_list:
        for source_type in edge_list[target_type]:
            for relation_type in edge_list[target_type][source_type]:
                for ti, si in edge_list[target_type][source_type][relation_type]:
                    sid, tid = si + node_dict[source_type][0], ti + node_dict[target_type][0]
                    edge_index += [[sid, tid]]
                    edge_type  += [edge_dict[relation_type]]   
                    '''
                        Our time ranges from 1900 - 2020, largest span is 120.
                    '''
                    edge_time  += [node_time[tid] - node_time[sid] + 120]
    node_feature = torch.FloatTensor(node_feature)
    node_type    = torch.LongTensor(node_type)
    edge_time    = torch.LongTensor(edge_time)
    edge_index   = torch.LongTensor(edge_index).t()
    edge_type    = torch.LongTensor(edge_type)
    return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict

In [None]:
losses = []
for epoch, (train_losses, valid_losses) in enumerate(stats):
    for ti in train_losses:
        losses += [[ti, 'Train', epoch]]
#     for ti in valid_losses:
#         losses += [[ti, 'Valid', epoch]]

In [None]:
sb.lineplot(data = pd.DataFrame(losses, columns=['loss', 'Type', 'Epoch']), x = 'Epoch', y='loss', hue='Type')
plt.show()