## Visualize K-means clustered ImageNet Images

In [1]:
import os
import re
import json
import matplotlib.pyplot as plt
from PIL import Image

In [3]:
# List all files in the directory and match the specific pattern
img_dir = '/shared/nas2/blume5/fa23/ecole/data/imagenet/subset-whole_unit-100'
data_dir = './subset-whole_unit-100-index'
pattern = re.compile(r'^nearest_img_paths_.*\.json$')
files_matching = [f for f in os.listdir(data_dir) if pattern.match(f)]

print(files_matching)

['nearest_img_paths_top_k_5-n_10-iter_200.json', 'nearest_img_paths_top_k_5-n_10-iter_1000.json', 'nearest_img_paths_top_k_5-n_100-iter_200.json', 'nearest_img_paths_top_k_5-n_50-iter_200.json', 'nearest_img_paths_top_k_5-n_10-iter_500.json']


In [6]:
# Visualize 'top_k' images per centroid
centroid2imgs_list = []
for fpath in files_matching:
    with open(os.path.join(data_dir, fpath)) as json_reader:
        print(f"File: {fpath}")
        centroid2imgs = json.load(json_reader)
        print(f"Num. of Centroids (n) : ", len(centroid2imgs))
        centroid2imgs_list.append(centroid2imgs)

File: nearest_img_paths_top_k_5-n_10-iter_200.json
Num. of Centroids (n) :  10
File: nearest_img_paths_top_k_5-n_10-iter_1000.json
Num. of Centroids (n) :  10
File: nearest_img_paths_top_k_5-n_100-iter_200.json
Num. of Centroids (n) :  100
File: nearest_img_paths_top_k_5-n_50-iter_200.json
Num. of Centroids (n) :  50
File: nearest_img_paths_top_k_5-n_10-iter_500.json
Num. of Centroids (n) :  10


In [None]:
for centroid2imgs in centroid2imgs_list:
    for centroid_idx, image_paths in centroid2imgs.items():
        centroid_idx = int(centroid_idx)    
        # Load the images
        images = [Image.open(path) for path in image_paths]

        # Determine the grid size
        grid_size = (len(images) // 2, 2) if len(images) % 2 == 0 else (len(images) // 2 + 1, 2)

        # Create a matplotlib figure with the determined grid size
        fig, axes = plt.subplots(grid_size[0], grid_size[1], figsize=(10, 10))

        # Flatten the axes for easy iteration
        axes_flat = axes.flatten()

        # Display each image in the grid
        for i, img in enumerate(images):
            axes_flat[i].imshow(img)
            axes_flat[i].axis('off')  # Hide the axes

        # Hide any unused axes if the number of images is odd
        if len(images) % 2 != 0:
            axes_flat[-1].axis('off')

        # Show the grid of images
        plt.tight_layout()
        plt.show()