In [5]:
import urllib.request
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [176]:
import dgl

src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
ratings = dgl.heterograph(
    {('user', 'rating', 'menuItem') : (np.concatenate([src, dst]), np.concatenate([dst, src]))})
ratings.nodes['user'].data['feat'] = torch.randn(100, 10)
ratings.nodes['menuItem'].data['feat'] = torch.randn(100, 10)
ratings.nodes['user'].data['h'] = torch.randn(100, 10)
ratings.nodes['menuItem'].data['h'] = torch.randn(100, 10)
ratings.edges['rating'].data['label'] = torch.randn(1000, 5)
hetero_graph = ratings

In [198]:
print(ratings.num_nodes())
print('Node types:', ratings.ntypes)
print('Edge types:', ratings.etypes)
print('Canonical edge types:', ratings.canonical_etypes)
ratings.nodes['user'].data['h']

200
Node types: ['menuItem', 'user']
Edge types: ['rating']
Canonical edge types: [('user', 'rating', 'menuItem')]


tensor([[-5.4788e-01, -4.3858e-01,  6.1055e-01,  6.7419e-01, -2.9468e-01,
         -2.2180e-01,  7.0531e-01,  2.1755e+00,  6.6196e-01,  2.0971e-01],
        [ 9.3083e-01,  4.8017e-01, -6.8178e-01, -4.8535e-01,  3.2071e-01,
         -2.5853e-02,  3.0396e-01, -8.5124e-01, -8.7602e-01,  5.1982e-01],
        [-1.9595e+00, -9.8200e-01,  1.5850e+00,  9.6900e-01,  4.0175e-02,
          1.0325e+00,  2.9664e-01, -1.5458e-04, -9.9686e-02, -3.9723e-01],
        [ 2.9910e-01, -1.1362e+00, -9.1420e-01, -5.2920e-02, -1.5072e+00,
         -1.2154e-01, -1.6402e-01, -2.6102e-01,  1.7016e+00, -6.5747e-01],
        [ 4.8218e-01, -1.8329e-01,  7.5698e-01,  7.3101e-01, -4.1037e-01,
          7.4353e-01, -3.0006e-01, -3.2393e-01,  2.9800e-01, -6.1953e-01],
        [ 6.3464e-01, -1.1022e-01, -1.8871e-02, -1.9165e+00, -2.0447e-01,
          2.6971e+00, -9.0890e-01,  8.7641e-01, -8.5182e-01, -1.0765e+00],
        [ 1.4204e+00,  2.1342e-01, -2.7295e-01,  5.8698e-01,  3.9386e-01,
         -7.9516e-01,  1.7091e+0

In [178]:
ratings.edges(etype='rating')
for c_etype in ratings.canonical_etypes:
    srctype, etype, dsttype = c_etype
    print(c_etype)
    print(ratings.edges[etype])

('user', 'rating', 'menuItem')
EdgeSpace(data={'label': tensor([[ 2.0302, -0.7687,  1.1129, -1.5148, -0.3823],
        [-0.5129,  1.4071,  1.0234,  0.3360,  0.3705],
        [-1.5700, -0.1940,  0.8504,  2.4718,  0.9155],
        ...,
        [-1.4285, -1.2889, -0.8775,  1.9161, -1.5724],
        [ 0.2555, -1.5693, -1.9820,  0.2515,  0.4006],
        [ 0.2942,  0.8836,  1.1265,  1.4447, -1.3095]])})


In [207]:
import pygraphviz as pgv
from IPython.display import display

def plot_graph(nxg):
    ag = pgv.AGraph(strict=False, directed=True)
    for u, v, k in nxg.edges(keys=True):
        ag.add_edge(u, v, label=k)
    ag.layout('dot')
    ag.draw('graph.png')
    ag.view()
 
print(ratings.metagraph().edges())
plot_graph(ratings.metagraph())

[('user', 'menuItem')]


AttributeError: 'AGraph' object has no attribute 'view'

In [199]:
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        return {'score': score}

    def forward(self, graph, h, etype):
        # h contains the node representations for each edge type computed from
        # the GNN for heterogeneous graphs defined in the node classification
        # section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h   # assigns 'h' of all node types in one shot
            print(h)
            graph.apply_edges(self.apply_edges, etype=etype)
            return graph.edges[etype].data['score']

In [200]:
class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each edge type computed from
        # the GNN for heterogeneous graphs defined in the node classification
        # section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h   # assigns 'h' of all node types in one shot
            print(graph.ndata['h'])
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

In [201]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dgl.nn.HeteroGraphConv({
            rel: dgl.nn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dgl.nn.HeteroGraphConv({
            rel: dgl.nn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

In [202]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = MLPPredictor(in_features, out_features)
    def forward(self, g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype)

In [203]:
model = Model(10, 20, 5, hetero_graph.etypes)
#import networkx as nx
#nx_G = ratings.to_networkx()
#pos = nx.kamada_kawai_layout(nx_G)
#nx.draw(nx_G, pos, with_labels=True, node_color=[[.7, .7, .7]])
user_feats = hetero_graph.nodes['user'].data['feat']
item_feats = hetero_graph.nodes['menuItem'].data['feat']
label = hetero_graph.edges['rating'].data['label']
node_features = {'user': user_feats, 'menuItem': item_feats}

In [204]:
import dgl.function as fn
opt = torch.optim.Adam(model.parameters())
for epoch in range(1000):
    pred = model(ratings, node_features, 'rating')
    loss = ((pred - label) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch % 5 == 0:
        print('In epoch {}, loss: {}'.format(epoch, loss))
    #print(loss.item())

{}
In epoch 0, loss: 1.3585706949234009
{}
{}
{}
{}
{}
In epoch 5, loss: 1.332624912261963
{}
{}
{}
{}
{}
In epoch 10, loss: 1.3079216480255127
{}
{}
{}
{}
{}
In epoch 15, loss: 1.2844899892807007
{}
{}
{}
{}
{}
In epoch 20, loss: 1.2623289823532104
{}
{}
{}
{}
{}
In epoch 25, loss: 1.2414274215698242
{}
{}
{}
{}
{}
In epoch 30, loss: 1.2217588424682617
{}
{}
{}
{}
{}
In epoch 35, loss: 1.2032859325408936
{}
{}
{}
{}
{}
In epoch 40, loss: 1.1859662532806396
{}
{}
{}
{}
{}
In epoch 45, loss: 1.1697523593902588
{}
{}
{}
{}
{}
In epoch 50, loss: 1.1545902490615845
{}
{}
{}
{}
{}
In epoch 55, loss: 1.1404242515563965
{}
{}
{}
{}
{}
In epoch 60, loss: 1.1271979808807373
{}
{}
{}
{}
{}
In epoch 65, loss: 1.114855170249939
{}
{}
{}
{}
{}
In epoch 70, loss: 1.10334050655365
{}
{}
{}
{}
{}
In epoch 75, loss: 1.0926014184951782
{}
{}
{}
{}
{}
In epoch 80, loss: 1.082587718963623
{}
{}
{}
{}
{}
In epoch 85, loss: 1.0732511281967163
{}
{}
{}
{}
{}
In epoch 90, loss: 1.0645477771759033
{}
{}
{}
{}


{}
In epoch 765, loss: 0.9541775584220886
{}
{}
{}
{}
{}
In epoch 770, loss: 0.9541774392127991
{}
{}
{}
{}
{}
In epoch 775, loss: 0.9541773200035095
{}
{}
{}
{}
{}
In epoch 780, loss: 0.9541773200035095
{}
{}
{}
{}
{}
In epoch 785, loss: 0.9541773200035095
{}
{}
{}
{}
{}
In epoch 790, loss: 0.9541772603988647
{}
{}
{}
{}
{}
In epoch 795, loss: 0.9541771411895752
{}
{}
{}
{}
{}
In epoch 800, loss: 0.9541771411895752
{}
{}
{}
{}
{}
In epoch 805, loss: 0.9541771411895752
{}
{}
{}
{}
{}
In epoch 810, loss: 0.9541771411895752
{}
{}
{}
{}
{}
In epoch 815, loss: 0.9541769623756409
{}
{}
{}
{}
{}
In epoch 820, loss: 0.9541769623756409
{}
{}
{}
{}
{}
In epoch 825, loss: 0.9541769623756409
{}
{}
{}
{}
{}
In epoch 830, loss: 0.9541769623756409
{}
{}
{}
{}
{}
In epoch 835, loss: 0.9541769623756409
{}
{}
{}
{}
{}
In epoch 840, loss: 0.9541769623756409
{}
{}
{}
{}
{}
In epoch 845, loss: 0.9541768431663513
{}
{}
{}
{}
{}
In epoch 850, loss: 0.9541768431663513
{}
{}
{}
{}
{}
In epoch 855, loss: 0.954