<a href="https://colab.research.google.com/github/Dennis-Farias/cursoIAeML/blob/main/KMeans.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
from sklearn import datasets
from sklearn.metrics import confusion_matrix
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, dendrogram

In [None]:
iris = datasets.load_iris()
iris

In [1]:
def plot_clusters(data, labels, title):
  colors = ['red', 'green', 'purple', 'black']
  plt.figure(figsize=(8,4))
  for i, c, l in zip(range(-1,3), colors, ['Noise','Setosa', 'Versicolor', 'Virginica']):
    if i == -1:
      plt.scatter(data[labels==i, 0], data[labels==i, 3], c=colors[i], label=l, alpha=0.5, s=50, marker='x')
    else:
      plt.scatter(data[labels==i, 0], data[labels==i, 3], c=colors[i], label=l, alpha=0.5, s=50)
  plt.legend()
  plt.title(title)
  plt.xlabel('Comprimento da Sépala')
  plt.ylabel('Largura da Pétala')
  plt.show()

In [None]:
iris.target

In [None]:
kmeans = KMeans(n_clusters=3, n_init='auto')
kmeans.fit(iris.data)
print(kmeans.labels_)

In [None]:
resultados = confusion_matrix(iris.target, kmeans.labels_)
print(resultados)

In [None]:
plot_clusters(iris.data, kmeans.labels_, 'Cluster KMeans')

In [None]:
dbscan = DBSCAN(eps=0.5, min_samples=3)
dbscan_labels = dbscan.fit_predict(iris.data)
print(dbscan_labels)

In [None]:
plot_clusters(iris.data, dbscan_labels, 'Cluster DBSCAN')

In [None]:
agglo = AgglomerativeClustering(n_clusters=3)
agglo_labels = agglo.fit_predict(iris.data)
print(agglo_labels)

In [None]:
resultados = confusion_matrix(iris.target, agglo_labels)
print(resultados)

In [None]:
plot_clusters(iris.data, agglo_labels, 'Cluster Hierárquico')

In [None]:
plt.figure(figsize=(12,6))
plt.title("Cluster Hierárquico: Dendograma")
plt.xlabel("Índice")
plt.ylabel("Distância")
linkage_matrix = linkage(iris.data, method='ward')
dendrogram(linkage_matrix, truncate_mode='lastp', p=15)
plt.axhline(y=7, color='gray', lw=1, linestyle='dashed')
plt.show()