<a href="https://colab.research.google.com/github/M-H-Amini/MachineLearning-AUT/blob/master/MLe_Lec6_KMeans.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# In The Name Of ALLAH
# Machine Learning *elementary* Course
## Amirkabir University of Technology
### Mohammad Hossein Amini (mhamini@aut.ac.ir)
# Lecture 6 - K-Means Clustering

<img src="https://drive.google.com/uc?id=144SDpgv7EEy6Og1ZFNIv_nBaugKGiSCE" width="400">



# Introduction

The theoretical stuff has been discussed in the video lectures. Let's implement a little...

We try to compress an image by applying clustering on its colors.

First of all, we should import some modules.

In [0]:
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import cv2

Let's import the image. I use the image of my favorite professor, **Gilbert Strang**.

In [0]:
image = cv2.imread('gilbert.jpeg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [0]:
print(image.shape)

Now we flatten the image and see how many unique colors does it have.

In [0]:
colors = np.reshape(image, (image.shape[0]*image.shape[1], image.shape[2]))
unique_colors = np.unique(colors, axis=0)
print('No of unique colors: ', unique_colors.shape[0])

In [0]:
print(colors)
print(colors.shape)

# K-Means Clustering

Using and implementing K-Means Clustering is so easy in *sklearn*. Let's see.

In [0]:
kmeans = KMeans(n_clusters=10, random_state=0).fit(colors)

Clustering is done! Let's see the centers and labels.

In [0]:
centers = kmeans.cluster_centers_.astype(np.uint8)
labels = kmeans.labels_

In [0]:
#print(centers)
print(labels)

# Transforming The Image

Now, it only suffices to replace each color in the original image with its corresponding cluster center.

In [0]:
transformed_flattened_image = np.array([centers[labels[i],:] for i in range(labels.shape[0])])
transformed_image = np.reshape(transformed_flattened_image, (image.shape[0], image.shape[1], image.shape[2]))

Let's see what we've gained.

In [0]:
plt.figure()
plt.imshow(transformed_image)
plt.show()

# compress Function
To make it simple, it's better to define a function and do all the above stuff there.

In [0]:
def compress(image_name, no_of_colors = 10):
  image = cv2.imread(image_name)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  colors = np.reshape(image, (image.shape[0]*image.shape[1], image.shape[2]))
  kmeans = KMeans(n_clusters=no_of_colors, random_state=0).fit(colors)
  centers = kmeans.cluster_centers_.astype(np.uint8)
  labels = kmeans.labels_
  transformed_flattened_image = np.array([centers[labels[i],:] for i in range(labels.shape[0])])
  transformed_image = np.reshape(transformed_flattened_image, (image.shape[0], image.shape[1], image.shape[2]))
  return transformed_image

Also, for the ease of use, we define **show** function.

In [0]:
def show(image, title='Image'):
  plt.figure(figsize=(10, 8))
  plt.imshow(image)
  plt.title(title)
  plt.axis('off')
  plt.show()

Now, it's as simple as a single line!

In [0]:
image1 = compress('gilbert.jpeg', 5)
show(image1)
show(image)

# compare Function
Finally, for the educational purpose! we define the **compare** function.

In [0]:
def compare(image_name, no_of_colors_list = [5, 10, 15]):
  image = cv2.imread(image_name)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  plt.figure(figsize=(12, 9))
  plt.subplot(2, 2, 1)
  plt.imshow(image)
  plt.title('Original')
  plt.axis('off')
  plt.subplot(2, 2, 2)
  print('Compressing 1 of 3...')
  image1 = compress(image_name, no_of_colors_list[0])
  plt.imshow(image1)
  plt.title('Colors: {}'.format(no_of_colors_list[0]))
  plt.axis('off')
  plt.subplot(2, 2, 3)
  print('Compressing 2 of 3...')
  image2 = compress(image_name, no_of_colors_list[1])
  plt.imshow(image2)
  plt.title('Colors: {}'.format(no_of_colors_list[1]))
  plt.axis('off')
  plt.subplot(2, 2, 4)
  print('Compressing 3 of 3...')
  image3 = compress(image_name, no_of_colors_list[2])
  plt.imshow(image3)
  plt.title('Colors: {}'.format(no_of_colors_list[2]))
  plt.axis('off')
  print('Compressions done!')
  plt.show()

In [0]:
compare('gilbert.jpeg', [2, 10, 30])