In [None]:
import torch

import os
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import Dataset, DataLoader
from datasets import SolarDataset, normalize_standard, preprocess_clip_wrapper, preprocess_dino,load_file_names_and_classes_for_test, prepare_dataloaders, find_all_fits_files, load_filenames, load_filenames_inverse
from utils import test_visualize_images_all_cluster_zero, visualize_batch_images_from_cluster, visualize_histograms_from_cluster, visualize_np_images_from_cluster
import matplotlib.patches as patches

from general import DATA_PATH, TEST_PATH
from utils import *
import os
import random
import glob
from sklearn.utils import check_random_state 

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from PIL import Image, ImageOps


# Set the environment variable for the current session
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

random_seed = 45
np.random.seed(random_seed)
random.seed(random_seed)
random_state = check_random_state(random_seed)

In [None]:
if not os.path.exists("outputs"):
    os.mkdir("outputs")
if not os.path.exists("outputs/isolation_forest"):
    os.mkdir("outputs/isolation_forest")
if not os.path.exists("outputs/weights"):
    os.mkdir("outputs/weights")

In [None]:
im_size = 32
# By default we will visualize files from the folder which is listed in saved saved_filtered_filenames_final.txt
inverse=False
mode="clip"  #other option: dino, msn, mae, standard

initial_steps = [
    #{'type': 'pca', 'params': {'n_components': 0.95}},
    {'type': 'tsne', 'params': {'n_components': 2}},
    {'type': 'histogram_kmeans', 'params': {'n_clusters': 20}},
    # Add other steps as needed
]

In [None]:
def order_clusters_by_median(cluster_medians, cluster_labels_unique):
    ordered_indices = [3]
    remaining_indices = set(cluster_labels_unqiue) - set(ordered_indices)
    
    while remaining_indices:
        last_index = ordered_indices[-1]
        # Find the closest cluster by median, excluding already chosen ones
        closest_index = min(remaining_indices, key=lambda x: abs(cluster_medians[x] - cluster_medians[last_index]))
        ordered_indices.append(closest_index)
        remaining_indices.remove(closest_index)
    
    return ordered_indices

def reorder_images_by_cluster_order(images, labels, ordered_cluster_indices):
    reordered_images = np.concatenate([images[labels == i] for i in ordered_cluster_indices])
    reordered_labels = np.concatenate([labels[labels == i] for i in ordered_cluster_indices])
    return reordered_images, reordered_labels

def reorder_properties(train_properties, ordered_indices):
    reordered_properties = {}
    for key, value in train_properties.items():
        # Ensure the property is an array-like structure and has the same first dimension.
        if isinstance(value, np.ndarray) : #and value.shape[0] == len(ordered_indices):
            reordered_properties[key] = value[ordered_indices]
        else:
            reordered_properties[key] = value  # Keep unchanged if the property does not match the criteria.
    return reordered_properties

def filter_images_by_cluster(reordered_properties, reordered_indices, sorted_labels, indices_reduce, n=8):
    # Initialize a dictionary to track the count of images for each cluster in indices_reduce
    cluster_image_count = {cluster: 0 for cluster in np.unique(sorted_labels)}

    # List to hold the selected indices following the original order
    selected_reordered_indices = []

    indexes_normal = np.array(range(np.max(reordered_indices)))

    # Iterate over sorted_labels and reordered_indices together
    for idx, cluster_label in zip(indexes_normal, sorted_labels):
        if cluster_label in indices_reduce:
            if cluster_image_count[cluster_label] < n:
                # Add index if the cluster's image count is below the limit
                selected_reordered_indices.append(idx)
                cluster_image_count[cluster_label] += 1
        else:
            # Always include indices for clusters not in indices_reduce
            selected_reordered_indices.append(idx)
            cluster_image_count[cluster_label] += 1


    # Filter each property in the reordered_properties dictionary using selected indices
    filtered_properties = {}
    for key, value in reordered_properties.items():
        if isinstance(value, np.ndarray) and value.shape[0] == len(reordered_indices):
            filtered_properties[key] = value[selected_reordered_indices]
        else:
            filtered_properties[key] = value

    # The new_reordered_original_indices list is simply the selected_reordered_indices
    new_reordered_original_indices = selected_reordered_indices

    # Generate new sorted_labels based on the filtered selection
    new_sorted_labels = [sorted_labels[reordered_indices.index(i)] for i in selected_reordered_indices]

    return filtered_properties, new_reordered_original_indices, new_sorted_labels


def visualize_grid_with_clusters(images, cluster_ids, grid_size=(90, 90), figsize=(18, 18), border_size=10, other_props=None):
    colormap = cm.get_cmap('tab20b', max(cluster_ids) + 1) 

    fig, ax = plt.subplots(figsize=figsize)
    ax.set_xlim(0, grid_size[1] * images.shape[2]+2*border_size)
    ax.set_ylim(0, grid_size[0] * images.shape[1]+2*border_size)
    ax.axis('off')

    for idx, img in enumerate(images):
        if idx >= grid_size[0] * grid_size[1]:
            print("break")
            break

        row = idx // grid_size[1]
        col = idx % grid_size[1]
        x = col * images.shape[2]
        y = (grid_size[0] - 1 - row) * images.shape[1]

        filename = other_props["filename"][idx]
        if filename.find("4-10keV") != -1:
            ax.imshow(img, extent=(x, x + images.shape[2], y, y + images.shape[1]), cmap='inferno', aspect='auto')
        else:
            ax.imshow(img, extent=(x, x + images.shape[2], y, y + images.shape[1]), cmap='viridis', aspect='auto')

        uid = filename[filename.find("uid_")+4:].split("_")[0]
        cluster_color = colormap(cluster_ids[idx])

        cluster_nb = str(cluster_ids[idx])
        fontsize = 4

        rect = patches.Rectangle((x+border_size/2, y+border_size/2 ), images.shape[2]-2*border_size/4, images.shape[1]-2*border_size/4 , linewidth=border_size*3/2, edgecolor=cluster_color, facecolor='none')
        ax.add_patch(rect)

        ax.text(x + border_size+2, y + images.shape[1] - 4.0*border_size, uid, color='white', fontsize=fontsize, ha='left', va='bottom')
        ax.text(x + border_size+2, y + 1.5*border_size, cluster_nb, color='white', fontsize=fontsize, ha='left', va='bottom')


    plt.savefig(f"outputs/test_visualized_clusters_{mode}{extra_line}_{len(images)}.png", dpi=350)
    plt.show()

In [None]:

if inverse:
    extra_line = "_inverse"
else:
    extra_line = ""
    
train_dataloader, test_dataloader = prepare_dataloaders(mode=mode, im_size=im_size, batch_size=16, inverse=inverse)
    

In [None]:
im_shape = (3, 224, 224)

if mode == "standard":
    im_shape = (1, im_size, im_size)
    
property_names = ['thermal_component', 'class', "data", "filename"]
train_features, train_properties = extract_features(train_dataloader, property_names)
test_features, test_properties = extract_features(test_dataloader, property_names)

In [None]:
train_properties['thermal_component_vectorized'] = replace_string_values(train_properties['thermal_component'])
test_properties['thermal_component_vectorized'] = replace_string_values(test_properties['thermal_component'])
print(train_properties["data"].shape)
if len(train_properties["data"].shape) != 4:
    train_properties["data"] = train_properties["data"].reshape(train_features.shape[0], 3, *im_shape[1:])
    test_properties["data"] = test_properties["data"].reshape(test_features.shape[0], 3, *im_shape[1:])
print(train_properties["data"].shape)

In [None]:
curs = test_properties["data"][test_properties["class"]==4]
for i in curs:
    plt.imshow(i[0])
    plt.show()

In [None]:
plt.hist(train_features.reshape(-1,), density=True, bins=100)
plt.hist(test_features.reshape(-1,), density=True, alpha=0.5, bins=100)
plt.show()

In [None]:
# Define your steps
for n_clusters in [ 30, ]:
    initial_steps = [
        #{'type': 'pca', 'params': {'n_components': 0.9}},
        {'type': 'tsne', 'params': {'n_components': 2, 'learning_rate': 'auto', 'init': 'random', 'perplexity': 300}},
        {'type': 'kmeans', 'params': {'n_clusters': n_clusters}},
        # Add other s
    ]
        
    models_info, train_transformed = train_and_store_models(train_features, initial_steps)
    models_info, test_transformed = apply_models_to_test_data(test_features, models_info)
    
    train_tsne, test_tsne = models_info["tsne_transformed_train_features"], models_info["tsne_transformed_test_features"]
    print(models_info.keys())

In [None]:

from skimage.transform import resize

input_images = train_properties["data"][:,:1] 
# Create an array of original indices
original_indices = np.arange(len(input_images))

# Resize images to 10x10
resized_images = np.array([resize(image[0], (40, 40)) for image in input_images])

cluster_labels = models_info[f"{initial_steps[-1]['type']}_labels"]
# Sort images by cluster labels
sorted_indices = cluster_labels.argsort()
sorted_images = resized_images[sorted_indices]
sorted_labels = cluster_labels[sorted_indices]
sorted_original_indices = original_indices[sorted_indices]

# Number of clusters
cluster_labels_unqiue = np.unique(sorted_labels)
num_clusters = cluster_labels_unqiue.size

# Calculate median for each cluster
cluster_medians = {i : np.mean(sorted_images[sorted_labels == i]) for i in np.unique(sorted_labels)}

# Order clusters
ordered_cluster_indices = order_clusters_by_median(cluster_medians, cluster_labels_unqiue)

reordered_images, reordered_labels = reorder_images_by_cluster_order(sorted_images, sorted_labels, ordered_cluster_indices)
reordered_original_indices, _ = reorder_images_by_cluster_order(sorted_original_indices, sorted_labels, ordered_cluster_indices)

In [None]:


# Apply the final order to all properties in train_properties
reordered_train_properties = reorder_properties(train_properties, reordered_original_indices)

indices_reduce = []#[0,2,5,6,7,12,13,17,18,20,25,26,28,29, 3, 23, 14, 1, 9, 21, 10, 19, 8]

# Apply the function to filter images while maintaining order
filtered_train_properties, new_reordered_original_indices, new_sorted_labels = filter_images_by_cluster(
    reordered_train_properties, reordered_original_indices.tolist(), reordered_labels.tolist(), indices_reduce, n=37
)
new_sorted_labels = reordered_labels[new_reordered_original_indices]
new_reordered_images = reordered_images[new_reordered_original_indices]
filtered_train_properties = reorder_properties(reordered_train_properties, new_reordered_original_indices)

In [None]:


visualize_grid_with_clusters(new_reordered_images, new_sorted_labels, figsize=(40,40), grid_size=(70, 70), border_size=2, other_props=filtered_train_properties)


In [None]:
current_properties = train_properties
cluster_labels=models_info[f"{initial_steps[-1]['type']}_labels"]
names = {}
for i in np.unique(cluster_labels):    
    names[i] = current_properties["filename"][cluster_labels==i]
    
with open(f'outputs/clusters_seed_{random_seed}{extra_line}_{initial_steps[-1]["params"]["n_clusters"]}.txt', 'w') as file:
    for key, file_paths in names.items():
        file.write(f"{key}\n")
        for file_path in file_paths:
            file.write(f"{file_path}\n")

In [None]:
for i in np.unique(new_sorted_labels):
    print(i)
    print(np.sum(cluster_labels==i))
    
    visualize_np_images_from_cluster(i, new_sorted_labels, filtered_train_properties["data"], max_images=500, norm_func=lambda x : x)
    plt.show()

# This creates a new file. If you want to use it, rename it as saved_filtered_filenames_final.txt

In [None]:
def save_filenames(filenames, filepath):
    with open(filepath, 'w') as file:
        for filename in filenames:
            file.write(filename + '\n')
            
cluster_to_keep = [1,3]
indexes = np.array(range(len(new_sorted_labels)))
filenames = []
for cluster_label in cluster_to_keep:
    filenames.extend(filtered_train_properties["filename"][new_sorted_labels==cluster_label].tolist())
    


# Assuming updated_train_properties["filename"] contains your list of filenames
save_filepath = 'saved_filtered_filenames_new.txt'

# Save the filenames to a text file
save_filenames(filenames, save_filepath)