In [1]:

import os 
import dgl
import pickle
import torch 
import pandas as pd
import numpy as np 
from pandas.api.types import is_numeric_dtype, is_categorical_dtype, is_categorical
import torch
import pandas as pd 
import dgl.function as fn


Using backend: pytorch


In [2]:
import numpy as np
import dgl
import torch
from torch.utils.data import IterableDataset, DataLoader

In [3]:
ns_music_all_data = pickle.load(open('ns_music_all_data.p', 'rb'))

In [4]:
df_playlists = ns_music_all_data['df_playlist']
df_playlists_info = ns_music_all_data['df_playlist_info']
df_tracks = ns_music_all_data['df_track']

# 1. Build Playlist-Track graph

In [5]:
def _series_to_tensor(series):
    if is_categorical(series):
        return torch.LongTensor(series.cat.codes.values.astype('int64'))
    else:       # numeric
        return torch.FloatTensor(series.values)

class PandasGraphBuilder(object):
    """Creates a heterogeneous graph from multiple pandas dataframes.

    Examples
    --------
    Let's say we have the following three pandas dataframes:

    User table ``users``:

    ===========  ===========  =======
    ``user_id``  ``country``  ``age``
    ===========  ===========  =======
    XYZZY        U.S.         25
    FOO          China        24
    BAR          China        23
    ===========  ===========  =======

    Game table ``games``:

    ===========  =========  ==============  ==================
    ``game_id``  ``title``  ``is_sandbox``  ``is_multiplayer``
    ===========  =========  ==============  ==================
    1            Minecraft  True            True
    2            Tetris 99  False           True
    ===========  =========  ==============  ==================

    Play relationship table ``plays``:

    ===========  ===========  =========
    ``user_id``  ``game_id``  ``hours``
    ===========  ===========  =========
    XYZZY        1            24
    FOO          1            20
    FOO          2            16
    BAR          2            28
    ===========  ===========  =========

    One could then create a bidirectional bipartite graph as follows:
    >>> builder = PandasGraphBuilder()
    >>> builder.add_entities(users, 'user_id', 'user')
    >>> builder.add_entities(games, 'game_id', 'game')
    >>> builder.add_binary_relations(plays, 'user_id', 'game_id', 'plays')
    >>> builder.add_binary_relations(plays, 'game_id', 'user_id', 'played-by')
    >>> g = builder.build()
    >>> g.number_of_nodes('user')
    3
    >>> g.number_of_edges('plays')
    4
    """
    def __init__(self):
        self.entity_tables = {}
        self.relation_tables = {}

        self.entity_pk_to_name = {}     # mapping from primary key name to entity name
        self.entity_pk = {}             # mapping from entity name to primary key
        self.entity_key_map = {}        # mapping from entity names to primary key values
        self.num_nodes_per_type = {}
        self.edges_per_relation = {}
        self.relation_name_to_etype = {}
        self.relation_src_key = {}      # mapping from relation name to source key
        self.relation_dst_key = {}      # mapping from relation name to destination key

    def add_entities(self, entity_table, primary_key, name):
        entities = entity_table[primary_key].astype('category')
        if not (entities.value_counts() == 1).all():
            raise ValueError('Different entity with the same primary key detected.')
        # preserve the category order in the original entity table
        entities = entities.cat.reorder_categories(entity_table[primary_key].values)

        self.entity_pk_to_name[primary_key] = name
        self.entity_pk[name] = primary_key
        self.num_nodes_per_type[name] = entity_table.shape[0]
        self.entity_key_map[name] = entities
        self.entity_tables[name] = entity_table

    def add_binary_relations(self, relation_table, source_key, destination_key, name):
        src = relation_table[source_key].astype('category')
        src = src.cat.set_categories(
            self.entity_key_map[self.entity_pk_to_name[source_key]].cat.categories)
        dst = relation_table[destination_key].astype('category')
        dst = dst.cat.set_categories(
            self.entity_key_map[self.entity_pk_to_name[destination_key]].cat.categories)
        if src.isnull().any():
            raise ValueError(
                'Some source entities in relation %s do not exist in entity %s.' %
                (name, source_key))
        if dst.isnull().any():
            raise ValueError(
                'Some destination entities in relation %s do not exist in entity %s.' %
                (name, destination_key))

        srctype = self.entity_pk_to_name[source_key]
        dsttype = self.entity_pk_to_name[destination_key]
        etype = (srctype, name, dsttype)
        self.relation_name_to_etype[name] = etype
        self.edges_per_relation[etype] = (src.cat.codes.values.astype('int64'), dst.cat.codes.values.astype('int64'))
        self.relation_tables[name] = relation_table
        self.relation_src_key[name] = source_key
        self.relation_dst_key[name] = destination_key

    def build(self):
        # Create heterograph
        graph = dgl.heterograph(self.edges_per_relation, self.num_nodes_per_type)
        return graph


Build bipartite heterogenous graph:

- track is identified by tid
- playlist is identified by pid 
- edge contains : play list contains track 
- edge contained_by: track is contained by playlist


ids of nodes in graph and rows in dataset are in the same order:
- 1st track node is track with tid =1 
- 1st playlist node is play list with pid = 1

In [6]:
graph_builder = PandasGraphBuilder()

In [7]:
df_playlists_info = df_playlists_info.sort_values('pid').reset_index(drop=True)

In [8]:
graph_builder = PandasGraphBuilder()
graph_builder.add_entities(df_tracks, 'tid', 'track')
graph_builder.add_entities(df_playlists_info, 'pid', 'playlist')
graph_builder.add_binary_relations(df_playlists, 'pid', 'tid', 'contains')
graph_builder.add_binary_relations(df_playlists, 'tid', 'pid', 'contained_by')


In [9]:
g = graph_builder.build()


Load features to graph
- music features are stored as long tensors (categorical, to be embedded)
- genre, album_img_emb, album_text_emb are stored as numerical features 
- track id, play list id are also included, can be embedded as well

In [11]:
for key in ['danceability', 'energy', 'loudness', 'speechiness', 'acousticness', 'instrumentalness', 'liveness', 'valence', 'tempo']:
    
    g.nodes['track'].data[key] = torch.LongTensor(df_tracks[key].values)
    

In [12]:
g.nodes['track'].data['genre'] = torch.tensor(np.asarray(list(df_tracks['genre'].values))).float()

In [None]:
g.nodes['track'].data['album_img_emb'] = torch.tensor(np.asarray(list(df_tracks['album_img_emb'].values)))
g.nodes['track'].data['album_text_emb'] = torch.tensor(np.asarray(list(df_tracks['album_text_emb'].values)))

In [None]:
g.nodes['playlist'].data['id'] = torch.arange(g.number_of_nodes('playlist'))
g.nodes['track'].data['id'] = torch.arange(g.number_of_nodes('track'))


In [23]:
g.nodes['playlist']

NodeSpace(data={'id': tensor([     0,      1,      2,  ..., 999997, 999998, 999999])})

In [24]:
g.nodes['track']

NodeSpace(data={'danceability': tensor([3, 3, 3,  ..., 3, 2, 1]), 'energy': tensor([3, 3, 3,  ..., 2, 3, 2]), 'loudness': tensor([2, 3, 2,  ..., 2, 2, 2]), 'speechiness': tensor([3, 3, 3,  ..., 2, 1, 1]), 'acousticness': tensor([1, 1, 1,  ..., 1, 1, 2]), 'instrumentalness': tensor([2, 2, 1,  ..., 2, 1, 2]), 'liveness': tensor([1, 3, 1,  ..., 2, 2, 2]), 'valence': tensor([3, 3, 3,  ..., 2, 3, 1]), 'tempo': tensor([2, 3, 1,  ..., 1, 3, 2]), 'genre': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'album_img_emb': tensor([[0.2030, 1.0508, 1.4206,  ..., 0.1448, 0.7110, 0.6996],
        [0.1082, 2.7684, 1.0451,  ..., 0.1598, 0.2774, 0.0799],
        [0.1201, 0.2322, 1.0679,  ..., 0.0352, 0.0731, 0.4070],
        ...,
        [1.2159, 2.7029, 1.2146,  ..., 0.0867, 0.0533, 0.6647],
       

# 2. Train test splits

In [28]:
# THIS PART IS MISSING, FOR NOW EVERY ENTRY IS IN TRAINING

In [26]:
import numpy as np 

In [27]:
def build_train_graph(g, train_indices, utype, itype, etype, etype_rev):
    train_g = g.edge_subgraph(
        {etype: train_indices, etype_rev: train_indices},
        relabel_nodes=False)

    # copy features
    for ntype in g.ntypes:
        for col, data in g.nodes[ntype].data.items():
            train_g.nodes[ntype].data[col] = data
    for etype in g.etypes:
        for col, data in g.edges[etype].data.items():
            train_g.edges[etype].data[col] = data[train_g.edges[etype].data[dgl.EID]]

    return train_g

In [29]:
train_indicies = np.arange(len(df_playlists))
train_g = build_train_graph(g, train_indicies, 'playlist', 'track', 'contains', 'contained_by' )

# 3. Build Sampler

## 3.1 HEAD POS NEG Tracks sampler

In [30]:
class ItemToItemBatchSampler(IterableDataset):
    def __init__(self, g, user_type, item_type, batch_size):
        self.g = g
        self.user_type = user_type
        self.item_type = item_type
        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
        self.batch_size = batch_size

    def __iter__(self):
        while True:
            heads = torch.randint(0, self.g.number_of_nodes(self.item_type), (self.batch_size,))
            tails = dgl.sampling.random_walk(
                self.g,
                heads,
                metapath=[self.item_to_user_etype, self.user_to_item_etype])[0][:, 2]
            neg_tails = torch.randint(0, self.g.number_of_nodes(self.item_type), (self.batch_size,))

            mask = (tails != -1)
            yield heads[mask], tails[mask], neg_tails[mask]


this block below generate a small graph data set that can be used to do parameters tuning

In [None]:
# heads = torch.randint(0, train_g.number_of_nodes('track'), (10000,))

# tails = dgl.sampling.random_walk(
#     train_g,
#     heads,
#     metapath=['contained_by', 'contains'])[0]
# playlist_ids = torch.unique(tails[:, 1]).numpy()
# track_ids = torch.unique(torch.cat([tails[:,0],tails[:,-1]])).numpy()
# df_tracks_small = df_tracks[df_tracks['tid'].isin(track_ids)]
# df_playlists_info_small = df_playlists_info[df_playlists_info['pid'].isin(playlist_ids)]
# df_playlists_small  = df_playlists[df_playlists['pid'].isin(playlist_ids) & df_playlists['tid'].isin(track_ids)]
# df_tracks_small = df_tracks_small.reset_index(drop=True)
# df_playlists_info_small = df_playlists_info_small.reset_index(drop=True)
# df_playlists_small = df_playlists_small.reset_index(drop=True)
# new_track_uri_ids ={x:idx for idx, x in enumerate(list(df_tracks_small['tid']))}
# new_playlists_ids ={x:idx for idx, x in enumerate(list(df_playlists_info_small['pid']))}
# df_tracks_small['tid'] = [new_track_uri_ids[x] for x in list(df_tracks_small['tid'])]
# df_playlists_info_small['pid'] = [new_playlists_ids[x] for x in list(df_playlists_info_small['pid'])]
# df_playlists_small['pid'] = [new_playlists_ids[x] for x in list(df_playlists_small['pid'])]
# df_playlists_small['tid'] = [new_track_uri_ids[x] for x in list(df_playlists_small['tid'])]
# data = {
#     'df_playlist': df_playlists_small,
#     'df_playlist_info': df_playlists_info_small,
#     'df_track': df_tracks_small
# }
# pickle.dump(data, open('ns_music_small_data.p', 'wb'))

The sampler below generate positive and negative edges
- sample a batch of heads
- do track -> playlist -> track random walk to find pairs of positive edges
- random sample batch of nodes as negative edges

In [32]:
batch_sampler = ItemToItemBatchSampler(train_g, 'playlist', 'track', 32)

In [33]:
batch_iter = iter(batch_sampler)

In [34]:
heads, tails, neg_tails = next(batch_iter)

In [36]:
heads, tails, neg_tails

(tensor([ 652604, 1172189, 1171658,  775423, 2072876,  411679,  654767, 1671334,
         1342078,  149534,   43313, 1643706,   29176, 1767177,  481002, 1210897,
         1444569,  537707,  423477, 1186928, 1672399,  465615, 1665923, 1358648,
         1345516,   60195,  138084, 1640883,  496928,  449051,  780300,  843284]),
 tensor([ 142868,  314177,    2517,  775421, 1302117,    3733,    2618, 1671286,
         1342063,  860119,    9366,    8388,    3138,   13769,   10143,  915598,
         1444561,  133257,  347474, 1186908,  579463,  111112,   58208,    7268,
          900040,   27499,   93612,  115300, 1309897,  501567,  780299, 2204382]),
 tensor([ 303276, 1635463,  174702,  581917,  298855, 1163418,  709089,  257351,
         1919806,  401042,  477965, 1527108, 1879487, 2224889, 2140501, 1751727,
         1126472,  167081, 2131186,  755532,  333102,  219050, 1365535,  742104,
         1823141,   35783, 1429677, 1943776, 1665372, 1594853, 1200184, 1386411]))

## 3.2  Neighborhood sampler default

In [42]:

def compact_and_copy(frontier, seeds):
    block = dgl.to_block(frontier, seeds)
    for col, data in frontier.edata.items():
        if col == dgl.EID:
            continue
        block.edata[col] = data[block.edata[dgl.EID]]
    return block

class NeighborSampler(object):
    def __init__(self, g, user_type, item_type, random_walk_length, random_walk_restart_prob,
                 num_random_walks, num_neighbors, num_layers):
        self.g = g
        self.user_type = user_type
        self.item_type = item_type
        self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
        self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
        self.samplers = [
            dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
                                        random_walk_restart_prob, num_random_walks, num_neighbors)
            for _ in range(num_layers)]

    def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None):
        blocks = []
        for sampler in self.samplers:
            frontier = sampler(seeds)
            if heads is not None:
                eids = frontier.edge_ids(torch.cat([heads, heads]), torch.cat([tails, neg_tails]), return_uv=True)[2]
                if len(eids) > 0:
                    old_frontier = frontier
                    frontier = dgl.remove_edges(old_frontier, eids)
                    # print(old_frontier)
                    # print(frontier)
                    # print(frontier.edata['weights'])
                    # frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]]
            block = compact_and_copy(frontier, seeds)
            seeds = block.srcdata[dgl.NID]
            blocks.insert(0, block)
        return blocks

    def sample_from_item_pairs(self, heads, tails, neg_tails):
        # Create a graph with positive connections only and another graph with negative
        # connections only.
        pos_graph = dgl.graph(
            (heads, tails),
            num_nodes=self.g.number_of_nodes(self.item_type))
        neg_graph = dgl.graph(
            (heads, neg_tails),
            num_nodes=self.g.number_of_nodes(self.item_type))
        pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
        seeds = pos_graph.ndata[dgl.NID]

        blocks = self.sample_blocks(seeds, heads, tails, neg_tails)
        return pos_graph, neg_graph, blocks


In [44]:
def assign_features_to_blocks(blocks, g, ntype='track'):
    
    data = blocks[0].srcdata
    
    for col in g.nodes[ntype].data.keys():
        if  col == dgl.NID:
            continue
        induced_nodes = data[dgl.NID]
        data[col] = g.nodes[ntype].data[col][induced_nodes]

    
    data = blocks[-1].dstdata
    for col in g.nodes[ntype].data.keys():
        if  col == dgl.NID:
            continue
        induced_nodes = data[dgl.NID]
        data[col] = g.nodes[ntype].data[col][induced_nodes]


In [45]:
class PinSAGECollator(object):
    def __init__(self, sampler, g, ntype):
        self.sampler = sampler
        self.ntype = ntype
        self.g = g

    def collate_train(self, batches):
        heads, tails, neg_tails = batches[0]
        # Construct multilayer neighborhood via PinSAGE...
        pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(heads, tails, neg_tails)
        assign_features_to_blocks(blocks, self.g, self.ntype)

        return pos_graph, neg_graph, blocks


**IMPORTANT** This is the neighbor graph generator

- batch sampler: generate pos and negative pairs
- neigbbor sampler:   generate neighbor graph of nodes from pos + neg pairs

please go through the implementation details if you want to make modifications

The output consists of:

- pos graph: data structure storing positive pairs 
- neg graph: data structure storing negative pairs 
- blocks: data structure fascilitating message passing:
    - blocks[0]  frontier 2 -> frontier 1 
    - blocks[1]  frontier 1 -> nodes of interests

corresponding nodes features are stored in blocks as well 
- `block.srcdata` src nodes data 
- `block.dstdata` dst nodes data

**IMPORTANT** 
- The destination nodes are the destination end of edges in the graph. 
- The source nodes stored in `block.srcnode` are not only the source end of edges in the graph, it consists of:
    - destination end nodes 
    - source end nodes

In [49]:
neighbor_sampler = NeighborSampler(train_g, 'playlist', 'track', 
                                   random_walk_length=2, random_walk_restart_prob=0.5, num_random_walks=10, num_neighbors=3, num_layers=2)
collator = PinSAGECollator(neighbor_sampler, train_g, 'track')


In [50]:
dataloader = DataLoader(
    batch_sampler,
    collate_fn=collator.collate_train,
    num_workers=8)

In [51]:
dataloader_it = iter(dataloader)

In [52]:
pos_graph,neg_graph, blocks = next(dataloader_it)

In [56]:
blocks[0].srcdata

{'_ID': tensor([1222908,  685010,  200852,  ..., 1405005, 2251956, 2251949]), 'danceability': tensor([1, 2, 3,  ..., 1, 1, 1]), 'energy': tensor([1, 2, 3,  ..., 3, 3, 3]), 'loudness': tensor([1, 3, 2,  ..., 2, 3, 3]), 'speechiness': tensor([1, 1, 2,  ..., 2, 2, 3]), 'acousticness': tensor([3, 1, 3,  ..., 2, 1, 2]), 'instrumentalness': tensor([1, 1, 2,  ..., 3, 3, 2]), 'liveness': tensor([2, 2, 2,  ..., 2, 3, 3]), 'valence': tensor([1, 2, 3,  ..., 2, 1, 2]), 'tempo': tensor([1, 2, 3,  ..., 3, 2, 3]), 'genre': tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'album_img_emb': tensor([[0.1468, 1.9877, 0.6605,  ..., 0.0424, 0.1162, 0.2592],
        [0.1444, 2.0385, 0.3274,  ..., 0.0382, 0.0620, 0.1265],
        [0.3310, 1.2132, 0.9204,  ..., 0.1687, 0.2099, 0.1162],
        ...,
        [

# 4. Modeling

outline of model procedure:

inputs: `pos_graph`, `neg_graph`, `block_0`, `block_1`

- compute node projection for all nodes,  if you understand the section above you will see all we need to computes are:
    - source nodes of block_0 
    - destination nodes of block_1
    
- run 2 sage layers:
    - layer 1 on block 0: frontier 2 -> frontier 1
    - layer 2 on block 1: frontier 1 -> nodes of interests 
    
- score on nodes of interests
    - for nodes u, v,  dot(u,v) + bias(u) + bias(v)

- loss function sum(neg_pairs_scores)  - sum(pos_pairs_scores) + 1




In [63]:
from torch import nn

Feature projections:

- music feature: 
    - each entry map to embedding of length 16
    - total size 144 
- genre feature: as it is  20 
- album img feature: as it is  2048
- track id feature: map to 128 embeddding


Feature aggregation:

- concatenate music feature, genre feature , album img feature : size 2212
- Fully connected layer 2212x128 reduce the dimentionality of the concatenated feature 
- add id feature

FC(concate[music, genre, img_emb]) + id_emb




In [64]:
def disable_grad(module):
    for param in module.parameters():
        param.requires_grad = False


def _init_input_modules(g, ntype):
    module_dict = nn.ModuleDict()
    
    tracks_data = g.nodes[ntype].data
    
    module_dict['track_id'] = nn.Embedding(tracks_data['id'].max()+1, 128)
    
    for m in ['danceability', 'energy', 'loudness', 'speechiness', 'acousticness', 'instrumentalness', 'liveness', 'valence', 'tempo']: 
        module_dict[m] = nn.Embedding(tracks_data[m].max() + 1, 16)

    return module_dict


class LinearProjector(nn.Module):
    """
    Projects each input feature of the graph linearly and sums them up
    """

    def __init__(self, full_graph, ntype):
        super().__init__()

        self.ntype = ntype
        #self.fc = nn.Linear(164, 128)
        self.fc = nn.Linear(2212, 128)
        self.inputs = _init_input_modules(full_graph, ntype)

    def forward(self, ndata):
        
        # get music feature
        music_features = []
        for c in ['danceability', 'energy', 'loudness', 'speechiness', 'acousticness', 'instrumentalness', 'liveness', 'valence', 'tempo']:

            module = self.inputs[c]
            music_features.append(module(ndata[c]))
        music_features = torch.cat(music_features, dim=1)
        
        
        # id embedding 
        id_embedding = self.inputs['track_id'](ndata['id'])
        
        # album feature 
        img_emb = ndata['album_img_emb']
        
        # genre 
        genre = ndata['genre']
        
        # concatenate 
        feature = torch.cat([music_features, genre, ndata['album_img_emb']], dim=1)
        #feature = torch.cat([music_features, genre], dim=1)

        projection = self.fc(feature) + id_embedding
        
        return projection

## 4.2 sage layers

not too much to discuss, please understand every line of the the following if you intend to make model modification

notice the changes in `itemtoitemscorer`

In [66]:
import torch.nn.functional as F


class WeightedSAGEConv(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims, act=F.relu):
        super().__init__()

        self.act = act
        self.Q = nn.Linear(input_dims, hidden_dims)
        self.W = nn.Linear(input_dims + hidden_dims, output_dims)
        self.reset_parameters()
        self.dropout = nn.Dropout(0.5)

    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_uniform_(self.Q.weight, gain=gain)
        nn.init.xavier_uniform_(self.W.weight, gain=gain)
        nn.init.constant_(self.Q.bias, 0)
        nn.init.constant_(self.W.bias, 0)

    def forward(self, g, h, weights):
        """
        g : graph
        h : node features
        weights : scalar edge weights
        """
        h_src, h_dst = h
        with g.local_scope():
            g.srcdata['n'] = self.act(self.Q(self.dropout(h_src)))
            g.edata['w'] = weights.float()
            g.update_all(fn.u_mul_e('n', 'w', 'm'), fn.sum('m', 'n'))
            g.update_all(fn.copy_e('w', 'm'), fn.sum('m', 'ws'))
            n = g.dstdata['n']
            ws = g.dstdata['ws'].unsqueeze(1).clamp(min=1)
            z = self.act(self.W(self.dropout(torch.cat([n / ws, h_dst], 1))))
            z_norm = z.norm(2, 1, keepdim=True)
            z_norm = torch.where(z_norm == 0, torch.tensor(1.).to(z_norm), z_norm)
            z = z / z_norm
            return z


class SAGENet(nn.Module):
    def __init__(self, hidden_dims, n_layers):
        """
        g : DGLHeteroGraph
            The user-item interaction graph.
            This is only for finding the range of categorical variables.
        item_textsets : torchtext.data.Dataset
            The textual features of each item node.
        """
        super().__init__()

        self.convs = nn.ModuleList()
        for _ in range(n_layers):
            self.convs.append(WeightedSAGEConv(hidden_dims, hidden_dims, hidden_dims))

    def forward(self, blocks, h):
        for layer, block in zip(self.convs, blocks):
            h_dst = h[:block.number_of_nodes('DST/' + block.ntypes[0])]
            h = layer(block, (h, h_dst), block.edata['weights'])
        return h
class ItemToItemScorer(nn.Module):
    def __init__(self, full_graph, ntype):
        super().__init__()

        n_nodes = full_graph.number_of_nodes(ntype)
        self.bias = nn.Parameter(torch.zeros(n_nodes))

    def _add_bias(self, edges):
        bias_src = self.bias[edges.src[dgl.NID]]
        bias_dst = self.bias[edges.dst[dgl.NID]]
        return {'s': edges.data['s'] + bias_src + bias_dst}

    def forward(self, item_item_graph, h):
        """
        item_item_graph : graph consists of edges connecting the pairs
        h : hidden state of every node
        """
        with item_item_graph.local_scope():
            item_item_graph.ndata['h'] = h
            item_item_graph.apply_edges(fn.u_dot_v('h', 'h', 's'))
            item_item_graph.edata['s'] = item_item_graph.edata['s'].flatten()
            item_item_graph.apply_edges(self._add_bias)

            pair_score = item_item_graph.edata['s']
        return pair_score


In [67]:
from sklearn.metrics import roc_auc_score
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).detach().cpu().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

class PinSAGEModel(nn.Module):
    def __init__(self, full_graph, ntype, hidden_dims, n_layers):
        super().__init__()

        self.proj = LinearProjector(full_graph, ntype)
        self.sage = SAGENet(hidden_dims, n_layers)
        self.scorer = ItemToItemScorer(full_graph, ntype)

    def forward(self, pos_graph, neg_graph, blocks):
        h_item = self.get_repr(blocks)
        pos_score = self.scorer(pos_graph, h_item)
        neg_score = self.scorer(neg_graph, h_item)
        
        #return h_item, pos_score, neg_score
        auc = compute_auc(pos_score, neg_score)
        return (neg_score - pos_score + 1).clamp(min=0), auc

    def get_repr(self, blocks):
        h_item = self.proj(blocks[0].srcdata)
        h_item_dst = self.proj(blocks[-1].dstdata)
        return h_item_dst + self.sage(blocks, h_item)


# Run

In [68]:
model = PinSAGEModel(train_g, 'track', 128, 2)
model = model.cuda()

In [69]:
dataloader_it = iter(dataloader)


In [None]:
device = torch.device('cuda:0')


In [None]:
model.train()
opt = torch.optim.Adam(model.parameters(), lr=3e-5)
losses = []
for batch_id in  range(100000000):
    pos_graph, neg_graph, blocks = next(dataloader_it)
    # Copy to GPU
    for i in range(len(blocks)):
        blocks[i] = blocks[i].to(device)
    pos_graph = pos_graph.to(device)
    neg_graph = neg_graph.to(device)
    
    loss,auc = model(pos_graph, neg_graph, blocks)
    loss = loss.mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    if batch_id % 100 == 0:
        print(loss, auc)
        losses.append([loss.item(),auc])

In [None]:
from matplotlib import pyplot as plt
def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

plt.plot(moving_average(np.array([x[0] for x in losses]), 100))