# Clusters Visualization

In [None]:
import os
import json
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from ipywidgets import interact, Layout, Dropdown
from IPython.display import display
import sys
sys.path.append(str(Path("../src").resolve()))
from config import *


def load_clustering_data() -> dict:
    '''
    Load clustering data as a dictionary.
    
    Parameters:
        None
    
    Returns:
        dict: The dictionary in the form {k, data}
    '''
    json_files = list(Path(IMAGES_CLUSTERING_DIRECTORY).glob("*.json"))

    # check if the json is valid (to do)...

    def get_k_from_json(json_file):
        # get k value from a json
        with open(json_file, "r") as f:
            data = json.load(f)
        return data.get("k", 0)

    # sort for k
    json_files = sorted(json_files, key=get_k_from_json)

    for json_file in json_files:
        with open(json_file, "r") as f:
            data = json.load(f)
        clustering_data[data['k']] = data

    return clustering_data



def find_image_path(image_name: str) -> str | None:
    '''
    Find the complete path given the image name.
    
    Parameters:
        image_name (str): The name of the image
    
    Returns:
        str o None: Complete path if found, None otherwise
    '''
    for ext in IMAGES_EXTENSIONS:
        for candidate_ext in [ext.lower(), ext.upper()]:
            candidate_path = os.path.join("../", IMAGES_DIRECTORY, image_name + candidate_ext)
            if os.path.exists(candidate_path):
                return candidate_path
    return None


def show_cluster_samples(cluster_samples: list, max_per_row=5) -> None:
    '''
    Show the samples of a cluster.
    
    Parameters:
        cluster_samples (list): The list of the image's names
        max_per_row (int):  The amount of images to show for every row
    
    Returns:
        None:
    '''
    if not all(isinstance(x, str) for x in cluster_samples):
        raise ValueError("Not valid cluster_samples: cluster_samples has to be a list of names")

    n = len(cluster_samples)
    n_cols = min(max_per_row, n)
    n_rows = (n + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*3, n_rows*3))
    axes = axes.flatten() if n > 1 else [axes]

    for ax, img_name in zip(axes, cluster_samples):
        image_path = find_image_path(img_name)
        if image_path:
            img = Image.open(image_path)
            ax.imshow(img)
        else:
            ax.text(0.5, 0.5, f"File {img_name} not found", ha='center', va='center', color='red')
        ax.set_title(img_name, fontsize=8)
        ax.axis('off')

    for ax in axes[len(cluster_samples):]:
        ax.axis('off')
    plt.tight_layout()
    plt.show()


def visualize_cluster( k: int, cluster_id: int) -> None:
    '''
    Visualize a specified cluster (cluster_id) in a specified clustering (k).
    
    Parameters:
        k (int): The number of clusters, that determines the clustering
        cluster_id (int):  The id of the cluster to visualize
    
    Returns:
        None:
    '''
    if k < MIN_N_CLUSTERS or k > MAX_N_CLUSTERS :
        raise ValueError(f"Not valid k: k not in [{MIN_N_CLUSTERS}, {MAX_N_CLUSTERS}]")
    
    if cluster_id < 0 or cluster_id >= k:
        raise ValueError(f"Not valid cluster_id: cluster id not in [0, {k - 1}]")

    data = clustering_data[k]
    samples = data['samples']
    inertia_score = data.get("inertia_score")
    silhouette_score = data.get("silhouette_score")
    cluster_counts = data.get("cluster_counts")

    print(f"Clustering for K={k}")
    print(f"Inertia Score: {inertia_score}")
    print(f"Silhouette Score: {silhouette_score}")
    print(f"Cluster counts: {cluster_counts}")

    cluster_samples = samples[cluster_id]
    print(f"\nCluster {cluster_id} ({cluster_counts[cluster_id]} images)\n")
    show_cluster_samples(cluster_samples)


# ==============================================================

# load all json in a dictionary {k, data}
clustering_data = load_clustering_data()

# widget interface to choose desired k and cluster
k_widget = Dropdown(options=sorted(clustering_data.keys()), description='Select K:')
cluster_widget = Dropdown(options=[], description='Cluster ID')

def update_widgets(*args):
    k = k_widget.value
    cluster_widget.options = range(len(clustering_data[k]['samples']))
    cluster_widget.value = 0

k_widget.observe(update_widgets, 'value')
update_widgets()

interact(visualize_cluster, k=k_widget, cluster_id=cluster_widget)
