In [53]:
import numpy as np
import matplotlib.pyplot as plt
import math
%matplotlib inline

In [None]:
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler

centers = [[1, 1], [-1, -1], [1, -1]]
X, labels_true = make_blobs(
    n_samples=750, centers=centers, cluster_std=0.4, random_state=0
)

X = StandardScaler().fit_transform(X)

x = X[:, 0]
y = X[:, 1]

n = len(x)
plt.scatter(x, y)
plt.show()

In [71]:
# # define a random point cloud

# n = 100
# x = np.random.rand(n)
# y = np.random.rand(n)

# # plot the point cloud
# plt.scatter(x, y)
# plt.show()

In [72]:
# union find data structure

class UnionFind:
    def __init__(self, n):
        self.parent = np.arange(n)
        self.rank = np.zeros(n)
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        x_root = self.find(x)
        y_root = self.find(y)
        
        if x_root == y_root:
            return
        
        if self.rank[x_root] < self.rank[y_root]:
            self.parent[x_root] = y_root
        elif self.rank[x_root] > self.rank[y_root]:
            self.parent[y_root] = x_root
        else:
            self.parent[y_root] = x_root
            self.rank[x_root] += 1
            
    def num_components(self):
        return np.sum(self.parent == np.arange(len(self.parent)))

In [73]:
# define a function to compute the distance between two points
def dist2d(a, b):
    return np.sqrt((x[a] - x[b])**2 + (y[a] - y[b])**2)

In [74]:
def plotGraph(dist):
    plt.scatter(x, y)
    for i in range(n):
        for j in range(i+1, n):
            if dist2d(i, j) < dist:
                plt.plot([x[i], x[j]], [y[i], y[j]], 'k-')
    plt.show()

In [None]:
distances = np.arange(0, 1.05, 0.01)[1:]

u = UnionFind(n)

results = []

for iteration, dist in enumerate(distances):
    for i in range(n):
        for j in range(i+1, n):
            if dist2d(i, j) <= dist:
                u.union(i, j)
    
    num_components = u.num_components()
    if num_components <= 1:
        break
    
    results.append((dist, num_components))
    if iteration % math.floor(len(distances) / 150) == 0:
        plotGraph(dist)

In [None]:
plt.plot(*zip(*results))
plt.xlabel('distance')
plt.ylabel('number of components')
plt.show()