## Imports

In [None]:
import os
import json
import pickle
import math
from copy import deepcopy
import numpy as np
import csv
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances

In [None]:
from source.constants import RANDOM_SEED
from source.constants import NUM_CLASS_IMAGES, NUM_CLASS_PROTOTYPES, AVG_NUM_DUPLICATES_PER_CLASS_PROTOTYPE
from source.constants import DATA_DIR, FEATURE_VECTORS_SAVE_DIR, ANNOTATIONS_SAVE_DIR
from source.constants import ALL_CANCER_TYPES, ALL_EXTRACTOR_MODELS, ALL_IMG_NORMS, ALL_DISTANCE_METRICS, ALL_DIMENSIONALITY_REDUCTION_METHODS, ALL_CLUSTERING_ALGORITHMS

print("RANDOM_SEED:", RANDOM_SEED)
print("NUM_CLASS_IMAGES:", NUM_CLASS_IMAGES)
print("NUM_CLASS_PROTOTYPES:", NUM_CLASS_PROTOTYPES)
print("AVG_NUM_DUPLICATES_PER_CLASS_PROTOTYPE:", AVG_NUM_DUPLICATES_PER_CLASS_PROTOTYPE)
print()

print(f"DATA_DIR: {DATA_DIR}")
print(f"FEATURE_VECTORS_SAVE_DIR: {FEATURE_VECTORS_SAVE_DIR}")
print(f"ANNOTATIONS_SAVE_DIR: {ANNOTATIONS_SAVE_DIR}")
print()

print("ALL_CANCER_TYPES:", ALL_CANCER_TYPES)
print("ALL_EXTRACTOR_MODELS:", ALL_EXTRACTOR_MODELS)
print("ALL_DISTANCE_METRICS:", ALL_DISTANCE_METRICS)
print("ALL_DIMENSIONALITY_REDUCTION_METHODS:", ALL_DIMENSIONALITY_REDUCTION_METHODS)
print("ALL_CLUSTERING_ALGORITHMS:", ALL_CLUSTERING_ALGORITHMS)

In [None]:
from source.eval_utils import reduce_feature_dimensionality, get_clustering_labels, get_clustering_centroids

In [None]:
from source.interactive_clustering_utils import (
    visualise_kmeans_cluster,
    display_image_pairs,            # manual accepting and rejecting of pairs: anchor, candidate
    visualize_cluster_results_processed_file_path,
    kmeans_and_review,              # cluster and purify rejected images
    merge_clusters_interactively,   # merge pure accepted and pure rejected clusters
)

## autoreload

In [None]:
%load_ext autoreload
%autoreload 2

## Notebook Constants

In [None]:
# TODO: Set the cancer type and extractor name
CANCER_TYPE = 'colon_n'
EXTRACTOR_NAME = 'UNI'
IMG_NORM = 'resize_only'
DISTANCE_METRIC = 'euclidean'
DIMENSIONALITY_REDUCTION_METHOD = 'UMAP-8'
CLUSTERING_ALGORITHM = 'kmeans'

assert CANCER_TYPE in ALL_CANCER_TYPES
assert EXTRACTOR_NAME in ALL_EXTRACTOR_MODELS
assert IMG_NORM in ALL_IMG_NORMS
assert CLUSTERING_ALGORITHM in ALL_CLUSTERING_ALGORITHMS
assert DIMENSIONALITY_REDUCTION_METHOD in ALL_DIMENSIONALITY_REDUCTION_METHODS
assert DISTANCE_METRIC in ALL_DISTANCE_METRICS

In [None]:
features_save_dir = f'{FEATURE_VECTORS_SAVE_DIR}/{CANCER_TYPE}/{EXTRACTOR_NAME}/{IMG_NORM}'
print("features_save_dir:", features_save_dir)

# expected to already be there
features_npy_path = f'{features_save_dir}/features.npy'
ids_2_imgpaths_json_path = f'{features_save_dir}/ids_2_img_paths.json'
assert os.path.isfile(features_npy_path)
assert os.path.isfile(ids_2_imgpaths_json_path)

# to save objects created in this notebook
results_dir = features_save_dir.replace(FEATURE_VECTORS_SAVE_DIR, ANNOTATIONS_SAVE_DIR)
os.makedirs(results_dir, exist_ok=True)

clustering_model_path = f'{results_dir}/distance_metric={DISTANCE_METRIC}#dimensionality_reduction={DIMENSIONALITY_REDUCTION_METHOD}#clustering={CLUSTERING_ALGORITHM}.pkl'
print("clustering_model_path:", clustering_model_path)

## Clustering

In [None]:
features = np.load(features_npy_path)
print(features.shape)

# Load image paths
with open(ids_2_imgpaths_json_path, 'r') as f:
    ids_2_imgpaths = json.load(f)
print(ids_2_imgpaths)

In [None]:
if DISTANCE_METRIC == 'cosine':
    # after normalisation, euclidean distance is equivalent to cosine distance
    # KMeans does not support cosine distance, so we can't just pass distance_metric to KMeans as a parameter
    features = features / \
        np.linalg.norm(features, axis=1,  keepdims=True)

features = reduce_feature_dimensionality(features=features, method=DIMENSIONALITY_REDUCTION_METHOD)
print(features.shape, features.dtype)

In [None]:
# Perform clustering
labels, clustering_model = get_clustering_labels(
    features=features,
    n_clusters=NUM_CLASS_PROTOTYPES,
    method=CLUSTERING_ALGORITHM,
    random_state=RANDOM_SEED,
    return_model=True
)
cluster_centers = get_clustering_centroids(features, labels)
print(labels.shape, cluster_centers.shape)
print(labels)

In [None]:
if not os.path.exists(clustering_model_path):
    print(f"No model at {clustering_model_path}. \n Saving the current model.")
    with open(clustering_model_path, 'wb') as f:
        pickle.dump(clustering_model, f)
        labels = clustering_model.labels_
else:
    print(f"KMeans model at {clustering_model_path} exists. \n Loading and checking if it's the same as the current model.")
    with open(clustering_model_path, 'rb') as f:
        loaded_kmeans = pickle.load(f)
    if (
        (clustering_model.labels_ != loaded_kmeans.labels_).all()
        or not np.allclose(clustering_model.cluster_centers_,
                           loaded_kmeans.cluster_centers_,
                           atol=1e-6)
    ):
        print("Loaded KMeans model is not the same as the current model.")
        user_input = input("Do you want to use the loaded model (l) or the new model (n)?")

        if user_input == 'n':
            labels = clustering_model.labels_
        elif user_input == 'l':
            labels = loaded_kmeans.labels_
        else:
            raise NotImplementedError("Choose between 'l' and 'n'.")
    else:
        print("Loaded KMeans model is the same as the current model.")
        labels = loaded_kmeans.labels_
        
print(labels)

# kmeans has the centroids attribute, but agglomerative clustering does not
centroids = get_clustering_centroids(features, labels)

In [None]:
# Map cluster labels to image paths and features
labels_2_imgpaths = {}
labels_2_features = {}
for idx, label in enumerate(labels):
    img_path = ids_2_imgpaths[str(idx)]
    if label not in labels_2_imgpaths:
        labels_2_imgpaths[label] = []
        labels_2_features[label] = []
    labels_2_imgpaths[label].append(img_path)
    labels_2_features[label].append(features[idx])

# Find the image closest to the centroid for each cluster
centroid_imgpaths = {}
for label, centroid in enumerate(centroids):
    distances = euclidean_distances([centroid], labels_2_features[label])
    closest_idx = np.argmin(distances)
    centroid_imgpaths[label] = labels_2_imgpaths[label][closest_idx]

In [None]:
visualise_kmeans_cluster(
    cluster_index=2,
    labels_2_imgpaths=labels_2_imgpaths,
    labels_2_features=labels_2_features,
    centroids=centroids,
    num_examples=3,
)

## Manual Annotation

In [None]:
# Define paths for results and session state
results_csv_file_path = f'{results_dir}/results.csv'
session_state_file_path = f'{results_dir}/session_state.json'
results_processed_file_path = f'{results_dir}/results_processed.json'

In [None]:
# Start the session from the last state
for i in range(5):
    print(f"Cluster {i}")
    interrupt_status = display_image_pairs(
        labels_2_imgpaths=labels_2_imgpaths,
        labels_2_features=labels_2_features,
        features=features,
        centroid_imgpaths=centroid_imgpaths,
        results_csv_file_path=results_csv_file_path,
        results_processed_file_path=results_processed_file_path,
        session_state_file_path=session_state_file_path,
    )
    if interrupt_status == 'q':
        break

## Check for Skipped and Non-belonging images

In [None]:
# Load processed images
with open(results_processed_file_path, 'r') as f:
    results_processed = json.load(f)

processed_images_paths = []
non_belonging_images_paths = []
for anchor_img_path, details in results_processed.items():
    processed_images_paths.append(anchor_img_path)
    processed_images_paths.extend(details['belonging_image_paths'])
    processed_images_paths.extend(details['non_belonging_image_paths'])
    non_belonging_images_paths.extend(details['non_belonging_image_paths'])

assert len(set(processed_images_paths)) == len(processed_images_paths)
assert len(set(non_belonging_images_paths)) == len(non_belonging_images_paths)

processed_images_paths = sorted(processed_images_paths)
non_belonging_images_paths = sorted(non_belonging_images_paths)

print(f"Total processed images: {len(processed_images_paths)}")
print(f"Total non-belonging images: {len(non_belonging_images_paths)}")

# Get all images in clusters
all_images_paths = []
for img_list in labels_2_imgpaths.values():
    all_images_paths.extend(img_list)
assert len(set(all_images_paths)) == len(all_images_paths)
all_images_paths = sorted(all_images_paths)

# Determine skipped images
skipped_images_paths = set(all_images_paths) - set(processed_images_paths)
print(f"Skipped Images: {len(set(skipped_images_paths))}")

In [None]:
# Example usage
visualize_cluster_results_processed_file_path(
    cluster_id=0,
    results_processed_file_path=results_processed_file_path
)

In [None]:
# Example usage - 
visualize_cluster_results_processed_file_path(cluster_id=60, results_processed_file_path=results_processed_file_path)

## Group Non-Belonging Images

In [None]:
pure_rejected_clusters_json_path = f'{results_dir}/pure_rejected_clusters.json'

In [None]:
img_paths_2_int_ids = {v: int(k) for k, v in ids_2_imgpaths.items()}

# Extract features for the non-belonging images
non_belonging_features = [features[img_paths_2_int_ids[img_path]]
                          for img_path in non_belonging_images_paths]

# Convert to numpy array
non_belonging_features = np.array(non_belonging_features)
print(non_belonging_features.shape)

In [None]:
pure_clusters = kmeans_and_review(
    non_belonging_features=non_belonging_features,
    n_clusters=math.ceil(2 * len(non_belonging_images_paths) / AVG_NUM_DUPLICATES_PER_CLASS_PROTOTYPE), # make twice as many clusters as expected to increase the chance of finding pure clusters
    non_belonging_images_paths=non_belonging_images_paths,
    pure_rejected_clusters_json_path=pure_rejected_clusters_json_path
)

## Merge Rejected Purified Clusters and Originally-selected Clusters until there are <= 250 clusters in total

In [None]:
with open(pure_rejected_clusters_json_path, 'r') as f:
    pure_rejected_clusters = json.load(f)

assert sum(len(v) for v in pure_rejected_clusters.values()
           ) == len(non_belonging_images_paths)
print(f"Total pure clusters: {len(pure_rejected_clusters)}")
print(
    f"Total rejected images: {sum(len(v) for v in pure_rejected_clusters.values())}")

In [None]:
# -------------------------------------------------------------------------------------

# Load features
features = np.load(features_npy_path)

# load ids_2_imgpaths
with open(ids_2_imgpaths_json_path, 'r') as f:
    ids_2_imgpaths = json.load(f)
img_paths_2_int_ids = {v: int(k) for k, v in ids_2_imgpaths.items()}

# Load the merged clusters
with open(pure_rejected_clusters_json_path, 'r') as f:
    pure_rejected_clusters = json.load(f)

# load the results_processed
with open(f'{results_dir}/results_processed.json', 'r') as f:
    results_processed = json.load(f)

# -------------------------------------------------------------------------------------

all_clusters = {}
for key, value in results_processed.items():
    cluster_id = int(value['cluster_index'])
    # key is the anchor image path
    images = [key] + value['belonging_image_paths']
    all_clusters[cluster_id] = images

for cluster_id, images in pure_rejected_clusters.items():
    all_clusters[NUM_CLASS_PROTOTYPES + int(cluster_id)] = images

# just to make sure we are not modifying the original clusters
all_clusters = deepcopy(all_clusters)

print("Total clusters:", len(all_clusters))

total = 0
for key, value in all_clusters.items():
    total += len(value)
    # print(f"Cluster {key}: {len(value)} images")
print(f"Total images: {total}")

# -------------------------------------------------------------------------------------

# Execute the interactive merging
merge_results = merge_clusters_interactively(
    clusters=all_clusters,
    features=features,
    img_paths_2_int_ids=img_paths_2_int_ids,
    max_num_clusters=250,
    linkage='single',
    patience=10,
)

final_clusters = merge_results['clusters']
print("Final clusters:", final_clusters)

In [None]:
# Can be interrupted and restarted
merge_results = merge_clusters_interactively(
    clusters=final_clusters,
    features=features,
    img_paths_2_int_ids=img_paths_2_int_ids,
    max_num_clusters=250,
    linkage='single',
    patience=10,
)

final_clusters = merge_results['clusters']
print("Final clusters:", final_clusters)

## Last check and Save

In [None]:
len(final_clusters.keys())
total = 0
uniqie_images = set()
images_per_cluster_list = []
for key, value in final_clusters.items():
    total += len(value)
    # print(f"Cluster {key}: {len(value)} images")
    uniqie_images.update(value)
    images_per_cluster_list.append(len(value))
print(total)
assert total == 5000
assert len(uniqie_images) == 5000

print(f"Total images: {total}")
print(f"Unique images: {len(uniqie_images)}")

plt.bar(range(len(images_per_cluster_list)), images_per_cluster_list)
plt.xlabel('Cluster Index')
plt.ylabel('Number of Images')


# Save the final clusters in a csv file, make sure to save the labels as integers from 0 to n_clusters-1 for consistency
final_clusters_csv_path = f'{results_dir}/final_clusters.csv'
final_clusters_contiguous_indices = {}
with open(final_clusters_csv_path, 'w', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(['cluster_label', 'img_path'])

    for i, key in enumerate(final_clusters.keys()):
        
        # access the list of files at key
        sorted_img_paths = sorted(final_clusters[key])
        # save the sorted_img_paths in the final_clusters_contiguous_indices with the new key
        final_clusters_contiguous_indices[i] = sorted_img_paths

        for img_path in sorted_img_paths:
            csvwriter.writerow([i, img_path])

# Save the final clusters in a json file
final_clusters_json_path = f'{results_dir}/final_clusters.json'
with open(final_clusters_json_path, 'w') as f:
    json.dump(final_clusters_contiguous_indices, f, indent=4)