In [None]:
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import cdist
from sklearn.model_selection import train_test_split

# K-means

In this notebook, we will explore the k-means algorithm in its simplest form.
The data can be generated as in the previous notebook.

In [None]:
X,y,centers = make_blobs(
    n_samples=1000,
    n_features=2,
    return_centers=True,
)

#TODO: Shuffle and split the data

# Plot the training data
plt.scatter(X_train[:,0], X_train[:,1], c=y_train)

plt.scatter(centers[:,0], centers[:,1], c='r', s=120)
plt.legend(["data", "centers"])
plt.xlabel(r"$x_1$")
plt.ylabel(r"$x_2$")

You can proceed in the same way as for KNN. However:
- In the constructor, we will save the number of clusters, as well as the maximum number of iterations and the tolerance.
- In the fit method, we will:
    - Initialize the centroids randomly from the data points
    - For each iteration:
        - Assign each data point to the nearest centroid
        - Update the centroids by taking the mean of the assigned points
        - Check for convergence by checking the decrease in loss (sum of squared distances to the nearest centroid)
- In the predict method, we will assign each data point to the nearest centroid.

Use, as presented in the lecture, matrix $\gamma$ to keep track of the assignments.

In [None]:
class kmeans_demo:
    def __init__(self, K, max_iters=200, tol=1e-8):
        pass

    def fit(self,X):
        pass

    def predict(self, X):
        pass

Once this is done, we can instantiate the object and call its `fit` method on the data.

In [None]:
# TODO: Instantiate, fit and predict

The obtained solution is **not** the unique solution of the problem. The optimization procedure converges to a __local minimum__. As such, it is sensitive to initial conditions

<details>
    <summary>What influences the final result?</summary>
    The choice of initial values of `C` is the *only* random parameter that we initialize
</details>

<details>
    <summary>What can be done about it?</summary>
    We could start from different random positions! Then, select the best model out of those we obtained.
</details>


Create a class that applies k-means several times, and saves the best model (the one with the lowest loss).

In [None]:
class kmeans_demo_multiple:
    def __init__(self, K, repeats=10, max_iters=200, tol=1e-8):
        self.K = K
        self.repeats = repeats
        self.max_iters=max_iters
        self.tol=tol

# TODO: Implement the required methods

In [None]:
# Instantiate and fit the model

## How to know which $K$ to use?

So far, we assumed that we have a reasonable intuition for `K`. But it may not be the case.
Let us create a new example, with more centroids, and this time plot it without the labels. See if you can guess the true number of centroids (which is 5).

Tip: run the following cell several time, as it will produce always a different result

In [None]:
X2,y2 = make_blobs(
    n_samples=1000,
    n_features=2,
    centers=5
)
plt.scatter(X2[:,0], X2[:,1])
plt.legend(["data"])
plt.xlabel(r"$x_1$")
plt.ylabel(r"$x_2$")

In those cases, we can proceed in the following way:

- select a set of possible values for `K`
- fit k-means with each of them and keep track of the error
- plot the error for each `K`

In [None]:
Ks = [2,3,4,5,6,7,8,9,10]
errors = []

for k in Ks:
    m = kmeans_demo_multiple(K=k, repeats=20)
    m.fit(X2)
    errors.append(m.error)

plt.plot(Ks, errors)
plt.xlabel("K")
plt.ylabel("Error")

In most cases, it is possible to empirically observe how after a certain value of `K`, further increases lead to marginal decrease in the error

<details>
<summary>What is the maximal value for K?</summary>
The maximal value for K is the number of data points.
</details>

<details>
<summary>Why? What happens at the maximal value for K?</summary>
If K is equal to the number of data points, there exist one centroid for each data point. Any additional centroid would never receive any affiliation because of that.
</details>

<details>
<summary>What about the error?</summary>
If K is equal to the number of data points, the error is 0, because each centroid is overlapping with its respective data point, and therefore there is no error</details>