In [None]:
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import re
import pandas as pd
import random
import webcolors

from sklearn.cluster import KMeans, MeanShift, estimate_bandwidth
from sklearn.metrics import silhouette_score, silhouette_samples

from tqdm import tqdm

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE


# Función para cargar imágenes desde un directorio dado
def load_images(root_path):
    files = os.listdir(root_path)
    random.shuffle(files)
    images_original = []
    for file in files:
        img = cv2.imread(os.path.join(root_path, file))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        images_original.append({'image': img, 'name': file})
    return images_original

# Función para preprocesar las imágenes (escalar, normalizar, cambiar tamaño)
def pre_process_image(images_original, flatten=False, normalize=False, target_size=None):
    images_process = []
    for data in images_original:
        height, width = data['image'].shape[:2]
        if target_size is not None:
            ratio = float(target_size) / max(height, width)
            new_height = int(height * ratio)
            new_width = int(width * ratio)
            image = cv2.resize(data['image'], (new_width,new_height))

        if flatten:
            image = image.reshape((-1, 3))
        if normalize:
            image = image / 255.0
        images_process.append(image)
    return images_process


# Función para aplicar el algoritmo de K-means a una imagen
def kmeans_image(img,num_clusters):
    kmeans = KMeans(n_clusters=num_clusters,random_state=42)
    kmeans.fit(img)
    return kmeans.cluster_centers_, kmeans.labels_

# Función para convertir un color RGB a su código hexadecimal
def rgb_to_hex(rgb):
    return '#{:02x}{:02x}{:02x}'.format(int(rgb[0]), int(rgb[1]), int(rgb[2]))

# Función para aplicar PCA o t-SNE a una imagen y visualizarla
def dimensionality_reduction(img, colors, tsne=False,p=3):
    if not tsne:
        pca = PCA(n_components=2)
        x_train_reduced = pca.fit_transform(img)
    else:
        tsne = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=1)
        x_train_reduced = tsne.fit_transform(img)
    print('>>',x_train_reduced.shape)

    colors = colors * 255
    colors = list(map(rgb_to_hex,colors))
    # colors
    color_dict = {}
    for i in colors:
        color_dict[i] = i
    color_dict
    colors = np.array(colors)

    for category in np.unique(colors):
        mask = colors == category
        # print(mask)
        plt.scatter(x_train_reduced[:,0][mask], x_train_reduced[:,1][mask], label=category, color=color_dict[category],edgecolors='black')

    plt.xlabel('X-axis Label')
    plt.ylabel('Y-axis Label')
    plt.title('Scatter Plot with Categories')
    plt.legend()  # Show legend

    # Show plot
    plt.show()
    return x_train_reduced

def rgb_to_color_hex(rgb_tuple):
    try:
        color_name = webcolors.rgb_to_hex(rgb_tuple)
    except ValueError:
        color_name = 'Unknown'
    return color_name

# Función para visualizar la paleta de colores de una imagen junto con sus colores predominantes
def draw_image_palette(image_original, centroids=None):
    #for data in images_original:
    palette = centroids * 255
    title = image_original['name'][:-3]
    title = re.sub(r'[^A-Za-z0-9\-]+', ' ', title).upper()
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title(f"{title}")
    plt.imshow(image_original['image'])
    plt.axis('off')
    
    if centroids is not None:
        plt.subplot(1, 2, 2)
        plt.title('Palette')
        palette_colors = []
        color_names = []
        for color in palette:
            palette_colors.append(color)
            color_name = rgb_to_color_hex(tuple(color.astype(int)))
            color_names.append(color_name)
        plt.imshow(np.expand_dims(palette_colors, axis=0).astype(np.uint8))
        plt.xticks([])  # Remove x-axis ticks
        plt.yticks([])  # Remove y-axis ticks
        for i, name in enumerate(color_names):
            plt.text(i, .5, name, ha='center', va='top', color='black', fontsize=8)
        plt.axis('off')

    plt.tight_layout()
    plt.show()


# Función para visualizar una serie de colores
def plot_colors(colors):
    plt.figure(figsize=(8, 6))
    for i in range(len(colors)):
        color_swatch = np.zeros((100, 100, 3))
        color_swatch[:, :, :] = colors[i]
        plt.subplot(1, len(colors), i + 1)
        plt.imshow(color_swatch)
        plt.axis('off')
    plt.show()

# Función para el método del codo
def elbow_method(data, max_clusters=10):
    distortions = []
    for i in range(1, max_clusters + 1):
        kmeans = KMeans(n_clusters=i, random_state=42)
        kmeans.fit(data)
        distortions.append(kmeans.inertia_)
    # Plotting the elbow method graph
    plt.plot(range(1, max_clusters + 1), distortions, marker='o')
    plt.xlabel('Number of clusters')
    plt.ylabel('Distortion')
    plt.title('Elbow Method')
    plt.show()

# Función para el análisis de silueta
def silhouette_analysis(data, max_clusters=10):
    silhouette_scores = []
    for i in range(2, max_clusters + 1):
        kmeans = KMeans(n_clusters=i, random_state=42)
        cluster_labels = kmeans.fit_predict(data)
        silhouette_avg = silhouette_score(data, cluster_labels)
        silhouette_scores.append(silhouette_avg)
    # Plotting silhouette scores
    plt.plot(range(2, max_clusters + 1), silhouette_scores, marker='o')
    plt.xlabel('Number of clusters')
    plt.ylabel('Silhouette Score')
    plt.title('Silhouette Analysis')
    plt.show()


# Función para graficar el coeficiente de silueta para diferentes números de clusters
def silhouette_plot(data, max_clusters=10):
    scores = []
    for i in range(2, max_clusters + 1):
        model_k = KMeans(n_clusters=i, n_init=10, random_state=42)
        # Entrenamos el modelo
        model_k.fit(data)
        # Almacenamos el coeficiente de la silueta
        score = silhouette_score(data, model_k.labels_)
        scores.append(score)
    # Mostramos los valores de los coeficientes
    display(pd.DataFrame({'K': range(2, max_clusters+1), 'Coeficiente': scores}))
    # Graficamos los valores del coeficiente de la silueta
    plt.plot(range(2, max_clusters+1), scores, marker='o')
    plt.xlabel('Número de clústeres')
    plt.ylabel('Silhouette Score')
    plt.grid()
    plt.show()

root_path = '/home/satoru/repos/u_andes/maia/mlns/micro_projects/one/sample'
images_original = load_images(root_path)
images_process = pre_process_image(images_original, flatten=True, normalize=True, target_size=100)

num_clusters = 7
num_images = 2
processed_images = 0

for i, image in enumerate(images_process):
    print(f"****************************************{i}************************************")
    centroids,labels = kmeans_image(images_process[i],num_clusters)
    draw_image_palette(images_original[i], centroids)
    # plot_colors(centroids)
    x_train_reduced = dimensionality_reduction(images_process[i],centroids[labels])

    elbow_method(x_train_reduced)
    silhouette_plot(x_train_reduced)
    processed_images += 1
    if processed_images >= num_images:
        break