In [1]:
import numpy as np

""" Snowball Sampling 
    INPUT:
        graph:        adj. mat. of a graph
        tr_fraction:  fraction of training nodes, default is 0.1
    OUTOUT:
        Train:        indices of training nodes
"""
def snowball_sampling(graph, tr_fraction=0.1):
    TrFraction = tr_fraction

    graph = graph + graph.T
    graph[graph > 0] = 1

    sparseGraph = graph.nonzero()

    TrFraction = 1 - TrFraction
    
    Nodes = np.unique(sparseGraph)
    n_Test = int(np.floor(len(Nodes)*TrFraction))

    Test = np.zeros(n_Test)
    Train = np.zeros(len(Nodes) - n_Test)

    n_seed = int(np.ceil(n_Test*0.02))
    Seed = np.random.choice(Nodes, n_seed, replace=False)
    Selected = Seed

    while len(Selected) < n_Test:
        tmp_Neighbor = Nodes[np.squeeze(np.asarray(np.sum(graph[Seed], axis=0).ravel())).nonzero()[0]]
        Neighbor = []
        for n in tmp_Neighbor:
            if n in Selected:
                continue
            Neighbor.append(n)

        tmp_Selected = []
        if len(Neighbor) > 0:
            tmp_Selected = np.random.choice(np.array(Neighbor), int(len(Neighbor)*TrFraction/2), replace=False)

        if len(tmp_Selected) == 0:
            UnSelected = np.setdiff1d(np.array(range(graph.shape[0])), Selected, assume_unique=True)
            Seed = np.random.choice(UnSelected, np.min([n_seed, len(UnSelected)]), replace=False)
            Selected = np.unique(np.append(Selected, Seed))
        else:
            Selected = np.unique(np.append(Selected, tmp_Selected))
            Seed = tmp_Selected
    Test = Selected[0 : n_Test]
    Train = np.setdiff1d(np.array(range(graph.shape[0])), Test, assume_unique=True)
    return Train