## RGCN模型

问题： 
- 关于node和entity的关系
- ns_mode的作用
- 为什么link predicition中num_rel * 2?


In [None]:

#  AIFB数据集 原本数据
#  Nodes: 7262 
#  num_edges: 48,810 (including reverse edges)
#  Target Category: Personen
#  Number of Classes: 4
#  Label Split:train:140, test:36
# entities = nodes? : 8285
# relations: 45
# edges: 29,043


class RGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels,
                 regularizer="basis", num_bases=-1, dropout=0.,
                 self_loop=False,
                 ns_mode=False):
    # 与dgl官方实现的对比             
    ''' def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
                 activation=None, is_input_layer=False):
    '''

        # num_nodes
        # (h_i)^(l): 节点i的第l层节点表示
        # in_feat(int) = h_dim: the number of dimension of (h_i)^(l)
        # out_feat(int) = out_dim the number of dimensions of (h_i)^(l+1)
        # num_rel(int):number of relations 边类型的数量(关系数量)
        # regularizer: basis (基函数分解) or bdd(block-diagonal-decomposition) (块对角分解)
        # self_loop: 是否加入自身节点表示 
        # num_base:  W_r 分解的数量，对应公式（3）B值，即累加求和上限 (needed when 'regularizer' is specified)
          # 是我们人为给定的量
        # is_input_layer = 是否是输入层（第一层）= ns_mode???
    
        super(RGCN, self).__init__()

        if num_bases == -1: 
            num_bases = num_rels
        '''if self.num_bases <=0 (i.e. == -1) or self.num_bases > self.num_rels:
                self.num_bases = self.num_rels
                即矩阵分解的参数校验——对应论文的正则化部分
                num_base 不能小于0 且不能大于现有维度，否则复杂度会变高，参数反而增加'''
        # RelGraphConv(in_feat,out_feat,num_rels,regularizer,num_bases,bias,activation,self_loop,dropout, layer_norm=False)
        # 以FB15k-237数据集输出为例
        self.emb = nn.Embedding(num_nodes, h_dim) # (14541,500)
        # in_feat = out_feat = h_dim = 500
        self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer,
                                  num_bases, self_loop=self_loop) # num_rel = 474 = 237 * 2
        self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer, num_bases, self_loop=self_loop)
        
        self.dropout = nn.Dropout(dropout)
        self.ns_mode = ns_mode 



![formula_of_rgcn.jpg](https://s2.loli.net/2022/07/10/Gn9VWSRiMywT8oF.jpg)

In [None]:
# 原始的节点或边的类型和对应的ID被存储在 ndata 和 edata 中
def forward(self, g, nids=None): # nids:原始的特定类型节点ID 
        # 利用公式（1）计算？
        if self.ns_mode: # 如果不是第一层
            # forward for neighbor sampling
            # dgl 中的边都是有向的，由 u指向 v
            # g[0]代表传入节点: u
            # g[1]表示传出节点：v  
            # dgl.NID: 节点特征, dgl.EID: 边特征
            x = self.emb(g[0].srcdata[dgl.NID]) # x是embedding之后的传入节点特征
            # 即在第一层输入传入节点u，embedding后的传入节点特征，传入节点特征的边特征， 传入节点的边的l2 norm(欧几里得范数)
            h = self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']) # g[0].edata[dgl.ETYPE]：传入节点的边类型
            h = self.dropout(F.relu(h))
            # 第二层输入传出节点v，第一层输出结果h, 传出节点的边特征，传出节点的边的欧几里得范数
            h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm']) # g[1].edata[dgl.ETYPE]: 传出节点的边特征
            return h
            
        else: #ns_mode = False 如果是第一层，用以下代码生成公式里的(W_0)*(h_i)项
            # Embedding.weight (Tensor) – the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized from normal distribution N(0,1)
            # x 是初始化的权重张量
        
            x = self.emb.weight if nids is None else self.emb(nids)
            h = self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']) 
            h = self.dropout(F.relu(h))
            h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
            return h


'''
        dgl对于图的表示方法
         # Source nodes for edges (2, 1), (3, 2), (4, 3)
            src_ids = torch.tensor([2, 3, 4])
            # Destination nodes for edges (2, 1), (3, 2), (4, 3)
            dst_ids = torch.tensor([1, 2, 3])
            g = dgl.graph((src_ids, dst_ids))
        '''
        # 分配和检索 node and edge features via ndata and edata
            '''for example,
            # Assign a 3-dimensional node feature vector for each node.
            g.ndata['x'] = torch.randn(6, 3)
            # Assign a 4-dimensional edge feature vector for each edge.
            g.edata['a'] = torch.randn(5, 4)
            # Assign a 5x4 node feature matrix for each node.  Node and edge features in DGL can be multi-dimensional.
            g.ndata['y'] = torch.randn(6, 5, 4)

            print(g.edata['a'])

            output:
            tensor([[ 0.0498,  1.2527,  0.1431, -0.7624],
            [ 0.0602,  1.6373,  0.5788,  2.6319],
            [ 0.4912,  0.3511, -1.1502, -0.1934],
            [-1.5344, -0.4983,  1.6341, -0.2023],
            [-0.5982, -3.5050, -1.2111,  0.0091]])
            '''

## 实体分类(Entity Classification)

In [None]:
def main(args):
    g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data(
        args.dataset, get_norm=True)

    model = RGCN(g.num_nodes(),
                 args.n_hidden,
                 num_classes,
                 num_rels,
                 num_bases=args.n_bases)

    if args.gpu >= 0 and th.cuda.is_available():
        device = th.device(args.gpu)
    else:
        device = th.device('cpu')
    labels = labels.to(device)
    model = model.to(device)
    g = g.int().to(device)

    optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)

    model.train()
    for epoch in range(100):
        logits = model(g)
        logits = logits[target_idx]
        loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc = accuracy(logits[train_idx].argmax(dim=1), labels[train_idx]).item()
        print("Epoch {:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
            epoch, train_acc, loss.item()))
    print()

    model.eval()
    with th.no_grad():
        logits = model(g)
    logits = logits[target_idx]
    test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
    print("Test Accuracy: {:.4f}".format(test_acc))

In [None]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for entity classification')
    parser.add_argument("--n-hidden", type=int, default=16,
                        help="number of hidden units")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--n-bases", type=int, default=-1,
                        help="number of filter weight matrices, default: -1 [use all]")
    parser.add_argument("-d", "--dataset", type=str, required=True,
                        choices=['aifb', 'mutag', 'bgs', 'am'],
                        help="dataset to use")
    parser.add_argument("--wd", type=float, default=5e-4,
                        help="weight decay")

    args = parser.parse_args()
    print(args)
    main(args)
    

## 关系预测(Link Prediction)

In [None]:
# 与entity classification 完全相同，只是将 relation 个数乘了2
# 将正则方式换成块对角分解
# forFB15k-237 数据集
        # h_dim：500
        # num_bases:100(人为设置，即公式(4)中的B i.e. 累加上限)
        # out_dim:500
        # num_nodes: 14541
        # num_rel:474 = 237 * 2(乘2原因？)

class LinkPredict(nn.Module):
    def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100, dropout=0.2, reg_param=0.01):
        super(LinkPredict, self).__init__()
        # 这里为什么num_rel要乘2？？
        # reg_param: 正则化参数 : l2 regularization to the decoder with penalty of 0.01
        self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, regularizer="bdd",
                         num_bases=num_bases, dropout=dropout, self_loop=True)
        self.dropout = nn.Dropout(dropout)
        self.reg_param = reg_param
        self.w_relation = nn.Parameter(th.Tensor(num_rels, h_dim))
        nn.init.xavier_uniform_(self.w_relation,
                                gain=nn.init.calculate_gain('relu'))

![score.jpg](https://s2.loli.net/2022/07/10/2U9OEWNiprl4RYG.jpg)

![loss.jpg](https://s2.loli.net/2022/07/10/Qx4cRq9svzPpoS6.jpg)

In [None]:
def calc_score(self, embedding, triplets):
        # score即对应上方公式(1)
        # e_i = h_i
        # R_r 为对角矩阵
        # 我们用一个directed and labeled graph G = (V, E, R) 来表示知识库(knowledge base)
        # triplets： 即(subject, relation, object)元组
        # assign scores to possible edges(s, r, o) 来判定这些边属于E的可能性
        s = embedding[triplets[:,0]]
        r = self.w_relation[triplets[:,1]]
        o = embedding[triplets[:,2]]
        score = th.sum(s * r * o, dim=1)
        return score

def forward(self, g, nids):
        return self.dropout(self.rgcn(g, nids=nids))

def regularization_loss(self, embedding):
        return th.mean(embedding.pow(2)) + th.mean(self.w_relation.pow(2))

 def get_loss(self, embed, triplets, labels):
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss

In [None]:
# FB15K-237 dateset
# entities: 14,541
# relations: 237
# train edges: 272,115
# validation edges: 17,535
# test edges: 20466
def main(args):
    data = FB15k237Dataset(reverse=False)
    graph = data[0] 
    num_nodes = graph.num_nodes() # 14541
    num_rels = data.num_rels # 237

    # test_g = Graph(num_nodes=14541, num_edges=544230)
    # train_g = Graph(num_nodes=14541,num_edges=272115)
    
    train_g, test_g = preprocess(graph, num_rels)
    test_nids = th.arange(0, num_nodes)
    test_mask = graph.edata['test_mask']
    subg_iter = SubgraphIterator(train_g, num_rels, args.edge_sampler)
    dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])

    # Prepare data for metric computation
    src, dst = graph.edges()
    triplets = th.stack([src, graph.edata['etype'], dst], dim=1)

    model = LinkPredict(num_nodes, num_rels)
    optimizer = th.optim.Adam(model.parameters(), lr=1e-2)

    if args.gpu >= 0 and th.cuda.is_available():
        device = th.device(args.gpu)
    else:
        device = th.device('cpu')
    model = model.to(device)

    best_mrr = 0
    model_state_file = 'model_state.pth'
    for epoch, batch_data in enumerate(dataloader):
        model.train()
        g, train_nids, edges, labels = batch_data
        g = g.to(device) 
        train_nids = train_nids.to(device)
        edges = edges.to(device)
        labels = labels.to(device)

        embed = model(g, train_nids)
        loss = model.get_loss(embed, edges, labels)
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
        optimizer.step()

        print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr))

        if (epoch + 1) % 500 == 0:
            # perform validation on CPU because full graph is too large
            model = model.cpu()
            model.eval()
            print("start eval")
            embed = model(test_g, test_nids)
            mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
                           batch_size=500, eval_p=args.eval_protocol)
            # save best model
            if best_mrr < mrr:
                best_mrr = mrr
                th.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)

            model = model.to(device)

In [None]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RGCN for entity classification')
    parser.add_argument("--n-hidden", type=int, default=16,
                        help="number of hidden units")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--n-bases", type=int, default=-1,
                        help="number of filter weight matrices, default: -1 [use all]")
    parser.add_argument("-d", "--dataset", type=str, required=True,
                        choices=['aifb', 'mutag', 'bgs', 'am'],
                        help="dataset to use")
    parser.add_argument("--wd", type=float, default=5e-4,
                        help="weight decay")

    args = parser.parse_args()
    print(args)
    main(args)
    