# Install Packages

In [1]:
import torch
import torch.nn as nn
from torch.nn import Embedding
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from sklearn.linear_model import LogisticRegression
import torch_cluster

from torch_geometric.datasets import Planetoid

# Load Dataset

In [3]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

In [19]:
data

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])

# Model

In [None]:
class Node2Vec(nn.Module):
    """
    Args:
        edge_index(LongTensor)
        embedding_dim(int)
        walk_length(int)
        context_size(int) : 윈도우 사이즈
        walks_per_node(int, optional)
        p (float, optional) : 다시 돌아올 가능도
        q (float, optional) : 멀리갈 가능도
        num_negative_samples 
        num_nodes(int, optional)
        sparse(bool, optional)
    """
    
    def __init__(self, edge_index, embedding_dim, walk_length, context_size,
                walks_per_node=1, p=1, q=1, num_negative_samples=1,
                num_nodes=None, sparse=False):
        super(Node2Vec, self).__init__()
        
        N = data.num_nodes
        row, col = data.edge_index
        self.adj = SparseTensor(row=row, col=col, sparse_size=(N,N))
        #self.adj = self.adj.to('cpu')
        
        assert walk_length >= context_size
        
        self.embedding_dim = embedding_dim
        self.walk_length = walk_length - 1
        self.context_size = context_size
        self.walks_per_node = walks_per_node
        self.p = p
        self.q = q
        self.num_negative_samples = num_negative_samples
        
        self.embedding = Embedding(N, embedding_dim, sparse=sparse)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        self.embedding.reset_parameters()
        
    def forward(self, batch=None):
        emb = self.embedding.weight
        
        return emb if batch is None else emb[batch]
    
    def loader(self, **kwargs):
        return DataLoader(range(self.adj.sparse_size(0)),
                         collate_fn=self.sample, **kwargs)
    
    def pos_sample(self, batch):
        batch = batch.repeat(self.walks_per_node) # 한 노드마다 몇번식 출발하는지
        rowptr, col, _ = self.adj.csr() # coo방식에서 csr방식으로 변경
        rw = random_walk(rowptr, col, batch, self.walk_length, self.p, self.q)
        if not isinstance(rw, torch.Tensor):
            rw = rw[0]
            
        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 -self.context_size # randomwalk당 나오는 시퀀스 갯수
        for j in range(num_walks_per_rw):
            walks.append(rw[ : , j:j + self.context_size])
        return torch.cat(walks, dim=0)
        
    
    def neg_sample(self, batch):
        
        
    def sample(self, batch):
        if not isinstance(batch, torch.Tensor):
            batch = torch.tensor(batch)
        return self.pos_sample(batch), self.neg_sample(batch)
        
    def loss(self, pos_rw, neg_rw):
        
        # positive loss
        start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()
        
        
        
        # negative loss
        