In [None]:
import os
import numpy as np
import umap
import clip
import torch
from PIL import Image
import plotly.graph_objects as go
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from IPython.display import display, HTML
import pandas as pd

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
def generate_embeddings(image_paths):
    embeddings = []
    for path in image_paths:
        image = preprocess(Image.open(path)).unsqueeze(0).to(device)
        with torch.no_grad():
            embedding = model.encode_image(image)
            embeddings.append(embedding.cpu().numpy())
    return np.vstack(embeddings)

def compute_similarity(embeddings):
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    return cosine_similarity(normalized_embeddings)

def find_similar_images(query_index, similarity_matrix, image_paths, top_n=5):
    similarities = similarity_matrix[query_index]
    similar_indices = np.argsort(-similarities)[:top_n]
    return [(image_paths[i], similarities[i]) for i in similar_indices]

def find_best_kmeans(projections, min_k=5, max_k=20):
    best_k = None
    best_score = -1
    cluster_labels = None

    for k in range(min_k, max_k + 1):
        kmeans = KMeans(n_clusters=k, random_state=42)
        labels = kmeans.fit_predict(projections)
        score = silhouette_score(projections, labels)
        print(f"k={k}, Silhouette Score={score:.4f}")

        if score > best_score:
            best_k = k
            best_score = score
            cluster_labels = labels

    print(f"Best k: {best_k}, Best Silhouette Score: {best_score:.4f}")
    return best_k, cluster_labels

def visualize_clusters(projections, cluster_labels):
    df = pd.DataFrame({
        "x": projections[:, 0],
        "y": projections[:, 1],
        "z": projections[:, 2],
        "cluster": cluster_labels
    })

    fig = go.Figure(data=go.Scatter3d(
        x=df["x"],
        y=df["y"],
        z=df["z"],
        mode='markers',
        marker=dict(size=6, color=df["cluster"], opacity=0.8),
        hovertemplate="<b>Cluster:</b> %{marker.color}<extra></extra>"
    ))

    fig.update_layout(
        title="UMAP + KMeans 3D Clustering",
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
        height=800
    )

    fig.show()

def show_similarity_results_with_index(query_index, similarity_matrix, image_paths, top_n=5):
    similar_images = find_similar_images(query_index, similarity_matrix, image_paths, top_n=top_n)
    
    images, scores = zip(*similar_images)

    print(f"Query Index: {query_index}")
    print(f"Query Image: {image_paths[query_index]}")
    print(f"Top {top_n} Similar Images:")
    for i, (img, score) in enumerate(similar_images):
        print(f"Rank {i + 1}: Index: {image_paths.index(img)} | {img} (Similarity: {score:.4f})")

    html = '<div style="display: flex; flex-wrap: wrap; justify-content: flex-start;">'
    for i, img in enumerate(images):
        html += f"""
        <div style="margin: 10px; text-align: center;">
            <img src="{img}" style="height: 150px; margin-bottom: 5px;" />
            <p style="font-size: 14px; margin: 0;">Index: {image_paths.index(img)}</p>
            <p style="font-size: 14px; margin: 0;">Sim: {scores[i]:.4f}</p>
        </div>
        """
    html += '</div>'

    display(HTML(html))

In [None]:
image_folder = "/home/xuanl/UMAP_local/static"
image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.endswith((".png", ".jpg", ".jpeg"))]

In [None]:
print("Generating CLIP embeddings...")
embeddings = generate_embeddings(image_paths)
print(f"Generated embeddings of shape: {embeddings.shape}")

In [None]:
print("Reducing dimensions with UMAP...")
umap_reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=3, random_state=42)
projections = umap_reducer.fit_transform(embeddings)

In [None]:
print("Performing KMeans clustering...")
best_k, cluster_labels = find_best_kmeans(projections, min_k=5, max_k=20)

In [None]:
print("Computing cosine similarity matrix...")
similarity_matrix = compute_similarity(embeddings)

In [None]:
print("Visualizing 3D clustering with similarity search...")
visualize_clusters(projections, cluster_labels)

In [None]:
query_index = 0  # Replace with the desired index
top_n = 5        # Replace with the desired number of similar images
show_similarity_results_with_index(query_index, similarity_matrix, image_paths, top_n=top_n)