In [None]:
import tensorflow as tf

import os
import numpy as np
import math
import json

from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import pdist, squareform

from collections import Counter
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.offsetbox import OffsetImage,AnnotationBbox


from tensorflow.keras.preprocessing import image

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

### Setup Data

In [None]:
# If you're working with Colab mount your drive or skip this step
from google.colab import drive
drive.mount('/content/drive')

In [None]:
scraped_images_folder = '/set/the/path/to/your/scraped/images/'
                 
feature_files = [('/path/to/file/BiT-m-r152x4_feature.npz', 0.9),
                 ('/path/to/another/feature_file.npz', None),
                 ('/path/to/another/feature_file.npz', 10)
                 ]

# Here you must now specify a list of tuples with the form (feature_file, components).
# Components is the parameter for the PCA or is set to None.
# You can stack different feature-files and combine them.
# PCA: Reduce the number of dimensions to n_components  
# If n_components is between 0-1 it controlls the fraction of explained variance
# If it is integer, it controlls the number of components directly

In [None]:
# fix feature files if necessary
for feature_file,_ in feature_files:
    feature_dict = np.load(feature_file)
    keys = list(feature_dict.keys())
    if "'" in keys[0]:
        print('fixing '+feature_file)
        new_dict = {}
        for k in keys:
            new_dict[k.replace("'",'')] = feature_dict[k]
        np.savez(feature_file, **new_dict)

### Load Data

In [None]:
# collect all the feature files
feature_dicts = []
image_name_lists = []
for feature_file,components in feature_files:
    feature_dict = dict(np.load(feature_file))
    # do local pca on features
    pca = PCA(n_components = components,whiten = True)
    image_names = np.array(list(feature_dict.keys()))
    features = np.array(list(feature_dict.values()))
    features = pca.fit_transform(features)
    feature_dict = dict(zip(image_names,features))
    feature_dicts.append(feature_dict) 
    keys =feature_dict.keys()
    image_name_lists.append(keys)
    print(f'using {features.shape[1]} features from {feature_file}')

# make sure we use only image names occuring in all files
image_names  = set(image_name_lists[0])
for image_name_list in image_name_lists[1:]:
    image_names = image_names.intersection(set(image_name_list))
image_names = list(image_names)

features = []
for image_name in image_names:
    feature = []
    for feature_dict in feature_dicts:
        feature.append(feature_dict[image_name])
    features.append(np.concatenate(feature))

image_names = np.array(image_names)
all_features = np.array(features)

In [None]:
all_features.shape

In [None]:
# reduce the number of dimensions to n_components  
# if n_components is between 0-1 it controlls the fraction of explained variance
# if it is integer, it controlls the number of components directly
pca = PCA(n_components = 0.9, whiten = True)
features = pca.fit_transform(all_features)
features.shape

In [None]:
#features = all_features # if you are not running a second PCA
#features.shape

In [None]:
# for plotting only
pca2d = PCA(n_components = 2,whiten = True)
features_pca2d = pca2d.fit_transform(features)
tsne2d = TSNE(n_components =2 )
features_tsne2d = tsne2d.fit_transform(features)

In [None]:
features_pca2d.shape # checksum

In [None]:
features_tsne2d.shape # checksum

### Cluster

In [None]:
max_samples = -1 # set to a part of the data f.e. 20000 or to -1 if all samples should be considered

l = linkage(features[:max_samples], 'ward')

In [None]:
threshold = 236

main_clusters_tmp = fcluster(l, criterion='distance', t=threshold)
print(f'{np.max(main_clusters_tmp)} cluster created')

main_clusters = {}
for cluster_number in range(1, main_clusters_tmp.max()):
    main_clusters[cluster_number] = np.where(main_clusters_tmp == cluster_number)[0]
    print(f'cluster {cluster_number} has {len(main_clusters[cluster_number])} member')

### Filter Copyright

In [None]:
legalize = True # the filter is switched on (True) or off (False)

In [None]:
def clean_images(members,image_names):
    copyright_file = '/set/the/path/to/the/file/is_public.json'
    with open(copyright_file) as json_file:
        is_public = json.load(json_file)
    new_members = []
    for member in members:
        if is_public[image_names[member]]:
            new_members.append(member)
    return np.array(new_members)

### Visualize cluster

In [None]:
def vis_cluster(cluster_number, clustering, img_per_row,mode = 'tsne_scatter',zoom=0.4,local_reduction = True):
    """ modes: grid         - arrangement in a grid
               pca_grid     - grid influenced by the PCA
               pca_scatter  - scatterplot influenced by the PCA
               tsne_grid    - grid influenced by the TSNE
               tsne_scatter - scatterplot influenced by the TSNE
        local_reduction: True   - PCA or TSNE is only calculated on the visualized cluster
                         False  - PCA or TSNE is calculated on all data
        zoom: size of the images in the scatterplot
    """
    members = clustering[cluster_number]
    if legalize:
        members = clean_images(members,image_names)
    member_images = [os.path.join(scraped_images_folder, image_names[m].strip("'")) for m in members]
    if mode == 'grid':
        plt.figure(figsize=(30, math.ceil(len(members)/img_per_row)*(25/img_per_row)))
        for i, member in enumerate(members):
            plt.subplot(math.ceil(len(members)/img_per_row), img_per_row, i+1) 
            im_name = member_images[i]
            im = image.load_img(im_name, target_size=(224,224))
            plt.imshow(im)
            plt.axis(False)
    else:
        member_features =  features[members]
        if 'pca' in mode:
            if local_reduction:
                member_pca = PCA(n_components=2,whiten  =True)
                member_positions = member_pca.fit_transform(member_features)
            else:
                member_positions = features_pca2d[members]
        elif 'tsne' in mode:
            if local_reduction:
                member_tsne = TSNE(n_components=2)
                member_positions = member_tsne.fit_transform(member_features)
            else:
                member_positions = features_tsne2d[members]

        if 'scatter' in mode:
            plt.figure(figsize = (30,30))
            ax = plt.subplot(111)
            for pos,im_name in zip(member_positions,member_images):
                img = image.load_img(im_name, target_size=(224,224))
                im = OffsetImage(img, zoom=zoom)
                ab = AnnotationBbox(im, pos, xycoords='data', frameon=False)
                ax.add_artist(ab)
                ax.update_datalim(np.column_stack(pos))
                ax.autoscale()
            plt.axis(False)

        
        elif 'grid' in mode:
            n_rows = math.ceil(len(members)/img_per_row)
            plt.figure(figsize=(30, n_rows*(25/img_per_row)))
            subplot_ind = 1
            remaining_members = range(len(members))

            while len(remaining_members)>0:
                # find the highest img_per_row members
                order = np.argsort(member_positions[remaining_members,1])
                top_members = [remaining_members[m] for m in order[-img_per_row:]]
                remaining_members =  [m for m in remaining_members if m not in top_members]
                # sort from left to right
                order = np.argsort(member_positions[top_members,0])
                for o in order:
                    plt.subplot(math.ceil(len(members)/img_per_row), img_per_row, subplot_ind) 
                    im_name = member_images[top_members[o]]
                    im = image.load_img(im_name, target_size=(224,224))
                    plt.imshow(im)
                    plt.axis(False)
                    subplot_ind += 1

In [None]:
vis_cluster(
    cluster_number = 34, # set the cluster you want to visualize
    clustering = main_clusters,
    img_per_row = 10
)

#### Visualization: Cluster with "Hierarchical Family Tree (Maincluster)"

In [None]:
def do_visual_clustering(clusters, cluster2subcluster, threshold, image_size):
    members = clusters[cluster2subcluster]
    if legalize:
        members = clean_images(members,image_names)
    l2 = linkage(features[members], 'ward')
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, len(members)*image_size))
    plt.subplot(ax1)
    d = dendrogram(l2, orientation='left', color_threshold=threshold)
    ax1.set_yticklabels([])
    gs = matplotlib.gridspec.GridSpecFromSubplotSpec(subplot_spec=ax2, ncols = 1, nrows=len(members))
    for ix, img_nr in enumerate(d['leaves']):
        plt.subplot(gs[ix])
        im_name = os.path.join(scraped_images_folder, image_names[members[img_nr]].strip("'"))
        im = image.load_img(im_name, target_size=(224,224))
        plt.imshow(im)
        plt.axis(False)
    return d

In [None]:
d = do_visual_clustering(
    clusters = main_clusters,
    cluster2subcluster = 8, # set the cluster you want to visualize
    threshold = 20, # only visual cluster cut
    image_size = 1.5 # determines how large images are printed
)

### Subclustering

In [None]:
def do_subclustering(clusters, cluster2subcluster, threshold):
    
    main_cluster_members = clusters[cluster2subcluster]
    l2 = linkage(features[main_cluster_members], 'ward')
    sub_clusters_tmp = fcluster(l2, criterion='distance', t=threshold)
    print(f'{np.max(sub_clusters_tmp)} cluster created')

    sub_clusters = {}
    for cluster_number in range(1, sub_clusters_tmp.max()):
        sub_clusters[cluster_number] = [main_cluster_members[i] for i in np.where(sub_clusters_tmp == cluster_number)[0]]
        print(f'cluster {cluster_number} has {len(sub_clusters[cluster_number])} member')

    return sub_clusters

In [None]:
sub_cluster = do_subclustering(
    clusters = main_clusters,
    cluster2subcluster = 6, # set a large cluster to divide it in subcluster
    threshold = 2
) 

In [None]:
vis_cluster(
    cluster_number = 2, # set the subcluster you want to visualize
    clustering = sub_cluster,
    img_per_row = 10
)

#### Visualization: Cluster with "Hierarchical Family Tree (Subcluster)"

In [None]:
d = do_visual_clustering(
    clusters = sub_cluster,
    cluster2subcluster = 2, # set the subcluster you want to visualize
    threshold = 5, # only visual cluster cut
    image_size = 3 # determines how large images are printed
)