In [None]:
# env=research

In [None]:
# https://www.kaggle.com/code/fareselmenshawii/kmeans-from-scratch

In [None]:
# TODO: add how to predict and run on new data

### imports

In [18]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go

### read data

In [19]:
df = pd.read_csv("Iris.csv")
df.drop("Id", inplace=True, axis=1)
print(df.shape)
df.head()

(150, 5)


Unnamed: 0,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
0,5.1,3.5,1.4,0.2,Iris-setosa
1,4.9,3.0,1.4,0.2,Iris-setosa
2,4.7,3.2,1.3,0.2,Iris-setosa
3,4.6,3.1,1.5,0.2,Iris-setosa
4,5.0,3.6,1.4,0.2,Iris-setosa


In [20]:
x = df[["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"]]
y = df["Species"]

y.value_counts(normalize=True)

Species
Iris-setosa        0.333333
Iris-versicolor    0.333333
Iris-virginica     0.333333
Name: proportion, dtype: float64

### eda

In [21]:
fig = px.box(df, x="Species", y="SepalLengthCm", color="Species")
fig.show()

In [22]:
fig = px.box(df, x="Species", y="PetalLengthCm", color="Species")
fig.show()

In [23]:
fig = px.box(df, x="Species", y="SepalWidthCm", color="Species")
fig.show()

In [24]:
fig = px.box(df, x="Species", y="PetalWidthCm", color="Species")
fig.show()

In [32]:
fig = px.scatter(df, x="PetalLengthCm", y="PetalWidthCm", color="Species")
fig.update_layout(title="Petal Length vs Petal Width", xaxis_title="Petal Length", yaxis_title="Petal Width")
fig.show()

- setosa has clearly smaller petal length and sligthly wider petal width than other two
- sepal width is not that useful feature
- its bit difficult to distinguish between versicolor and virginica
- virginica may have outliers

### train

In [25]:
class KMeans:
    def __init__(self, k: int):
        self.k = k
    
    def initialize_centroids(self, x: np.ndarray):
        # select k points from the dataset and assign them as centroids
        self.centroids = x[np.random.choice(x.shape[0], self.k, replace=False)]
    
    def assign_points_centroids(self, x: np.ndarray):
        # assign each point to the closest centroid
        x = np.expand_dims(x, axis=1)
        distance = np.linalg.norm((x - self.centroids), axis=-1)
        points = np.argmin(distance, axis=1)
        return points
    
    def compute_mean(self, x: np.ndarray, points: np.ndarray) -> np.ndarray:
        # calculate new centroids by computing mean of the points
        centroids = np.zeros((self.k, x.shape[1]))
        for i in range(self.k):
            centroids[i] = np.mean(x[points == i], axis=0)
        return centroids
    
    def fit(self, x, epochs):
        # cluster the data by assigning each point to closest centroid -> recompute centroids -> repeat
        self.initialize_centroids(x)
        for i in range(epochs):
            points = self.assign_points_centroids(x)
            self.centroids = self.compute_mean(x, points)
        return self.centroids, points

In [26]:
x = x.values
kmeans = KMeans(k=3)
centroids, points = kmeans.fit(x, epochs=1000)

In [27]:
points

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2,
       2, 2, 2, 0, 0, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 2, 0, 0, 2, 2, 2, 2,
       2, 0, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 0])

### viz

In [36]:
fig = go.Figure()

# add points
fig.add_trace(
    go.Scatter(x=x[points == 0, 0], y=x[points == 0, 1], mode="markers", name="setosa")
)
fig.add_trace(
    go.Scatter(x=x[points == 1, 0], y=x[points == 1, 1], mode="markers", name="versicolour")
)
fig.add_trace(
    go.Scatter(x=x[points == 2, 0], y=x[points == 2, 1], mode="markers", name="virginica")
)

# add centroids
fig.add_trace(
    go.Scatter(x=centroids[:, 0], y=centroids[:, 1], mode="markers", marker_symbol=4, marker_size=13, name="centroids")
)

fig.show()
