In [24]:
import numpy as np

# Distance definition
def calculate_squared_distance(point, centroid):
    return np.sum((point - centroid) ** 2)

In [25]:
# Caculate the centroid position with given membership
# Data points
points = np.array([[1.0, 3.0], [2.0, 5.0], [4.0, 8.0], [7.0, 9.0], [9.0, 12.0]])

# Initial membership
membership = np.array([[0.8, 0.7, 0.5, 0.3, 0.1],
                       [0.2, 0.3, 0.5, 0.7, 0.9]]).T

#M-step: Update centroid position
for i in range(2):
    # Centroid that will optimizate SSE
    centroids_numerator = np.sum(membership[:, i][:, None] ** 2 * points, axis=0).astype(float)
    centroids_denominator = np.sum(membership[:, i]** 2).astype(float)
    centroids[i] = centroids_numerator / centroids_denominator

print("Centroids with given membership:")
print(centroids)

Centroids with given membership:
[[2.25675676 4.93243243]
 [7.10714286 9.94047619]]


In [26]:
# Iterative Execution of EM Algorithm
iter_time = 0
while 1:
    iter_time += 1
    
    # Calculate the square distance from each point to each center point
    squared_distances = np.array([[calculate_squared_distance(point, centroid) \
                                   for centroid in centroids] \
                                  for point in points])
    
    # E-step: Update membership based on square distance
    new_membership_denominator = np.sum(squared_distances, axis=1)

    new_membership_numerator = squared_distances
    new_membership_numerator[:, [0, 1]] = new_membership_numerator[:, [1, 0]]

    new_membership =  new_membership_numerator/ new_membership_denominator[:, None]
    
    mean_change_of_membership = np.mean(np.abs(new_membership-membership))
    membership = new_membership
    
    #M-step: Update centroid position
    for i in range(len(centroids)):
        # New centroid that will optimizate SSE
        centroids_numerator = np.sum(new_membership[:, i][:, None] ** 2 * points, axis=0).astype(float)
        centroids_denominator = np.sum(new_membership[:, i]** 2).astype(float)
        centroids[i] = centroids_numerator / centroids_denominator
    
    print(f"Iteration {iter_time}:")
    print("Distance:")
    print(squared_distances)
    print("New Membership:")
    print(membership)
    print("New Centroids:")
    print(centroids)
    print("Mean change of Menbership:")
    print(mean_change_of_membership)
    print("----------")
    
    # Two shutdown criteria
    if mean_change_of_membership < 0.00001:
        break
    if iter_time > 20:
        break

Iteration 1:
Distance:
[[8.54674036e+01 5.31373265e+00]
 [5.04912132e+01 7.04894083e-02]
 [1.34197846e+01 1.24488678e+01]
 [8.95975057e-01 3.90434624e+01]
 [7.82454649e+00 9.54218408e+01]]
New Membership:
[[0.94146655 0.05853345]
 [0.99860587 0.00139413]
 [0.51876628 0.48123372]
 [0.02243334 0.97756666]
 [0.07578518 0.92421482]]
New Centroids:
[[ 1.85854048  4.57240718]
 [ 7.48562706 10.12986197]]
Mean change of Menbership:
0.15212403662985943
----------
Iteration 2:
Distance:
[[ 92.89829012   3.20955612]
 [ 56.40758811   0.20284641]
 [ 16.68590803  16.33424136]
 [  1.51242172  46.03818409]
 [  5.79074164 106.16957904]]
New Membership:
[[0.96660464 0.03339536]
 [0.9964168  0.0035832 ]
 [0.50532503 0.49467497]
 [0.03180657 0.96819343]
 [0.05172137 0.94827863]]
New Centroids:
[[ 1.81711109  4.50607855]
 [ 7.50785987 10.17469156]]
Mean change of Menbership:
0.014841089128195067
----------
Iteration 3:
Distance:
[[ 93.82843918   2.93594314]
 [ 57.11395318   0.27740675]
 [ 17.0343643   16.9