In [63]:
import networkx as nx

def new_graph(node_types, first_node):
    g = nx.MultiDiGraph()
    node_id = 0
    ntype = first_node['ntype']
    node_data = {attr: first_node[attr] for attr in node_types[ntype]}
    g.add_node(node_id, ntype=ntype, **node_data)
    return g

def find_node_id_by_name(graph, name):
    for node_id, node_data in graph.nodes(data=True):
        if 'name' in node_data and node_data['name'] == name:
            return node_id
    return 0

def get_last_node_id(graph):
    last_node_id = list(graph.nodes())[-1]
    return last_node_id
        
def node_add(graph, node):
    node_id = get_last_node_id(graph) + 1
    ntype = node['ntype']
    node_data = {attr: node[attr] for attr in node_types[ntype]}
    graph.add_node(node_id, ntype=ntype, **node_data)
        
def edge_add(graph, item, board,relation):
    item_id = find_node_id_by_name(graph,item)
    board_id = find_node_id_by_name(graph,board)
    
    graph.add_edge(item_id, board_id, relation=relation)
    
def watch_graph(graph):
    print("Nodes:", graph.nodes(data=True))
    print("Edges:", graph.edges(data=True))

In [64]:
node_types = {
    'user': {'name': str, 'age': int, 'job': str},
    'post': {'name': str, 'content': str}
}

first_node = {'ntype': 'user', 'name': 'Hyeon Woo', 'age': 24, 'job': 'Student'}

g = new_graph(node_types, first_node)

In [65]:
watch_graph(g)

Nodes: [(0, {'ntype': 'user', 'name': 'Hyeon Woo', 'age': 24, 'job': 'Student'})]
Edges: []


In [66]:
print(get_last_node_id(g))

new_node = {'ntype': 'post', 'name': 'New Post', 'content': 'This is a new post!'}

node_add(g,new_node)

0


In [67]:
watch_graph(g)

Nodes: [(0, {'ntype': 'user', 'name': 'Hyeon Woo', 'age': 24, 'job': 'Student'}), (1, {'ntype': 'post', 'name': 'New Post', 'content': 'This is a new post!'})]
Edges: []


In [68]:
edge_add(g, 'Hyeon Woo', 'New Post','authored')

In [69]:
watch_graph(g)

Nodes: [(0, {'ntype': 'user', 'name': 'Hyeon Woo', 'age': 24, 'job': 'Student'}), (1, {'ntype': 'post', 'name': 'New Post', 'content': 'This is a new post!'})]
Edges: [(0, 1, {'relation': 'authored'})]


In [70]:
import os
import re
import argparse
import pickle

import pandas as pd
import torch

In [92]:
movies = []
with open('./dataset/movielens/movies.dat', encoding='latin1') as f:
    for l in f:
        id_, title, genres = l.strip().split('::')
        genres_set = set(genres.split('|'))

        # extract year
        assert re.match(r'.*\([0-9]{4}\)$', title)
        year = title[-5:-1]
        title = title[:-6].strip()

        data = {'movie_id': int(id_), 'title': title, 'year': year, 'genre': genres.split("|")}
        for g in genres_set:
            data[g] = True
        movies.append(data)
movies = pd.DataFrame(movies).astype({'year': 'int'})

In [93]:
movies

Unnamed: 0,movie_id,title,year,genre,Comedy,Children's,Animation,Fantasy,Adventure,Romance,...,Crime,Action,Horror,Sci-Fi,Documentary,War,Musical,Mystery,Film-Noir,Western
0,1,Toy Story,1995,"[Animation, Children's, Comedy]",True,True,True,,,,...,,,,,,,,,,
1,2,Jumanji,1995,"[Adventure, Children's, Fantasy]",,True,,True,True,,...,,,,,,,,,,
2,3,Grumpier Old Men,1995,"[Comedy, Romance]",True,,,,,True,...,,,,,,,,,,
3,4,Waiting to Exhale,1995,"[Comedy, Drama]",True,,,,,,...,,,,,,,,,,
4,5,Father of the Bride Part II,1995,[Comedy],True,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3878,3948,Meet the Parents,2000,[Comedy],True,,,,,,...,,,,,,,,,,
3879,3949,Requiem for a Dream,2000,[Drama],,,,,,,...,,,,,,,,,,
3880,3950,Tigerland,2000,[Drama],,,,,,,...,,,,,,,,,,
3881,3951,Two Family House,2000,[Drama],,,,,,,...,,,,,,,,,,


In [94]:
ratings = []
with open('./dataset/movielens/ratings.dat', encoding='latin1') as f:
    for l in f:
        user_id, movie_id, rating, timestamp = [int(_) for _ in l.split('::')]
        ratings.append({
            'user_id': user_id,
            'movie_id': movie_id,
            'rating': rating,
            'timestamp': timestamp,
            })
ratings = pd.DataFrame(ratings)

In [95]:
ratings

Unnamed: 0,user_id,movie_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291
...,...,...,...,...
1000204,6040,1091,1,956716541
1000205,6040,1094,5,956704887
1000206,6040,562,5,956704746
1000207,6040,1096,4,956715648


In [96]:
merged_ratings = pd.merge(ratings, movies, on=['movie_id'])
merged_ratings = merged_ratings[['movie_id', 'rating', 'genre']]
merged_ratings = merged_ratings.explode('genre')
genres = pd.DataFrame(merged_ratings['genre'].unique()).reset_index()
genres.columns = ['genre_id', 'genre']
merged_ratings = pd.merge(merged_ratings, genres, on='genre')
distinct_movies_in_ratings = merged_ratings['movie_id'].unique()
movies = movies[movies['movie_id'].isin(distinct_movies_in_ratings)]
genres = pd.DataFrame(genres).astype({'genre_id': 'category'})

In [97]:
movies

Unnamed: 0,movie_id,title,year,genre,Comedy,Children's,Animation,Fantasy,Adventure,Romance,...,Crime,Action,Horror,Sci-Fi,Documentary,War,Musical,Mystery,Film-Noir,Western
0,1,Toy Story,1995,"[Animation, Children's, Comedy]",True,True,True,,,,...,,,,,,,,,,
1,2,Jumanji,1995,"[Adventure, Children's, Fantasy]",,True,,True,True,,...,,,,,,,,,,
2,3,Grumpier Old Men,1995,"[Comedy, Romance]",True,,,,,True,...,,,,,,,,,,
3,4,Waiting to Exhale,1995,"[Comedy, Drama]",True,,,,,,...,,,,,,,,,,
4,5,Father of the Bride Part II,1995,[Comedy],True,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3878,3948,Meet the Parents,2000,[Comedy],True,,,,,,...,,,,,,,,,,
3879,3949,Requiem for a Dream,2000,[Drama],,,,,,,...,,,,,,,,,,
3880,3950,Tigerland,2000,[Drama],,,,,,,...,,,,,,,,,,
3881,3951,Two Family House,2000,[Drama],,,,,,,...,,,,,,,,,,


In [98]:
ratings

Unnamed: 0,user_id,movie_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291
...,...,...,...,...
1000204,6040,1091,1,956716541
1000205,6040,1094,5,956704887
1000206,6040,562,5,956704746
1000207,6040,1096,4,956715648


In [99]:
genres

Unnamed: 0,genre_id,genre
0,0,Drama
1,1,Animation
2,2,Children's
3,3,Musical
4,4,Romance
5,5,Comedy
6,6,Action
7,7,Adventure
8,8,Fantasy
9,9,Sci-Fi


In [100]:
merged_ratings

Unnamed: 0,movie_id,rating,genre,genre_id
0,1193,5,Drama,0
1,1193,5,Drama,0
2,1193,4,Drama,0
3,1193,4,Drama,0
4,1193,5,Drama,0
...,...,...,...,...
2101810,404,5,Documentary,17
2101811,404,3,Documentary,17
2101812,2198,3,Documentary,17
2101813,2198,5,Documentary,17


In [102]:
movies['year']

0       1995
1       1995
2       1995
3       1995
4       1995
        ... 
3878    2000
3879    2000
3880    2000
3881    2000
3882    2000
Name: year, Length: 3706, dtype: int32

In [114]:
from multisage.builder import PandasGraphBuilder

graph_builder = PandasGraphBuilder()
graph_builder.add_entities(genres, 'genre_id', 'genre')
graph_builder.add_entities(movies, 'movie_id', 'movie')
graph_builder.add_binary_relations(merged_ratings, 'genre_id', 'movie_id', 'define')
graph_builder.add_binary_relations(merged_ratings, 'movie_id', 'genre_id', 'define-by')
g = graph_builder.build()

g.nodes['genre'].data['id'] = torch.LongTensor(genres['genre_id'].cat.codes.values)
movies = pd.DataFrame(movies).astype({'year': 'category'})
genre_columns = movies.columns.drop(['movie_id', 'title', 'year', 'genre'])
movies[genre_columns] = movies[genre_columns].fillna(False).astype('bool')
g.nodes['movie'].data['year'] = torch.LongTensor(movies['year'].cat.codes.values)
g.nodes['movie'].data['genre'] = torch.FloatTensor(movies[genre_columns].values)
g.edges['define'].data['rating'] = torch.LongTensor(merged_ratings['rating'].values)
g.edges['define-by'].data['rating'] = torch.LongTensor(merged_ratings['rating'].values)

  g.nodes['genre'].data['id'] = torch.LongTensor(genres['genre_id'].cat.codes.values)


In [118]:
import dgl

output_path = 'graph_data.dgl'
dgl.save_graphs(output_path, [g])

In [119]:
import os
import dgl

import numpy as np
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader

from multisage import layers
from multisage.sampler import ItemToItemBatchSampler, NeighborSampler, PinSAGECollator


class MultiSAGEModel(nn.Module):
    def __init__(self, full_graph, ntype, ctype, hidden_dims, n_layers, gat_num_heads):
        super().__init__()
        self.nodeproj = layers.LinearProjector(full_graph, ntype, hidden_dims)
        self.contextproj = layers.LinearProjector(full_graph, ctype, hidden_dims)
        self.multisage = layers.MultiSAGENet(hidden_dims, n_layers, gat_num_heads)
        self.scorer = layers.ItemToItemScorer(full_graph, ntype)

    def forward(self, pos_graph, neg_graph, blocks, context_blocks):
        h_item = self.get_representation(blocks, context_blocks)
        pos_score = self.scorer(pos_graph, h_item)
        neg_score = self.scorer(neg_graph, h_item)
        return (neg_score - pos_score + 1).clamp(min=0)

    def get_representation(self, blocks, context_blocks, context_id=None):
        if context_id:
            return self.get_context_query(blocks, context_blocks, context_id)
        else:
            h_item = self.nodeproj(blocks[0].srcdata)
            h_item_dst = self.nodeproj(blocks[-1].dstdata)
            z_c = self.contextproj(context_blocks[0])
            z_c_dst = self.contextproj(context_blocks[-1])
            h = h_item_dst + self.multisage(blocks, h_item, (z_c, z_c_dst))
            return h

    def get_context_query(self, blocks, context_blocks, context_id):
        # check sub-graph contains context id
        context_id = context_blocks[-1]['_ID'][0].item()
        print(context_id)
        print(context_blocks[-1]['_ID'])
        context_index = (context_id == context_blocks[-1]['_ID']).nonzero(as_tuple=True)[0]
        if context_index.size()[0] == 0:  # if context id not in sub-graph, only random sample context using for repr
            print("context not in sub graph")
            return self.get_representation(blocks, context_blocks)
        else:  # if context id in sub-graph, get MultiSAGE's context query
            print("execute context query")
            attn_index = torch.ones(context_blocks[-1]['_ID'].shape[0], dtype=bool)
            attn_index[context_index] = False
            h_item = self.nodeproj(blocks[0].srcdata)
            h_item_dst = self.nodeproj(blocks[-1].dstdata)
            z_c = self.contextproj(context_blocks[0])
            z_c_dst = self.contextproj(context_blocks[-1])
            h = h_item_dst + self.multisage(blocks, h_item, (z_c, z_c_dst), attn_index)
            return h

In [120]:
graphs, _ = dgl.load_graphs("graph_data.dgl")
g = graphs[0]  # 저장된 그래프를 다시 불러옴

# 학습된 모델 가중치 로드
load_dict = torch.load('./multisage/MultiSAGE_weights.pth')

In [121]:
g

Graph(num_nodes={'genre': 18, 'movie': 3706},
      num_edges={('genre', 'define', 'movie'): 2101815, ('movie', 'define-by', 'genre'): 2101815},
      metagraph=[('genre', 'movie', 'define'), ('movie', 'genre', 'define-by')])

In [122]:
model = MultiSAGEModel(g, 'movie', 'genre', 512, 2, 3)
model.load_state_dict(load_dict)

<All keys matched successfully>

In [124]:
batch_sampler = ItemToItemBatchSampler(g, 'genre', 'movie', 512)
neighbor_sampler = NeighborSampler(
    g, 'genre', 'movie', 2, 0.5, 10, 5, 2)
collator = PinSAGECollator(neighbor_sampler, g, 'movie', 'genre')

index_id = 4
with torch.no_grad():
    blocks, context_blocks = collator.collate_point(index_id=index_id)
    context_batch = model.get_representation(blocks, context_blocks, context_id=4)

DGLError: Invalid key "0". Must be one of the edge types.