In [1]:
import pytest
import numpy as np
from mst import Graph
from sklearn.metrics import pairwise_distances

In [2]:
def check_mst(adj_mat: np.ndarray, 
              mst: np.ndarray, 
              expected_weight: int, 
              allowed_error: float = 0.0001):
    """
    
    Helper function to check the correctness of the adjacency matrix encoding an MST.
    Note that because the MST of a graph is not guaranteed to be unique, we cannot 
    simply check for equality against a known MST of a graph. 

    Arguments:
        adj_mat: adjacency matrix of full graph
        mst: adjacency matrix of proposed minimum spanning tree
        expected_weight: weight of the minimum spanning tree of the full graph
        allowed_error: allowed difference between proposed MST weight and `expected_weight`

    TODO: Add additional assertions to ensure the correctness of your MST implementation. For
    example, how many edges should a minimum spanning tree have? Are minimum spanning trees
    always connected? What else can you think of?

    """

    def approx_equal(a, b):
        return abs(a - b) < allowed_error

    total = 0
    for i in range(mst.shape[0]):
        for j in range(i+1):
            total += mst[i, j]
            
    # assert approx_equal(total, expected_weight), 'Proposed MST has incorrect expected weight'
    print((total,expected_weight))

In [3]:
file_path = './data/slingshot_example.txt'
coords = np.loadtxt(file_path) # load coordinates of single cells in low-dimensional subspace
dist_mat = pairwise_distances(coords) # compute pairwise distances to form graph


In [4]:
dist_mat

array([[ 0.        ,  0.76216519,  2.1732807 , ..., 13.29560662,
        14.89934831, 13.33280807],
       [ 0.76216519,  0.        ,  1.4111224 , ..., 12.5830895 ,
        14.18923971, 12.61080021],
       [ 2.1732807 ,  1.4111224 ,  0.        , ..., 11.28408483,
        12.89327348, 11.29073429],
       ...,
       [13.29560662, 12.5830895 , 11.28408483, ...,  0.        ,
         1.60955581,  0.47860832],
       [14.89934831, 14.18923971, 12.89327348, ...,  1.60955581,
         0.        ,  1.69297641],
       [13.33280807, 12.61080021, 11.29073429, ...,  0.47860832,
         1.69297641,  0.        ]])

In [5]:
g = Graph(dist_mat)
g.construct_mst()


16.899348313145875


In [7]:
check_mst(g.adj_mat, g.mst, 57.263561605571695)

(1.124194420326916, 57.263561605571695)


In [8]:
(g.mst[g.mst>0])

array([0.27976269, 0.48574831, 0.27976269, 0.35868341, 0.48574831,
       0.35868341])

In [18]:
np.max(g.adj_mat)

14.899348313145875