In [33]:
import numpy as np
import random

class Friedman_Rafsky:
    
    r"""
    Friedman-Rafksy (FR) test statistic and p-value.
    
    This is a multivariate extension of the Wald-Wolfowitz 
    runs test for randomness. The normal concept of a 'run'
    is replaced by a minimum spanning tree calculated between
    the points in respective data sets with edge weights defined
    as the Euclidean distance between two such points. After MST
    has been determined, all edges such that both corresponding
    nodes do not belong to the same class are severed and the
    number of independent resulting trees is counted. This test is 
    consistent against similar tests 
    :footcite:p:`GSARMultivariateAdaptationofWaldWolfowitzTest`
    
    """

    class Graph:
        
        r"""
        Helper class that finds the MST for a given dataset using
        Kruskal's MST algorithm. Takes in node lables and edge weights
        before returning all pairs of connected node labels in the 
        resultant MST.
        
        Parameters
        ----------
        
        i,j,weight : int, int, float
            Input node indeces i, j, and corresponding Euclidean distance
            edge weight between them. 
            
        Returns
        -------
        
        result : list
            List of pairs of nodes connected in final MST.
        
        """
        
        def __init__(self2, vertex):
            
            self2.V = vertex
            self2.graph = [] #Empty matrix for holding vertices and weights connecting them


        def add_edge(self2, v1, v2, w):
            
            self2.graph.append([v1, v2, w]) #Add method for creating edges between vertices


        def search(self2, parent, i): #Method for determining location of vertex in existing tree
            
            if parent[i] == i:
                return i
            
            return self2.search(parent, parent[i])


        def apply_union(self2, parent, rank, x, y): #Method for deleting and merging branches
            
            xroot = self2.search(parent, x)
            yroot = self2.search(parent, y)
            if rank[xroot] < rank[yroot]:
                parent[xroot] = yroot
            elif rank[xroot] > rank[yroot]:
                parent[yroot] = xroot
            else:
                parent[yroot] = xroot
                rank[xroot] += 1


        def kruskal(self2):
            
            result = []
            i, e = 0, 0
            self2.graph = sorted(self2.graph, key=lambda item: item[2])
            parent = []
            rank = []
            for node in range(self2.V):
                parent.append(node)
                rank.append(0)
            while e < self2.V - 1:
                v1, v2, w = self2.graph[i]
                i = i + 1
                x = self2.search(parent, v1)
                y = self2.search(parent, v2)
                if x != y:
                    e = e + 1
                    result.append([v1, v2])
                    self2.apply_union(parent, rank, x, y)
                    
            return result
        
    def Prim(self, weight_mat):
        INF = 9999999

        V = len(labels)

        selected = np.zeros(len(labels))

        no_edge = 0

        selected[0] = True

        MST_connections = []

        while (no_edge < V - 1):

            minimum = INF
            x = 0
            y = 0
            for i in range(V):
                if selected[i]:
                    for j in range(V):
                        if ((not selected[j]) and G[i][j]):  
                            # not in selected and there is an edge
                            if minimum > G[i][j]:
                                minimum = G[i][j]
                                x = i
                                y = j
            MST_connections.append([x,y])
            selected[y] = True
            no_edge += 1
            
        return MST_connections
        
        
    def MST(self, data, algorithm):
        
        r"""
        Helper function to read input data and calculate Euclidean distance
        between each possible pair of points before finding MST.
        
        Parameters
        ----------
        
        data : ndarry
            Dataset such that each column corresponds to a point of data.
            
        Returns
        -------
        MST_connections : list
            List of pairs of nodes connected in final MST.
        
        """
        if algorithm == 'Kruskal':
            
            g = self.Graph(len(data[0]-1))

            for i in range(len(data[0])):
                j = i + 1

                while j <= (len(data[0]) - 1):
                    weight = np.linalg.norm(data[:,i] - data[:,j])
                    g.add_edge(i, j, weight)
                    j += 1;

            MST_connections = g.kruskal()
            
        if algorithm == 'Prim':
            
            G = np.zeros((len(data[0]), len(data[0])))

            for i in range(len(data[0])):
                j = i + 1

            while j <= (len(data[0]) - 1):
                weight = np.linalg.norm(data[:,i] - data[:,j])
                G[i][j] = weight
                G[j][i] = weight
                j += 1;
                
            MST_connections = self.Prim(G)
        
        return MST_connections
    
    
    def num_runs(self, labels, MST_connections):
        
        r"""
        Helper function to determine number of independent
        'runs' from MST connections
        
        Parameters
        ----------
        
        labels, MST_connections : ndarry, list
            Lables corresponding to respective classes of points and MST_connections
            defining which pairs of points remain connect in final MST.
            
        Returns
        -------
        run_count : int
            Number of runs after severing all such edges with nodes of
            differing class labels.
        
        """
        
        run_count = 1;

        for x in MST_connections:
            if labels[x[0]] != labels[x[1]]:
                run_count += 1;

        return run_count
    
    
    
    def permutation(self, nperm, labels, MST_connections):
        
        r"""
        Helper function that randomizes labels and calculates respective
        'run' counts for a specified number of iterations.
        
        """
        
        
        runs = []
        for itr in np.arange(nperm):
            
            lab_shuffle = random.sample(labels, len(labels))

            run_val = self.num_runs(lab_shuffle, MST_connections)

            runs.append(run_val)

        return runs
    
    
    
    def test_stat(self, val):
        
        r"""
        Helper function that calculates the Friedman Rafksy test-statistic.
        
        """
        
        sd_val = np.std(val)
        
        return ((val - np.mean(val)) / sd_val)
    
    
    
    def pval(self, perm_stat, true_stat):
        
        r"""
        Helper function that calculates the Friedman Rafksy p-value.
        
        """
        
        pvalue = (np.sum(perm_stat <= true_stat) + 1) / (len(perm_stat) + 1)
        
        return pvalue
    
    
    
    def statistics(self, data, labels, nperm, algorithm):
        
        r"""
        Function to take in data, labels, nperm and return both test-statistic and
        p-value.
        
        Parameters
        ----------
        
        data, labels, nperm : ndarray, ndarray, int
            Data is array such that each column corresponds to a given point. Labels
            are the corresponding data labels for each point. It is of note that
            data and labels must have the same number of columns. Nperm is the
            number of randomized iterations to be used in calculating the test_statistic
            and p-value.
            
        Return
        ------
        
        W_true, p_value : float, float
            Corresponding test-statistic and p-value for the given data, labels, and
            number of randomized permutations used.
            
        """
        num_rows, num_cols = data.shape
        
        if num_cols != len(labels):
            raise IndexError('Number of features and labels not equal')
        
        
        MST_connections = self.MST(data, algorithm)
        
        runs_true = self.num_runs(labels, MST_connections)
        
        runs = self.permutation(nperm, labels, MST_connections)
        
        sd_runs = np.std(runs)
        
        mu_runs = np.mean(runs)
        
        W_perm = self.test_stat(runs)

        W_true = (runs_true - mu_runs) / sd_runs

        pvalue = self.pval(W_perm, W_true)
        
        return W_true, pvalue

In [34]:
data = np.array([[1,5,3,6,4,2,5,3,2,7],
              [2,5,7,1,8,3,3,7,5,2],
              [4,5,3,3,6,5,2,4,3,7],
              [4,5,3,3,6,5,2,4,3,7]])

labels = [1,2,2,1,2,1,2,2,1,1]

test = Friedman_Rafsky()


vals = [test.statistics(data, labels, 10000, 'Prim')[0] for x in range(50)]

print(np.mean(vals) - 3.291 * (np.std(vals) / np.sqrt(len(vals))), np.mean(vals) + 3.291 * (np.std(vals) / np.sqrt(len(vals))))


-0.7014857070600278 -0.6903100313509287
