In [None]:
import fast_mst as fm
import matplotlib.pyplot as plt
import numpy as np
import sklearn.datasets as dsets
from dataeval._internal.metrics.utils import minimum_spanning_tree as old_mst

In [None]:
small = np.load('./data/clusterable_data.npy')
# Create scatter plot
plt.figure(figsize=(20,20))
plot_kwds = {"alpha": 0.5, "s": 50, "linewidths": 0}
plt.scatter(small.T[0], small.T[1], **plot_kwds)
plt.show()

In [None]:
small, small_label = dsets.make_circles(
    n_samples=20000, factor=0.5, noise=0.05, random_state=30
)
# Create scatter plot
plt.figure(figsize=(20,20))
plot_kwds = {"alpha": 0.5, "s": 50, "linewidths": 0}
plt.scatter(small.T[0], small.T[1], **plot_kwds)

# Annotate each point in the scatter plot
# that = np.concatenate(edge_points)
# for i, (x, y) in enumerate(small[that,:2]):
#     plt.annotate(str(that[i]), (x, y), textcoords="offset points", xytext=(0, 1), ha="center")

plt.show()

In [None]:
blob, blob_label = dsets.make_blobs(  # type: ignore
    n_samples=8000,
    n_features=2,#048,
    centers=4,
    center_box=(-250,200),
    cluster_std=35,
    random_state=31,
)
blob2, blob_label2 = dsets.make_blobs(  # type: ignore
    n_samples=1500,
    n_features=2,#048,
    centers=1,
    center_box=(300,350),
    cluster_std=50,
    random_state=35,
)
blob3, blob_label3 = dsets.make_blobs(  # type: ignore
    n_samples=500,
    n_features=2,#048,
    centers=1,
    center_box=(-350,-400),
    cluster_std=25,
    random_state=33,
)
small = np.concatenate([blob,blob2,blob3])
small_label = np.concatenate([blob_label, blob_label2+4, blob_label3+5])

In [None]:
# Mapping from labels to colors
label_to_color = np.array(["b", "r", "g", "y", "m", 'c'])

# Translate labels to colors using vectorized operation
color_array = label_to_color[small_label]

# Additional parameters for plotting
plot_kwds = {"alpha": 0.5, "s": 50, "linewidths": 0}

plt.figure(figsize=(20,20))

# Create scatter plot
plt.scatter(small.T[0], small.T[1], c=color_array, **plot_kwds)

# Annotate each point in the scatter plot
# check = np.nonzero(edge_points)[0]
# for i, (x, y) in enumerate(small[check,:2]):
#     plt.annotate(str(check[i]), (x, y), textcoords="offset points", xytext=(0, 1), ha="center")

In [None]:
neighbors, distances = fm.calculate_neighbor_distances(small, 20)

In [None]:
mst = fm.minimum_spanning_tree(small, neighbors, distances)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.collections as mc

xs = small[:, 0]
ys = small[:, 1]
sources = mst[:, 0].astype(np.int64)
targets = mst[:, 1].astype(np.int64)

plt.figure(figsize=(20,16))
plt.scatter(xs, ys, c=color_array, s=25, edgecolors="none", linewidth=0)
lc = mc.LineCollection(
    list(zip(zip(xs[sources], ys[sources]), zip(xs[targets], ys[targets]))),
    linewidth=1,
    zorder=-1,
    alpha=0.5,
    color="k",
)
ax = plt.gca()
ax.add_collection(lc)
ax.set_aspect("equal")
plt.subplots_adjust(0, 0, 1, 1)
plt.axis("off")

# Annotate each point in the scatter plot
# for i, (x, y) in enumerate(small[:,:2]):
#     plt.annotate(str(i), (x, y), textcoords="offset points", xytext=(0, 1), ha="center")

plt.show()

In [None]:
# This cell only exists for speed up comparison
import time

if small.shape[0] <= 10000:
    print(f"Comparison using (n_samples, n_features): {small.shape}")
    print("Old MST algorithm")
    start = time.time()
    old_result = old_mst(small)
    print((time.time()-start), "secs")

    
    print("New MST algorithm")
    start = time.time()
    neighbors, distances = fm.calculate_neighbor_distances(small, 20)
    new_result = fm.minimum_spanning_tree(small, neighbors, distances)
    print((time.time()-start), "secs")