In [7]:
import networkx as nx
from gensim.models import Word2Vec
import numpy as np

In [139]:
class deepwalk:
    
    def __init__(self, graph, embedding_size=128, num_walks=80, window_size=10, walk_length=40, hs=1, sg=1, negative=0):
        self.G = graph
        self.embedding_size=embedding_size
        self.num_walks = num_walks
        self.window_size = window_size
        self.walk_length = walk_length
        self.hs=hs
        self.sg = sg
        self.negative = negative
        
    def sample_one_walk(self, start_node):
        '''
        使用Random Walk生成一条Walk
        start_node:起始节点
        '''
        current_node = start_node
        walk = [str(current_node)]
        if len(self.G[current_node]) == 0:
            return walk
        for step in  range(self.walk_length-1):
            nbr_info_dict = dict(self.G[current_node])
            nbrs = list(nbr_info_dict.keys())
            weights = list(nbr_info_dict.values())
            weights = np.array(list(map(lambda x: x['weight'], weights)))
            next_node = np.random.choice(nbrs, p=weights/weights.sum())
            walk.append(str(next_node))
            current_node = next_node
        return walk
    
    def sample_walks(self):
        '''
        生成网络中所有节点的Walk
        '''
        nodes = np.array(self.G.nodes)
        walks = []
        for i in range(self.num_walks):
            np.random.shuffle(nodes)
            for node in nodes:
                walk = self.sample_one_walk(node)
                walks.append(walk)
        return walks
    
    def train(self, epochs=5, min_count=0, workers=3):
        kwargs = dict()
        walks = self.sample_walks()
        kwargs['sentences'] = walks
        kwargs['epochs'] = epochs
        kwargs['vector_size'] = self.embedding_size
        kwargs['min_count'] = min_count
        kwargs['sg'] = self.sg
        kwargs['hs'] = self.hs
        kwargs['negative'] = self.negative
        kwargs['workers'] = workers
        kwargs['window'] = self.window_size
        model = Word2Vec(**kwargs)
        self.wv_model = model
        return model
    
    def get_embedding(self, node):
        '''
        返回node的embedding表示
        '''
        return self.wv_model.wv[str(node)]
    
    def get_embeddings(self):
        '''
        返回embedding表示矩阵
        '''
        embs = np.array([self.get_embedding(node) for node in self.G.nodes])
        return embs
    
    def similarity(self, node1, node2):
        '''
        计算node1和node2的相似度
        '''
        return self.wv_model.wv.similarity(str(node1), str(node2))
    
    def most_similar(self, node, topn=10):
        '''
        返回与node最相似的topn个节点
        '''
        return self.wv_model.wv.most_similar(str(node), topn=topn)