In [15]:
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
import numpy as np
import torch
import dgl

In [5]:
n_users = 1000
n_items = 500
n_follows = 3000
n_clicks = 5000
n_dislikes = 500
n_hetero_features = 10
n_user_classes = 5
n_max_clicks = 10

In [9]:
follow_src = np.random.randint(0, n_users, n_follows)
follow_dst = np.random.randint(0, n_users, n_follows)
click_src = np.random.randint(0, n_users, n_clicks)
click_dst = np.random.randint(0, n_items, n_clicks)
dislike_src = np.random.randint(0, n_users, n_dislikes)
dislike_dst = np.random.randint(0, n_items, n_dislikes)

In [10]:
follow_src.shape

(3000,)

In [11]:
follow_dst.shape

(3000,)

In [12]:
dislike_dst.shape

(500,)

In [16]:
hetero_graph = dgl.heterograph({
    ('user', 'follow', 'user'): (follow_src, follow_dst),
    ('user', 'followed-by', 'user'): (follow_dst, follow_src),
    ('user', 'click', 'item'): (click_src, click_dst),
    ('item', 'clicked-by', 'user'): (click_dst, click_src),
    ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
    ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})

In [17]:
dir(hetero_graph)

['__class__',
 '__contains__',
 '__copy__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_batch_num_edges',
 '_batch_num_nodes',
 '_canonical_etypes',
 '_dsttypes_invmap',
 '_edge_frames',
 '_etype2canonical',
 '_etypes',
 '_etypes_invmap',
 '_find_etypes',
 '_get_e_repr',
 '_get_n_repr',
 '_graph',
 '_idtype_str',
 '_init',
 '_is_unibipartite',
 '_node_frames',
 '_ntypes',
 '_pop_e_repr',
 '_pop_n_repr',
 '_reset_cached_info',
 '_set_e_repr',
 '_set_n_repr',
 '_srctypes_invmap',
 'add_edge',
 'add_edges',
 'add_nodes',
 'add_self_loop',
 'adj',
 'adjacency_matrix',
 'adjacency_matrix_scipy',
 'all_edges',
 'apply_edges',
 'apply_node

In [18]:
hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()

In [29]:
hetero_graph.nodes['user'].data['feature']

tensor([[ 0.8514,  0.9998, -0.5562,  ...,  0.1700,  1.0875, -0.1016],
        [-0.5147,  0.4508, -0.2208,  ...,  1.4062,  0.0911,  0.4209],
        [ 0.1216,  0.5356,  0.3412,  ...,  2.3806, -0.5839,  1.3165],
        ...,
        [ 2.4860, -0.2420, -2.5164,  ..., -1.0627, -0.7137, -1.1781],
        [ 0.0209, -1.3354,  1.4816,  ...,  0.8056, -0.3182,  0.2704],
        [-2.0768, -0.4367, -0.7052,  ...,  0.8856, -0.1017,  0.3964]])

In [30]:
hetero_graph.nodes['user'].data['feature'].shape

torch.Size([1000, 10])

In [27]:
hetero_graph.nodes['item'].data['feature']

tensor([[ 0.2892,  0.0536, -0.5659,  ...,  2.0141,  1.5969, -1.0956],
        [ 0.7082, -0.1850,  0.7187,  ..., -1.1740, -0.2240, -0.4934],
        [ 0.4287,  0.3628, -0.8518,  ..., -0.1427,  0.6991, -0.3048],
        ...,
        [-1.2799,  1.3628, -0.7978,  ...,  0.6321, -0.2597,  0.9185],
        [ 0.8453, -0.6468,  0.2216,  ...,  1.2142, -0.1533, -2.4434],
        [ 0.3474,  1.7530, -0.6332,  ..., -0.0467, -0.3661,  0.8721]])

In [28]:
hetero_graph.nodes['item'].data['feature'].shape

torch.Size([500, 10])

In [26]:
hetero_graph.nodes['user'].data['label']

tensor([2, 0, 3, 3, 1, 1, 4, 1, 2, 4, 4, 2, 2, 3, 4, 0, 0, 3, 4, 4, 3, 1, 0, 0,
        4, 3, 2, 2, 1, 1, 1, 0, 2, 4, 0, 1, 4, 4, 0, 1, 0, 2, 3, 4, 2, 4, 3, 3,
        0, 3, 4, 0, 3, 0, 1, 0, 2, 3, 3, 3, 0, 4, 0, 3, 2, 0, 0, 3, 3, 3, 3, 4,
        0, 1, 1, 3, 4, 3, 3, 2, 0, 4, 3, 2, 3, 0, 2, 2, 1, 1, 0, 3, 4, 4, 1, 3,
        2, 2, 3, 3, 0, 1, 3, 0, 1, 2, 0, 2, 3, 0, 0, 3, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 1, 3, 1, 1, 3, 0, 2, 3, 4, 2, 0, 2, 0, 0, 4, 3, 2, 2, 3, 3, 4, 2, 4,
        3, 2, 2, 4, 0, 0, 1, 0, 0, 0, 1, 2, 4, 0, 3, 2, 3, 4, 0, 4, 4, 4, 0, 2,
        0, 0, 0, 0, 1, 1, 1, 2, 0, 4, 3, 2, 4, 1, 4, 2, 3, 4, 0, 4, 0, 1, 2, 0,
        4, 3, 0, 0, 4, 4, 2, 1, 1, 4, 3, 4, 3, 0, 1, 3, 3, 0, 3, 3, 1, 3, 0, 3,
        4, 3, 3, 0, 2, 4, 1, 2, 0, 4, 4, 4, 4, 2, 2, 3, 2, 3, 2, 4, 1, 0, 3, 2,
        1, 3, 4, 2, 0, 2, 4, 1, 0, 2, 4, 3, 4, 0, 2, 3, 0, 4, 0, 3, 2, 4, 4, 2,
        0, 3, 2, 3, 1, 0, 3, 2, 0, 4, 0, 3, 4, 4, 1, 4, 1, 1, 4, 4, 4, 4, 4, 0,
        3, 3, 3, 1, 0, 2, 2, 0, 1, 1, 1,

In [25]:
hetero_graph.nodes['user'].data['label'].shape

torch.Size([1000])

In [20]:
hetero_graph.edges['click'].data['label']

tensor([7., 3., 9.,  ..., 1., 1., 1.])

In [21]:
hetero_graph.edges['click'].data['label'].shape

torch.Size([5000])

In [19]:
# randomly generate training masks on user nodes and click edges
hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)

In [32]:
hetero_graph.nodes['user'].data['train_mask']

tensor([False,  True, False,  True, False,  True, False,  True,  True,  True,
         True,  True, False,  True, False,  True, False, False,  True,  True,
         True, False,  True,  True,  True,  True, False,  True,  True,  True,
        False,  True, False,  True,  True, False, False, False,  True, False,
        False,  True, False,  True,  True,  True,  True, False,  True,  True,
        False,  True,  True, False, False,  True, False,  True,  True,  True,
        False, False,  True, False, False,  True,  True, False,  True, False,
        False,  True,  True, False, False,  True, False, False,  True,  True,
        False, False, False, False,  True,  True, False, False,  True,  True,
        False,  True,  True,  True,  True,  True,  True,  True, False,  True,
        False, False,  True,  True,  True,  True,  True,  True, False,  True,
        False,  True, False,  True, False,  True,  True, False, False,  True,
         True,  True, False, False, False, False,  True,  True, 

In [33]:
hetero_graph.nodes['user'].data['train_mask'].shape

torch.Size([1000])

In [34]:
hetero_graph.edges['click'].data['train_mask']

tensor([ True,  True,  True,  ...,  True,  True, False])

In [35]:
hetero_graph.edges['click'].data['train_mask'].shape

torch.Size([5000])

In [23]:
dir(hetero_graph.nodes)

['__call__',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '_graph',
 '_typeid_getter']

In [3]:
# Define a Heterograph Conv model
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

In [31]:
model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']

In [36]:
node_features = {'user': user_feats, 'item': item_feats}
h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})
h_user = h_dict['user']
h_item = h_dict['item']

In [37]:
opt = torch.optim.Adam(model.parameters())

for epoch in range(5):
    model.train()
    # forward propagation by using all nodes and extracting the user embeddings
    logits = model(hetero_graph, node_features)['user']
    # compute loss
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    # Compute validation accuracy.  Omitted in this example.
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())


1.7691476345062256
1.7608240842819214
1.753036618232727
1.7457612752914429
1.7389553785324097
