In [1]:
import torch
import torch.nn as nn
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from sklearn.linear_model import LogisticRegression

import os
import os.path as osp
from torch_geometric.datasets import AMiner

In [2]:
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data','AMiner')

NameError: name '__file__' is not defined

In [6]:
dataset = AMiner(root='/tmp/Aminer')#, name='Aminer')

Downloading https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1
Extracting /tmp/Aminer/net_aminer.zip?dl=1
Downloading https://www.dropbox.com/s/nkocx16rpl4ydde/label.zip?dl=1
Extracting /tmp/Aminer/raw/label.zip?dl=1
Processing...
Done!


In [60]:
data = dataset[0]
metapath = [
    ('author', 'wrote', 'paper'),
    ('paper', 'published in', 'venue'),
    ('venue', 'published', 'paper'),
    ('paper', 'written by', 'author'),
]
embedding_dim = 128
walk_length = 50
context_size=7
walks_per_node=5
num_negative_samples=5

In [14]:
data

Data(
  edge_index_dict={
    ('paper', 'written by', 'author')=[2, 9323605],
    ('author', 'wrote', 'paper')=[2, 9323605],
    ('paper', 'published in', 'venue')=[2, 3194405],
    ('venue', 'published', 'paper')=[2, 3194405]
  },
  num_nodes_dict={
    paper=3194405,
    author=1693531,
    venue=3883
  },
  y_dict={
    author=[246678],
    venue=[134]
  },
  y_index_dict={
    author=[246678],
    venue=[134]
  }
)

In [29]:
adj_dict = {}
for keys, edge_index in data.edge_index_dict.items():
    sizes = (data.num_nodes_dict[keys[0]], data.num_nodes_dict[keys[-1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=sizes)
    adj_dict[keys] = adj

In [48]:
adj_dict

{('paper',
  'written by',
  'author'): SparseTensor(row=tensor([      0,       1,       2,  ..., 3194404, 3194404, 3194404]),
              col=tensor([     0,      1,      2,  ...,   4393,  21681, 317436]),
              size=(3194405, 1693531), nnz=9323605, density=0.00%),
 ('author',
  'wrote',
  'paper'): SparseTensor(row=tensor([      0,       0,       0,  ..., 1693528, 1693529, 1693530]),
              col=tensor([      0,   45988,  124807,  ..., 3194371, 3194387, 3194389]),
              size=(1693531, 3194405), nnz=9323605, density=0.00%),
 ('paper',
  'published in',
  'venue'): SparseTensor(row=tensor([      0,       1,       2,  ..., 3194402, 3194403, 3194404]),
              col=tensor([2190, 2190, 2190,  ..., 3148, 3148, 3148]),
              size=(3194405, 3883), nnz=3194405, density=0.03%),
 ('venue',
  'published',
  'paper'): SparseTensor(row=tensor([   0,    0,    0,  ..., 3882, 3882, 3882]),
              col=tensor([2203069, 2203070, 2203071,  ...,  952391,  952392

In [47]:
types = set(x[0] for x in metapath) | set(x[-1] for x in metapath)
sorted(list(types))

['author', 'paper', 'venue']

In [49]:
count = 0
start, end = {}, {}

In [51]:
for key in types:
    start[key] = count
    count += data.num_nodes_dict[key]
    end[key] = count

In [52]:
start

{'paper': 0, 'author': 3194405, 'venue': 4887936}

In [53]:
end

{'paper': 3194405, 'author': 4887936, 'venue': 4891819}

In [77]:
offset = [start[metapath[0][0]]]
offset += [start[keys[-1]] for keys in metapath] * int((walk_length / len(metapath)) + 1)
offset = offset[:walk_length + 1]
assert len(offset) == walk_length + 1
offset = torch.tensor(offset)

In [78]:
offset

tensor([3194405,       0, 4887936,       0, 3194405,       0, 4887936,       0,
        3194405,       0, 4887936,       0, 3194405,       0, 4887936,       0,
        3194405,       0, 4887936,       0, 3194405,       0, 4887936,       0,
        3194405,       0, 4887936,       0, 3194405,       0, 4887936,       0,
        3194405,       0, 4887936,       0, 3194405,       0, 4887936,       0,
        3194405,       0, 4887936,       0, 3194405,       0, 4887936,       0,
        3194405,       0, 4887936])

In [84]:
embedding = Embedding(count, embedding_dim, sparse=True)

In [208]:
count

4891819

In [209]:
def loader(self, **kwargs):
    return DataLoader(range(data.num_nodes_dict[metapath[0][0]]),
                     collate_fn = sample, **kwargs
                     )

In [212]:
def sample(self, batch):
    if not isinstance(batch, torch.Tensor):
        batch = torch.tensor(batch)
    return pos_sample(batch), neg_sample(batch)

In [213]:
def pos_sample(self, batch):
    # device = self.embedding.weight.device

    batch = batch.repeat(walks_per_node)

    rws = [batch]
    for i in range(walk_length):
        keys = metapath[i % len(metapath)]
        adj = adj_dict[keys]
        batch = adj.sample(num_neighbors=1, subset=batch).squeeze()
        rws.append(batch)

    rw = torch.stack(rws, dim=-1)
    rw.add_(offset.view(1, -1))

    walks = []
    num_walks_per_rw = 1 + walk_length + 1 - context_size
    for j in range(num_walks_per_rw):
        walks.append(rw[:, j:j + context_size])
    return torch.cat(walks, dim=0)

In [214]:
def neg_sample(self, batch):
    batch = batch.repeat(walks_per_node * num_negative_samples)

    rws = [batch]
    for i in range(walk_length):
        keys = metapath[i % len(metapath)]
        batch = torch.randint(0, num_nodes_dict[keys[-1]],
                              (batch.size(0), ), dtype=torch.long)
        rws.append(batch)

    rw = torch.stack(rws, dim=-1)
    rw.add_(offset.view(1, -1))

    walks = []
    num_walks_per_rw = 1 + walk_length + 1 - context_size
    for j in range(num_walks_per_rw):
        walks.append(rw[:, j:j + context_size])
    return torch.cat(walks, dim=0)

In [215]:
loader = DataLoader(range(data.num_nodes_dict[metapath[0][0]]), 
                   collate_fn = sample,
                    batch_size=128,
                    shuffle=True,
                    num_workers=12
                   ) 

In [216]:
for i, (pos_rw, neg_rw) in enumerate(loader):
    print(pos_rw)
    print(neg_rw)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/junseok/anaconda3/envs/study2/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/junseok/anaconda3/envs/study2/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
TypeError: sample() missing 1 required positional argument: 'batch'


In [219]:
class MetaPath2Vec(torch.nn.Module):

    def __init__(self, edge_index_dict, embedding_dim, metapath, walk_length,
                 context_size, walks_per_node=1, num_negative_samples=1,
                 num_nodes_dict=None, sparse=False):
        super(MetaPath2Vec, self).__init__()

        if num_nodes_dict is None:
            num_nodes_dict = {}
            for keys, edge_index in edge_index_dict.items():
                key = keys[0]
                N = int(edge_index[0].max() + 1)
                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

                key = keys[-1]
                N = int(edge_index[1].max() + 1)
                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))

        adj_dict = {}
        for keys, edge_index in edge_index_dict.items():
            sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]])
            row, col = edge_index
            adj = SparseTensor(row=row, col=col, sparse_sizes=sizes)
            adj = adj.to('cpu')
            adj_dict[keys] = adj

        assert metapath[0][0] == metapath[-1][-1]
        assert walk_length >= context_size

        self.adj_dict = adj_dict
        self.embedding_dim = embedding_dim
        self.metapath = metapath
        self.walk_length = walk_length
        self.context_size = context_size
        self.walks_per_node = walks_per_node
        self.num_negative_samples = num_negative_samples
        self.num_nodes_dict = num_nodes_dict

        types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath])
        types = sorted(list(types))

        count = 0
        self.start, self.end = {}, {}
        for key in types:
            self.start[key] = count
            count += num_nodes_dict[key]
            self.end[key] = count

        offset = [self.start[metapath[0][0]]]
        offset += [self.start[keys[-1]] for keys in metapath
                   ] * int((walk_length / len(metapath)) + 1)
        offset = offset[:walk_length + 1]
        assert len(offset) == walk_length + 1
        self.offset = torch.tensor(offset)

        self.embedding = Embedding(count, embedding_dim, sparse=sparse)

        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()

    def forward(self, node_type, batch=None):
        """Returns the embeddings for the nodes in :obj:`subset` of type
        :obj:`node_type`."""
        emb = self.embedding.weight[self.start[node_type]:self.end[node_type]]
        return emb if batch is None else emb[batch]

    def loader(self, **kwargs):
        return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]),
                          collate_fn=self.sample, **kwargs)

    def pos_sample(self, batch):
        # device = self.embedding.weight.device

        batch = batch.repeat(self.walks_per_node)

        rws = [batch]
        for i in range(self.walk_length):
            keys = self.metapath[i % len(self.metapath)]
            adj = self.adj_dict[keys]
            batch = adj.sample(num_neighbors=1, subset=batch).squeeze()
            rws.append(batch)

        rw = torch.stack(rws, dim=-1)
        rw.add_(self.offset.view(1, -1))

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)

    def neg_sample(self, batch):
        batch = batch.repeat(self.walks_per_node * self.num_negative_samples)

        rws = [batch]
        for i in range(self.walk_length):
            keys = self.metapath[i % len(self.metapath)]
            batch = torch.randint(0, self.num_nodes_dict[keys[-1]],
                                  (batch.size(0), ), dtype=torch.long)
            rws.append(batch)

        rw = torch.stack(rws, dim=-1)
        rw.add_(self.offset.view(1, -1))

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)

    def sample(self, batch):
        if not isinstance(batch, torch.Tensor):
            batch = torch.tensor(batch)
        return self.pos_sample(batch), self.neg_sample(batch)


In [220]:
model = MetaPath2Vec(data.edge_index_dict, embedding_dim=128,
                         metapath=metapath, walk_length=50, context_size=7,
                         walks_per_node=5, num_negative_samples=5,
                         sparse=True).to(device)

In [221]:
model

MetaPath2Vec(
  (embedding): Embedding(4891819, 128, sparse=True)
)

In [222]:
loader = model.loader(batch_size=128)

In [229]:
data.num_nodes_dict[metapath[0][0]]

1693531

In [227]:
for i, (pos_rw, neg_rw) in enumerate(loader):
    print("pos_rw")
    print(pos_rw.shape)
    print("neg_rw")
    print(neg_rw.shape)

pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800, 7])
neg_rw
torch.Size([144000, 7])
pos_rw
torch.Size([28800

KeyboardInterrupt: 