# Prob Walk on Stratified Graph

Stratified Walk: There are different types of railways in the graph (G,D,K,T,Z,C). So, each time, we only walk on one specific railway graph. 

In [1]:
import networkx as nx
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
np.random.seed(0)

def merge(lists):
    out = []
    for l in lists:
        out += l
    return out

## Read Data

In [2]:
railway_type = ['g','d','k','t','c','z','n']
graph = {}
for rail in railway_type:
    graph[rail] = nx.read_edgelist('../graph/'+rail+'_undirected_graph.g')

In [3]:
nodes = list(set(merge([list(g.nodes) for g in graph.values()])))
print(len(nodes))

2826


## Hyperparameters

In [4]:
WALK_LENGTH = 20
WALK_PER_VERTEX = 20
WINDOW_SIZE = 2
K = 5
BATCH_SIZE = 128
EMBED_DIM = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 
LEARNING_RATE = 0.2
NUM_EPOCHS = 20

## Dataset

In [5]:
node2idx = {node:i for i,node in enumerate(nodes)}

In [6]:
node_freq = {}
for rail in railway_type:
    count = []
    for node in nodes:
        if node in graph[rail].nodes:
            count.append(sum([value['weight'] for value in graph[rail][node].values()]))
        else:
            count.append(0)
    count = np.array(count)
    freq = count/np.sum(count)
    node_freq[rail] = freq

In [7]:
def StratifiedProbWalk(graph,node,walk_length):
    choice_graph ={key:sum([value['weight'] for value in graph[key][node].values()])\
         for key, value in graph.items() if node in value.nodes}
    prob = np.array(list(choice_graph.values()))
    prob = prob/np.sum(prob)
    layer = np.random.choice(list(choice_graph.keys()),1,p=prob)[0]
    graph = graph[layer]
    
    path = [node]
    for _ in range(walk_length):
        neighbour = graph[node]
        prob = np.array([value['weight'] for value in neighbour.values()])
        prob = prob/np.sum(prob)
        node_next = np.random.choice(list(neighbour.keys()),1,p=prob)
        path.append(node_next[0])
        node = node_next[0]
        
    return path, layer

In [8]:
corpus = []
rail_type = []
for _ in range(WALK_PER_VERTEX):
    for node in nodes:
        path, layer = StratifiedProbWalk(graph,node,WALK_LENGTH)
        corpus.append(path)
        rail_type.append(layer)

print(len(corpus))
print(Counter(rail_type))

56520
Counter({'k': 19068, 'n': 17509, 'g': 6865, 'd': 4222, 't': 4143, 'c': 3852, 'z': 861})


In [9]:
pos_pairs = []
for path in corpus:
    for i in range(len(path)):
        idxs = (list(range(i-WINDOW_SIZE, i)) + list(range(i+1, i+WINDOW_SIZE+1)))
        idxs = [idx for idx in idxs if idx>=0 and idx<=WALK_LENGTH]
        if len(idxs)==2*WINDOW_SIZE:
            pos_pairs += [ [path[i],[path[idx] for idx in idxs] ] ]

rail_type = merge([[t]*(WALK_LENGTH-WINDOW_SIZE-1) for t in rail_type])

In [10]:
print(len(pos_pairs),len(rail_type))

960840 960840


In [11]:
class GraphDataset(tud.Dataset):
    def __init__(self,pos_pairs,rail_type,node2idx,node_freq,K):
        super(GraphDataset,self).__init__()
        
        self.center_node = [ node2idx[pair[0]] for pair in pos_pairs] 
        self.pos_pairs = [[node2idx[p] for p in pair[1]] for pair in pos_pairs]
        self.center_node = torch.Tensor(self.center_node).long()
        self.pos_pairs = torch.Tensor(self.pos_pairs).long()
        self.rail_type = rail_type
        self.node_freq = {key:torch.Tensor(value) for key,value in node_freq.items()}
        self.K = K
        
    def __len__(self):
        return len(self.center_node)
    
    def __getitem__(self,idx):
        center_node = self.center_node[idx]
        pos_nodes = self.pos_pairs[idx]
        rail = self.rail_type[idx]
        neg_nodes = torch.multinomial(self.node_freq[rail], self.K * pos_nodes.shape[0], True)
        
        return center_node, pos_nodes, neg_nodes

In [12]:
dataset = GraphDataset(pos_pairs,rail_type,node2idx,node_freq,K)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) 

## Model

In [13]:
class NodeEmbedding(nn.Module):
    
    def __init__(self,node_size,embed_dim):
        super(NodeEmbedding,self).__init__()
        self.node_size = node_size
        self.embed_dim = embed_dim
        
        self.in_embed = nn.Embedding(node_size,embed_dim)
        self.out_embed = nn.Embedding(node_size,embed_dim)
        
        initrange = 0.5/embed_dim
        self.in_embed.weight.data.uniform_(-initrange, initrange)
        self.out_embed.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, center_node, pos_nodes, neg_nodes):
        
        center_emb = self.in_embed(center_node)   # bs*emb_dim
        pos_emb = self.out_embed(pos_nodes) # bs*(2*ws)*emb_dim
        neg_emb = self.out_embed(neg_nodes) # bs*(2*ws*K)*emb_dim
        
        loss_pos = torch.bmm(pos_emb, center_emb.unsqueeze(2)).squeeze()  # bs*(2*ws)
        loss_neg = torch.bmm(neg_emb, -center_emb.unsqueeze(2)).squeeze() # bs*(2*ws*K)

        loss_pos = F.logsigmoid(loss_pos).sum(1)
        loss_neg = F.logsigmoid(loss_neg).sum(1) # batch_size
       
        loss = loss_pos + loss_neg
        
        return -loss
    
    def get_embed(self):
        return self.in_embed.weight.data.cpu().numpy().tolist()
        

In [14]:
model = NodeEmbedding(len(nodes),EMBED_DIM)
model = model.to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(),lr = LEARNING_RATE)

In [15]:
losses = []
for e in range(NUM_EPOCHS):
    for i, batch in enumerate(dataloader):
        center_node, pos_nodes, neg_nodes = map(lambda x:x.long().to(DEVICE), batch)
        
        optimizer.zero_grad()
        loss = model(center_node, pos_nodes, neg_nodes).mean()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        if i % 2000 == 0:
            print("epoch: {}, iter: {}, loss: {}".format(e, i, loss.item()))

epoch: 0, iter: 0, loss: 16.635223388671875
epoch: 0, iter: 2000, loss: 16.616764068603516
epoch: 0, iter: 4000, loss: 14.976537704467773
epoch: 0, iter: 6000, loss: 11.731965065002441
epoch: 1, iter: 0, loss: 10.44923210144043
epoch: 1, iter: 2000, loss: 9.22749137878418
epoch: 1, iter: 4000, loss: 8.145805358886719
epoch: 1, iter: 6000, loss: 7.470710754394531
epoch: 2, iter: 0, loss: 6.97938346862793
epoch: 2, iter: 2000, loss: 6.119506359100342
epoch: 2, iter: 4000, loss: 5.633045196533203
epoch: 2, iter: 6000, loss: 4.989376068115234
epoch: 3, iter: 0, loss: 4.6959547996521
epoch: 3, iter: 2000, loss: 4.5140838623046875
epoch: 3, iter: 4000, loss: 3.805485248565674
epoch: 3, iter: 6000, loss: 3.6393494606018066
epoch: 4, iter: 0, loss: 3.7159042358398438
epoch: 4, iter: 2000, loss: 3.2905123233795166
epoch: 4, iter: 4000, loss: 2.970101833343506
epoch: 4, iter: 6000, loss: 3.0207319259643555
epoch: 5, iter: 0, loss: 2.8857178688049316
epoch: 5, iter: 2000, loss: 2.743900775909424


## Output Embed

In [16]:
embed_dict = dict(zip(node2idx.keys(),model.get_embed()))
pd.DataFrame(embed_dict).T.to_csv('embedding/prob_walk_stratified_graph_'+str(EMBED_DIM)+'.csv')

In [17]:
np.save('loss/prob_walk_stratified_graph_'+str(EMBED_DIM)+'.npy',losses)