# Random Walk on Whole Graph

In [1]:
import networkx as nx
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
np.random.seed(0)

## Read Graph

In [2]:
graph = nx.read_edgelist('../graph/whole_undirected_graph.g')

In [3]:
nodes = list(graph.nodes)
print(len(nodes))

2826


## Hyperparameters

In [4]:
WALK_LENGTH = 20
WALK_PER_VERTEX = 20
WINDOW_SIZE = 2
K = 5
BATCH_SIZE = 128
EMBED_DIM = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 
LEARNING_RATE = 0.2
NUM_EPOCHS = 20

## Dataset

In [5]:
node2idx = {node:i for i,node in enumerate(nodes)}
node_count = np.array([graph.degree[node] for node in nodes])
node_freq = node_count/np.sum(node_count)

In [6]:
def RandomWalk(graph,start_node,walk_length):
    path = [start_node]
    node = start_node
    for _ in range(walk_length):
        node_next = np.random.choice(graph[node])
        path.append(node_next)
        node = node_next
    return path

In [7]:
corpus = []
for _ in range(WALK_PER_VERTEX):
    for node in graph.nodes:
        corpus.append(RandomWalk(graph,node,WALK_LENGTH))

print(len(corpus))

56520


In [8]:
pos_pairs = []
for path in corpus:
    for i in range(len(path)):
        idxs = (list(range(i-WINDOW_SIZE, i)) + list(range(i+1, i+WINDOW_SIZE+1)))
        idxs = [idx for idx in idxs if idx>=0 and idx<=WALK_LENGTH]
        if len(idxs)==2*WINDOW_SIZE:
            pos_pairs += [ [path[i],[path[idx] for idx in idxs] ] ]

In [9]:
class GraphDataset(tud.Dataset):
    def __init__(self,pos_pairs,node2idx,node_freq,K):
        super(GraphDataset,self).__init__()
        
        self.center_node = [ node2idx[pair[0]] for pair in pos_pairs] 
        self.pos_pairs = [[node2idx[p] for p in pair[1]] for pair in pos_pairs]
        self.center_node = torch.Tensor(self.center_node).long()
        self.pos_pairs = torch.Tensor(self.pos_pairs).long()
        self.node_freq = torch.Tensor(node_freq)
        self.K = K
        
    def __len__(self):
        return len(self.center_node)
    
    def __getitem__(self,idx):
        center_node = self.center_node[idx]
        pos_nodes = self.pos_pairs[idx]
        neg_nodes = torch.multinomial(self.node_freq, self.K * pos_nodes.shape[0], True)
        
        return center_node, pos_nodes, neg_nodes

In [10]:
dataset = GraphDataset(pos_pairs,node2idx,node_freq,K)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) 

## Model

In [11]:
class NodeEmbedding(nn.Module):
    
    def __init__(self,node_size,embed_dim):
        super(NodeEmbedding,self).__init__()
        self.node_size = node_size
        self.embed_dim = embed_dim
        
        self.in_embed = nn.Embedding(node_size,embed_dim)
        self.out_embed = nn.Embedding(node_size,embed_dim)
        
        initrange = 0.5/embed_dim
        self.in_embed.weight.data.uniform_(-initrange, initrange)
        self.out_embed.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, center_node, pos_nodes, neg_nodes):
        
        center_emb = self.in_embed(center_node)   # bs*emb_dim
        pos_emb = self.out_embed(pos_nodes) # bs*(2*ws)*emb_dim
        neg_emb = self.out_embed(neg_nodes) # bs*(2*ws*K)*emb_dim
        
        loss_pos = torch.bmm(pos_emb, center_emb.unsqueeze(2)).squeeze()  # bs*(2*ws)
        loss_neg = torch.bmm(neg_emb, -center_emb.unsqueeze(2)).squeeze() # bs*(2*ws*K)

        loss_pos = F.logsigmoid(loss_pos).sum(1)
        loss_neg = F.logsigmoid(loss_neg).sum(1) # batch_size
       
        loss = loss_pos + loss_neg
        
        return -loss
    
    def get_embed(self):
        return self.in_embed.weight.data.cpu().numpy().tolist()
        

In [12]:
model = NodeEmbedding(len(graph.nodes),EMBED_DIM)
model = model.to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(),lr = LEARNING_RATE)

In [13]:
losses = []
for e in range(NUM_EPOCHS):
    for i, batch in enumerate(dataloader):
        center_node, pos_nodes, neg_nodes = map(lambda x:x.long().to(DEVICE), batch)
        
        optimizer.zero_grad()
        loss = model(center_node, pos_nodes, neg_nodes).mean()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        if i % 2000 == 0:
            print("epoch: {}, iter: {}, loss: {}".format(e, i, loss.item()))

epoch: 0, iter: 0, loss: 16.635639190673828
epoch: 0, iter: 2000, loss: 16.632469177246094
epoch: 0, iter: 4000, loss: 16.53997039794922
epoch: 0, iter: 6000, loss: 14.52586555480957
epoch: 1, iter: 0, loss: 12.094035148620605
epoch: 1, iter: 2000, loss: 10.832967758178711
epoch: 1, iter: 4000, loss: 10.425745010375977
epoch: 1, iter: 6000, loss: 9.422791481018066
epoch: 2, iter: 0, loss: 8.974201202392578
epoch: 2, iter: 2000, loss: 7.940958023071289
epoch: 2, iter: 4000, loss: 7.051218509674072
epoch: 2, iter: 6000, loss: 6.262929916381836
epoch: 3, iter: 0, loss: 5.802253246307373
epoch: 3, iter: 2000, loss: 4.948734283447266
epoch: 3, iter: 4000, loss: 4.57375431060791
epoch: 3, iter: 6000, loss: 4.370632171630859
epoch: 4, iter: 0, loss: 3.911592960357666
epoch: 4, iter: 2000, loss: 3.665191888809204
epoch: 4, iter: 4000, loss: 3.224729537963867
epoch: 4, iter: 6000, loss: 3.1978776454925537
epoch: 5, iter: 0, loss: 3.340299129486084
epoch: 5, iter: 2000, loss: 3.0534088611602783


## Output Embedding

In [14]:
embed_dict = dict(zip(node2idx.keys(),model.get_embed()))
pd.DataFrame(embed_dict).T.to_csv('embedding/random_walk_whole_graph_'+str(EMBED_DIM)+'.csv')

In [15]:
np.save('loss/random_walk_whole_graph_'+str(EMBED_DIM)+'.npy',losses)