In [47]:
import torch
import torch.nn as nn
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from sklearn.linear_model import LogisticRegression

import os
import os.path as osp
from torch_geometric.datasets import AMiner

In [48]:
class MetaPath2Vec(nn.Module):
    
    """
    Args:
        edge_index_dict(dict) :obj : (source_node_type, relation_type, target_node_type) tuples
        
        embedding_dim(int) 
        
        metapath(list) : obj : (source_node_type, relation_type, target_node_type) tuples
        
        walk_length(int)
        
        context_size(int) 
        
        walks_per_node(int, optional) : walk를 몇번 할지
        
        num_negative_samples(int, optional) : 
        
        num_nodes_dict(dict, optional) : 
        
        sparse (bool, optional) : 
        
    """
    
    def __init__(self, edge_index_dict, embedding_dim, metapath, walk_length,
                context_size, walks_per_node=1, num_negative_samples=1,
                 num_nodes_dict = None, sparse=False):
        super(MetaPath2Vec, self).__init__()
        
        
        # num_nodes_dict이 입력되지 않을 경우 생성
        if num_nodes_dict is None:
            num_nodes_dict = {}
            for keys, edge_index in edge_index_dict.items():
                key = keys[0]
                N = int(edge_index[0].max() + 1)
                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))
                
                key = keys[-1]
                N = int(edge_index[1].max() + 1)
                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))
        
        # Sparse Tensor로 변환(edge_index_dict -> adj_dict)
        adj_dict = {}
        for keys, edge_index in edge_index_dict.items():
            sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]])
            row, col = edge_index
            adj = SparseTensor(row=row, col=col, sparse_sizes=sizes)
            adj = adj.to('cpu')
            adj_dict[keys] = adj
            
        assert metapath[0][0] == metapath[-1][-1] # metapath는 대칭적이여야함
        assert walk_length >= context_size
        
        self.adj_dict = adj_dict
        self.embedding_dim = embedding_dim
        self.metapath = metapath
        self.walk_length = walk_length
        self.context_size = context_size
        self.walks_per_node = walks_per_node
        self.num_negative_samples = num_negative_samples
        self.num_nodes_dict = num_nodes_dict
        
        # metapath에 존재하는 모든 type
        types = set(x[0] for x in metapath) | set(x[-1] for x in metapath)
        types = sorted(list(type))
        
        # count : metapath안에 있는 type들의 전체 갯수
        # start ~ end : 특정 type의 embedding의 위치 
        count = 0
        self.start, self.end = {}, {}
        for key in types:
            self.start[key] = count
            count += num_nodes_dict[key]
            self.end[key] = count
        
        
        # offset : meta path의 각 스템에서 start지점을 표시 
        offset = [self.start[metapath[0][0]]]
        offset += [self.start[keys[-1]] for keys in metapath] * int((walk_length / len(metapath)) + 1)
        offset = offset[:walk_length + 1]
        assert len(offset) == walk_length + 1
        self.offset = torch.tensor(offset)
        
        self.embedding = Embedding(count, embedding_dim, sparse=sparse)
        
        self.reset_parameters()
        
        # embedding parameter 초기화
        def reset_parameters(self):
            self.embedding.reset_parameters()
        
        # forward pass - nodetype을 주면 그 embedding값 출력
        def forward(self, node_type, batch=None):
            """
            Returns the embeddings for the nodes in :
            obj:'node_type'
            """
            emb = self.embedding.weight[self.start[node_type]:self.end[node_type]]
            return emb if batch is None else emb[batch]
        
        # collate fn이 뭐지?
        def loader(self, **kwargs):
            return DataLoader( range(self.num_nodes_dict[self.metapath[0][0]]),
                              
            )
        
        
        # 
        def pos_sample(self, batch):
            
            batch = batch.repeat(self.walks_per_node)
            
            rws = [batch]
            for i in range(self.walk_length):
                keys = self.metapath[i % len(self.metapath)]
        
        
        # batch의 type을 Tensor로 바꿔줌
        def sample(self, batch):
            if not isinstance(batch, torch.Tensor):
                batch = torch.tensor(batch)
            return self.pos_sample(batch), self.neg_sample(batch)
        
        
        
        
        
        # 
        def loss(self, pos_rw, neg_rw):
            """
            Computes the loss given positive and negative random walks
            """
            
            # Positive loss
            start ,rest = pos_rw[]

In [103]:
a = torch.tensor([1,2,3])
a

tensor([1, 2, 3])

In [106]:
[a.repeat(2)]

[tensor([1, 2, 3, 1, 2, 3])]

[('author', 'wrote', 'paper'),
 ('paper', 'published in', 'venue'),
 ('venue', 'published', 'paper'),
 ('paper', 'written by', 'author')]

In [None]:
    def loader(self, **kwargs):
        return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]),
                          collate_fn=self.sample, **kwargs)

In [97]:
for i, a in enumerate(DataLoader(range(5))):
    print(i, a)

0 tensor([0])
1 tensor([1])
2 tensor([2])
3 tensor([3])
4 tensor([4])


In [111]:
data

Data(
  edge_index_dict={
    ('paper', 'written by', 'author')=[2, 9323605],
    ('author', 'wrote', 'paper')=[2, 9323605],
    ('paper', 'published in', 'venue')=[2, 3194405],
    ('venue', 'published', 'paper')=[2, 3194405]
  },
  num_nodes_dict={
    paper=3194405,
    author=1693531,
    venue=3883
  },
  y_dict={
    author=[246678],
    venue=[134]
  },
  y_index_dict={
    author=[246678],
    venue=[134]
  }
)

In [112]:
data.y_dict

{'author': tensor([0, 2, 5,  ..., 0, 1, 5]),
 'venue': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])}

In [113]:
data.y_index_dict

{'author': tensor([ 168866, 1327323,     870,  ...,  168759,  254769,  264374]),
 'venue': tensor([1741, 2245,  111,  837, 2588, 2116, 2696, 3648, 3784,  313, 3414,  598,
         2995, 2716, 1423,  783, 1902, 3132, 1753, 2748, 2660, 3182,  775, 3339,
         1601, 3589,  156, 1145,  692, 3048,  925, 1587,  820, 1374, 3719,  819,
          492, 3830, 2777, 3001, 3693,  517, 1808, 2353, 3499, 1763, 2372, 1030,
          721, 2680, 3355, 1217, 3400, 1271, 1970, 1127,  407,  353, 1471, 1095,
          477, 3701,   65, 1009, 1899, 1442, 2073, 3143, 2466,  289, 1996, 1070,
         3871, 3695,  281, 3633,   50, 2642, 1925, 1285, 2587, 3814, 3582, 1873,
         1339, 3450,  271, 2966,  453, 2638, 1354, 3211,  391, 1588, 3875, 2216,
         2146, 3765, 2486,  661, 3367,  426,  750, 2158,  519,  230, 1677,  839,
         2945, 1313, 1037, 2879, 2225, 3523, 1247,  448,  227, 3385,  529, 2849,
         1584, 1229,  373, 2235, 1819, 1764, 3155, 2852, 2789, 3474, 1571, 2088,
          208,  462

In [119]:
data.y_dict['author'].unique()

tensor([0, 1, 2, 3, 4, 5, 6, 7])

In [120]:
torch.randperm(10)

tensor([8, 1, 4, 2, 9, 5, 6, 7, 3, 0])

In [121]:
data.y_index_dict['author']

tensor([ 168866, 1327323,     870,  ...,  168759,  254769,  264374])