In [1]:
import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
n_users = 1000
n_items = 500

n_follows = 3000
n_clicks = 5000
n_dislikes = 500

n_hetero_features = 10
n_max_clicks = 10
n_user_classes = 5

In [3]:
# 关注边
torch.manual_seed(8)
follow_src = torch.randint(0, n_users, (n_follows,))
follow_dst = torch.randint(0, n_users, (n_follows,))
# (follow_dst==follow_src).sum()

# 点击边
click_src = torch.randint(0, n_users, (n_clicks,))
click_dst = torch.randint(0, n_items, (n_clicks,))

# 不喜欢边
dislike_src = torch.randint(0, n_users, (n_dislikes,))
dislike_dst = torch.randint(0, n_items, (n_dislikes,))

In [4]:
hg = dgl.heterograph(
    {
        ("user", "follow", "user"): (follow_src, follow_dst),
        # ("user", "followed-by", "user"): (follow_dst, follow_src),
        ("user", "click", "item"): (click_src, click_dst),
        # ("item", "clicked-by", "user"): (click_dst, click_src),
        ("user", "dislike", "item"): (dislike_src, dislike_dst),
        # ("item", "disliked-by", "user"): (dislike_dst, dislike_src),
    }
)
# 节点特征
hg.nodes["user"].data["feat"] = torch.randn(n_users, n_hetero_features)
hg.nodes["item"].data["feat"] = torch.randn(n_items, n_hetero_features)
hg.nodes["user"].data["label"] = torch.randint(0, n_user_classes, (n_users,))
hg.nodes["user"].data["train_mask"] = torch.ones(n_users).bernoulli(0.6)

# 边特征
hg.edges["click"].data["label"] = torch.randint(1, n_max_clicks, (n_clicks,))
hg.edges["click"].data["train_mask"] = torch.ones(n_clicks).bernoulli(0.6)

In [5]:
print("n of nodes(user, item)", hg.num_nodes())
print("n of edges(click, like, follow)", hg.num_edges())
print(hg.canonical_etypes)

n of nodes(user, item) 1500
n of edges(click, like, follow) 8500
[('user', 'click', 'item'), ('user', 'dislike', 'item'), ('user', 'follow', 'user')]


In [6]:
for relation in hg.canonical_etypes:
    stype, etype, dtype = relation
    print(relation)
print("\nmetagraph")
print(hg.metagraph().edges())
print(hg.metagraph().nodes())

('user', 'click', 'item')
('user', 'dislike', 'item')
('user', 'follow', 'user')

metagraph
[('user', 'item'), ('user', 'item'), ('user', 'user')]
['user', 'item']


hg[relation]  是一个子图（二分图），结构与dgl.graph类似

```python
outputs = {nty : [] for nty in g.dsttypes}
# Apply sub-modules on their associating relation graphs in parallel
for relation in g.canonical_etypes:
    stype, etype, dtype = relation
    dstdata = relation_submodule(g[relation], ...)
    outputs[dtype].append(dstdata)

# Aggregate the results for each destination node type
rsts = {}
for ntype, ntype_outputs in outputs.items():
    if len(ntype_outputs) != 0:
        rsts[ntype] = aggregate(ntype_outputs)
return rsts
```

In [44]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super(RGCN, self).__init__()
        self.conv1 = dglnn.HeteroGraphConv(
            {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
            aggregate="sum",
        )
        self.conv2 = dglnn.HeteroGraphConv(
            {rel: dgl.nn.GraphConv(hid_feats, out_feats) for rel in rel_names},
            aggregate="sum",
        )

    def forward(self, graph, inputs):
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

In [88]:
model = RGCN(
    in_feats=n_hetero_features,
    hid_feats=20,
    out_feats=n_user_classes,
    rel_names=hg.canonical_etypes,
)

In [89]:
node_features = {
    "user": hg.nodes["user"].data["feat"],
    "item": hg.nodes["item"].data["feat"],
}

In [90]:
# 特征字典
h_dict = model(hg, node_features)
h_dict

{'item': tensor([[-0.0485,  0.5714,  0.2532,  0.0848,  0.0388],
         [-0.2480,  0.8053,  0.4096, -0.7093,  0.0435],
         [-0.1727,  0.2233,  0.3161, -0.1636,  0.3206],
         ...,
         [ 0.0509,  0.1758,  0.4099,  0.5674,  0.1094],
         [-0.0782,  0.7114,  0.4468,  0.6578,  0.6673],
         [-0.3112,  0.2152, -0.0139, -0.5720,  0.7135]], grad_fn=<SumBackward1>),
 'user': tensor([[-0.1323,  0.0412,  0.3033,  0.2523, -0.3688],
         [-0.1384,  0.2469, -0.0149, -0.1189, -0.2292],
         [ 0.0182,  0.4845,  0.0886,  0.3176, -0.2345],
         ...,
         [-0.1964,  0.3309, -0.1619,  0.1969, -0.3511],
         [-0.1551,  0.1830, -0.0325,  0.0985, -0.1261],
         [-0.3409,  0.2595, -0.3088,  0.3672, -0.3658]], grad_fn=<SumBackward1>)}

In [91]:
optimzer=torch.optim.Adam(model.parameters(),lr=0.01)

In [92]:

loss_list = []
train_score_list = []

In [94]:
for e in range(501):
    
    model.train()
    
    logits=model(hg,node_features)['user']
    labels=hg.nodes['user'].data['label']
    
    train_mask=hg.nodes['user'].data['train_mask'].bool()
    
    loss=F.cross_entropy(logits[train_mask],labels[train_mask])
    
    optimzer.zero_grad()
    loss.backward()
    optimzer.step()
    
    pred=logits.argmax(1)
    train_acc = (pred[train_mask] == labels[train_mask]).float().mean().item()
    train_score_list.append(train_acc)
    loss_list.append(loss.item())

    if e%10 ==0:
        print(e,'loss:',loss.item(),'  train_acc:',train_acc)

0 loss: 1.2283493280410767   train_acc: 0.47154471278190613
10 loss: 1.227639079093933   train_acc: 0.46991869807243347
20 loss: 1.227152943611145   train_acc: 0.46341463923454285
30 loss: 1.226840853691101   train_acc: 0.46341463923454285
40 loss: 1.226580023765564   train_acc: 0.4731707274913788
50 loss: 1.2264413833618164   train_acc: 0.46666666865348816
60 loss: 1.2263177633285522   train_acc: 0.4682926833629608
70 loss: 1.226258397102356   train_acc: 0.46341463923454285
80 loss: 1.2261302471160889   train_acc: 0.46666666865348816
90 loss: 1.2260348796844482   train_acc: 0.46666666865348816
100 loss: 1.225939393043518   train_acc: 0.46991869807243347
110 loss: 1.2259249687194824   train_acc: 0.46991869807243347
120 loss: 1.2258765697479248   train_acc: 0.46991869807243347
130 loss: 1.2258265018463135   train_acc: 0.46991869807243347
140 loss: 1.225786805152893   train_acc: 0.47154471278190613
150 loss: 1.225785255432129   train_acc: 0.47479674220085144
160 loss: 1.2257505655288696 

In [107]:
hg['user',:,'item'].edata


{'_TYPE': tensor([0, 0, 0,  ..., 1, 1, 1]), '_ID': tensor([  0,   1,   2,  ..., 497, 498, 499])}

'_TYPE'