# Cluster Photos

This notebook leverages learned features and out-of-the-box clustering techniques to cluster photos with the intent of finding a representative sample.

## Instantiate Autoencoder

Let's create our autoencoder and do preliminary work to process our images.

In [16]:
import torch

from photo_encode.encoder import FaissIndexImageEncoder
from transformers import AutoImageProcessor, AutoModel, ViTImageProcessor, ViTForImageClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# First embedding model: dinov2
processor_dino = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
processor_kwargs_dino = {
    "return_tensors" : "pt"
}
dinov2_hf = AutoModel.from_pretrained('facebook/dinov2-large')
dinov2 = FaissIndexImageEncoder(
    dinov2_hf, 
    index_file_out="./data/test_photos_dinov2.db", 
).eval().to(device)

# Second embedding model: ViT
processor_vit = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
processor_kwargs_vit = {
    "return_tensors" : "pt"
}
vit_hf = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
vit = FaissIndexImageEncoder(vit_hf, index_file_out="./data/test_photos_vit.db").eval().to(device)

## Create the Embeddings

Loop over the data with different models and aggregration functions to create our embeddings.

In [20]:
import torch

from tqdm import tqdm

from photo_encode.utils import image_batch_from_folder_generator
from torchvision.transforms import Resize, Compose

# Directory with photos
photo_directory = "./data/test_photos"

# Batch size
batch_size = 8

# Mean aggregation
def dinov2_mean_aggregate(x):
    embeddings = x.last_hidden_state
    return embeddings.mean(axis=1)

def vit_mean_aggregate(x):
    embeddings = x.logits
    return embeddings.mean(axis=1)
    
# Walk over the photo directory
with torch.no_grad():
    # For dinov2, we need to override the stack function as the batch collator
    for x, files in tqdm(
        image_batch_from_folder_generator(photo_directory, processor_dino, processor_kwargs_dino, batch_size, collate=lambda x: x)
    ):
        pixel_values = torch.stack(tuple([y["pixel_values"][0] for y in x]))
        dino_batch = {
            "pixel_values" : pixel_values.to(device)
        }
        embeddings = dinov2(dino_batch, files, dinov2_mean_aggregate)

    # For ViT
    for x, files in tqdm(
        image_batch_from_folder_generator(photo_directory, processor_vit, processor_kwargs_vit, batch_size, collate=lambda x: x)
    ):
        pixel_values = torch.stack(tuple([y["pixel_values"][0] for y in x]))
        vit_batch = {
            "pixel_values" : pixel_values.to(device)
        }
        embeddings = vit(vit_batch, files, vit_mean_aggregate)



85it [04:45,  3.36s/it]
85it [03:32,  2.50s/it]


## Create Clusters

Loop over a range of cluster numbers and compute metrics for quality.

In [None]:
from sklearn.cluster import KMeans

dinov2_mean_embeddings = dinov2.current_embeddings.detach().cpu().numpy()
vit_mean_embeddings = vit.current_embeddings.detach().cpu().numpy()

min_clusters = 4
max_clusters = 50

for n_c in range(min_clusters, max_clusters + 1):
    kmeans = KMeans(n_clusters=n_c, n_init=10)
    dinov2_clusters = kmeans.fit(dinov2_mean_embeddings)

    # TODO: Compute cluster quality metrics
    break
