In [None]:
import os
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from tensorflow.keras import backend as K


# --- 1. Model Definition ---
class ClusteringLayer(Layer):
    """
    Clustering layer converts input sample (feature) to soft label.
    It computes the Student's t-distribution similarity between input samples and cluster centroids.

    # Arguments
        n_clusters: number of clusters.
        weights: list of Numpy array with shape `(n_clusters, n_features)` representing initial cluster centers.
        alpha: degrees of freedom of the Student's t-distribution (typically 1.0).
    """
    def __init__(self, n_clusters, weights=None, alpha=1.0, **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(ClusteringLayer, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.alpha = alpha
        self.initial_weights = weights
        self.input_spec = Layer.InputSpec(ndim=2)

    def build(self, input_shape):
        assert len(input_shape) == 2
        input_dim = input_shape[1]
        self.input_spec = Layer.InputSpec(dtype=K.floatx(), shape=(None, input_dim))
        self.clusters = self.add_weight(shape=(self.n_clusters, input_dim), initializer='glorot_uniform', name='clusters')
        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def call(self, inputs, **kwargs):
        q = 1.0 / (1.0 + (K.sum(K.square(K.expand_dims(inputs, axis=1) - self.clusters), axis=2) / self.alpha))
        q **= (self.alpha + 1.0) / 2.0
        q = K.transpose(K.transpose(q) / K.sum(q, axis=1)) # Normalize
        return q

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) == 2
        return input_shape[0], self.n_clusters

    def get_config(self):
        config = {'n_clusters': self.n_clusters, 'alpha': self.alpha}
        base_config = super(ClusteringLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

def build_autoencoder(input_dim, latent_dim, hidden_dims=[500, 500, 2000]):
    """Builds a simple autoencoder model."""
    input_layer = Input(shape=(input_dim,), name='input')
    h = input_layer
    # Encoder
    for i, dim in enumerate(hidden_dims):
        h = Dense(dim, activation='relu', name=f'encoder_hidden_{i}')(h)
    latent = Dense(latent_dim, name='latent_space')(h) # Linear activation for latent space
    encoder = Model(inputs=input_layer, outputs=latent, name='encoder')

    # Decoder
    latent_input = Input(shape=(latent_dim,), name='decoder_input')
    h = latent_input
    for i, dim in enumerate(reversed(hidden_dims)):
        h = Dense(dim, activation='relu', name=f'decoder_hidden_{i}')(h)
    # Sigmoid activation if input X is normalized to [0,1]. Otherwise, 'linear' or 'tanh' might be better.
    output_layer = Dense(input_dim, activation='sigmoid', name='reconstruction_output')(h)
    decoder = Model(inputs=latent_input, outputs=output_layer, name='decoder')

    autoencoder = Model(inputs=input_layer, outputs=decoder(encoder(input_layer)), name='autoencoder')
    return autoencoder, encoder, decoder

# --- 2. Loss Functions & Target Distribution ---
def target_distribution(q):
    """
    Computes the target distribution p from soft assignments q, by sharpening.
    P_ij = (q_ij^2 / f_j) / sum_k(q_ik^2 / f_k), where f_j = sum_i q_ij.
    """
    weight = q ** 2 / K.sum(q, axis=0)
    return K.transpose(K.transpose(weight) / K.sum(weight, axis=1))

# --- 3. Training Logic ---
def train_deep_joint_clustering(X, y=None, n_clusters=10,
                                latent_dim=10, ae_hidden_dims=[500, 500, 2000],
                                ae_epochs=50, ae_batch_size=256,
                                joint_epochs=100, joint_batch_size=256, # joint_batch_size for model.fit
                                update_interval_epochs=10, # How often (in epochs) to update target P
                                gamma=0.1, # Coefficient for clustering loss
                                tol=1e-3, # Tolerance for stopping criterion based on label changes
                                random_seed=42):
    """Trains the IDEC-like deep joint clustering model."""
    tf.random.set_seed(random_seed)
    np.random.seed(random_seed)

    input_dim = X.shape[1]

    # Phase 1: Autoencoder Pre-training
    print("--- Phase 1: Autoencoder Pre-training ---")
    autoencoder, encoder, decoder = build_autoencoder(input_dim, latent_dim, ae_hidden_dims)
    autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
    autoencoder.fit(X, X, batch_size=ae_batch_size, epochs=ae_epochs, verbose=1)
    print("Autoencoder pre-training finished.")

    # Phase 2: Initialize Cluster Centroids
    print("\n--- Phase 2: Initializing Cluster Centroids ---")
    latent_representations = encoder.predict(X)
    kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=random_seed) # n_init='auto' from scikit-learn 1.4
    y_pred_kmeans = kmeans.fit_predict(latent_representations)
    cluster_centers_init = kmeans.cluster_centers_
    print("Cluster centroids initialized using K-means on latent space.")
    if y is not None:
        print(f"K-means on latent NMI: {normalized_mutual_info_score(y, y_pred_kmeans):.4f}, "
              f"ARI: {adjusted_rand_score(y, y_pred_kmeans):.4f}")

    # Phase 3: Joint Training
    print("\n--- Phase 3: Joint Clustering and Representation Learning ---")
    # Build the full IDEC model for joint training
    model_input = Input(shape=(input_dim,), name='idec_input')
    latent_z = encoder(model_input) # Use the pre-trained encoder
    q_output = ClusteringLayer(n_clusters, weights=[cluster_centers_init], name='clustering')(latent_z)
    reconstruction_output = decoder(latent_z) # Use the pre-trained decoder

    idec_model = Model(inputs=model_input, outputs=[q_output, reconstruction_output], name='idec_model')
    # The output names 'clustering' and 'decoder' (name of the Model) are used for loss keys.
    # idec_model.summary() # Useful for debugging output names.

    idec_model.compile(optimizer=Adam(learning_rate=0.001), # Can use a smaller LR for fine-tuning
                       loss={'clustering': 'kld',  # Kullback-Leibler divergence for clustering output
                             'decoder': 'mse'},    # Mean Squared Error for reconstruction output (decoder Model)
                       loss_weights={'clustering': gamma, 'decoder': 1.0})

    y_pred_last = y_pred_kmeans
    p_target = None

    for epoch in range(joint_epochs):
        if epoch % update_interval_epochs == 0:
            print(f"\nEpoch {epoch}/{joint_epochs}")
            q_pred_current, _ = idec_model.predict(X, verbose=0)
            p_target = target_distribution(tf.convert_to_tensor(q_pred_current)).numpy() # Update target P

            y_pred_current = np.argmax(q_pred_current, axis=1)
            if y is not None:
                nmi = normalized_mutual_info_score(y, y_pred_current)
                ari = adjusted_rand_score(y, y_pred_current)
                print(f"NMI = {nmi:.4f}, ARI = {ari:.4f}")

            # Check for convergence: if label assignments change less than tol
            delta_label = np.sum(y_pred_current != y_pred_last).astype(np.float32) / y_pred_current.shape[0]
            y_pred_last = y_pred_current
            if epoch > 0 and delta_label < tol:
                print(f"Reached label change tolerance ({delta_label:.4f} < {tol}). Stopping training.")
                break
        
        if p_target is None: # For the very first iteration if update_interval > 0
             q_pred_init, _ = idec_model.predict(X, verbose=0)
             p_target = target_distribution(tf.convert_to_tensor(q_pred_init)).numpy()

        # Train for one epoch using model.fit.
        # Note: X is used for both reconstruction target and input.
        # p_target is the target for the clustering head.
        history = idec_model.fit(X, {'clustering': p_target, 'decoder': X},
                                 batch_size=joint_batch_size, epochs=1, verbose=0)
        
        if epoch % update_interval_epochs == 0 :
            total_loss = history.history['loss'][0]
            clus_loss = history.history['clustering_loss'][0]
            recon_loss = history.history['decoder_loss'][0] # Ensure key matches with model output name
            print(f"Total Loss: {total_loss:.4f}, Clustering Loss: {clus_loss:.4f}, Recon Loss: {recon_loss:.4f}")


    print("Joint training finished.")
    q_final, _ = idec_model.predict(X, verbose=0)
    y_pred_final = np.argmax(q_final, axis=1)

    return y_pred_final, idec_model

# --- 4. Image Loading and Preprocessing Function ---
def load_and_preprocess_images(image_dir, target_size=(64, 64)):
    pil_images = []
    image_filenames = []
    processed_images_list = []

    print(f"Loading images from: {image_dir}")
    valid_extensions = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff")
    
    if not os.path.isdir(image_dir):
        print(f"Error: Directory not found: {image_dir}")
        return None, None, None

    for filename in sorted(os.listdir(image_dir)):
        if filename.lower().endswith(valid_extensions):
            try:
                img_path = os.path.join(image_dir, filename)
                img = Image.open(img_path)
                
                # Convert to RGB (handles grayscale, RGBA, etc.)
                img = img.convert('RGB')
                
                # Resize
                img_resized = img.resize(target_size, Image.Resampling.LANCZOS)
                pil_images.append(img_resized) # Store resized PIL image if needed later
                image_filenames.append(filename)
                
                # Convert to NumPy array and normalize to [0,1]
                img_array = np.array(img_resized) / 255.0
                processed_images_list.append(img_array)

            except Exception as e:
                print(f"Warning: Could not load/process image {filename}. Error: {e}")
    
    if not processed_images_list:
        print("No images were successfully processed.")
        return None, None, None

    print(f"{len(processed_images_list)} images loaded and preprocessed successfully.")
    
    # Stack into a single NumPy array
    X_images_raw = np.stack(processed_images_list, axis=0)
    
    # Flatten images for MLP autoencoder: (num_images, height * width * channels)
    num_images, height, width, channels = X_images_raw.shape
    X_images_flattened = X_images_raw.reshape(num_images, height * width * channels)
    
    current_input_dim = X_images_flattened.shape[1]
    print(f"Flattened image data shape: {X_images_flattened.shape}")
    
    return X_images_flattened, image_filenames, current_input_dim


# --- 5. Example Usage with Your Image Data ---
if __name__ == "__main__":
    # --- Parameters for your image data ---
    image_dir = "./storage/clean/blood_cell/segmenter" # YOUR IMAGE DIRECTORY
    
    # User-specified parameters:
    TARGET_IMAGE_SIZE = (360, 360) # Resize images to this (height, width)
    N_CLUSTERS_IMAGES = 5          # Number of clusters
    LATENT_DIM_IMAGES = 128        # Latent dimension for image features (encoding dim)

    # Derived input dimension: height * width * channels (assuming 3 for RGB)
    # This will be calculated in load_and_preprocess_images, but good to be aware:
    # INPUT_DIM_CALCULATED = TARGET_IMAGE_SIZE[0] * TARGET_IMAGE_SIZE[1] * 3
    # print(f"Expected input dimension: {INPUT_DIM_CALCULATED}") # Should be 388800

    # MLP AE hidden layers for images.
    # For INPUT_DIM = 388800 and LATENT_DIM = 128, this is a very large reduction.
    # Consider these carefully. A CAE would be much better.
    AE_HIDDEN_DIMS_IMAGES = [2048, 1024, 512] # Example: Adjust these layers.
                                             # Must be less than input_dim.
                                             # The last one should be > LATENT_DIM_IMAGES ideally.

    AE_EPOCHS_IMAGES = 30      # Adjust as needed. May need more for large inputs.
    JOINT_EPOCHS_IMAGES = 50   # Adjust as needed.
    # WARNING: With 360x360 images, batch sizes might need to be very small
    AE_BATCH_SIZE_IMAGES = 16  # Try small batch size due to large image dimensions
    JOINT_BATCH_SIZE_IMAGES = 16 # Try small batch size
    GAMMA_IMAGES = 0.1         # Weight for clustering loss

    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print("!!! WARNING: Using MLP Autoencoder for 360x360 images is computationally !!!")
    print("!!! intensive and likely suboptimal. Consider using a Convolutional      !!!")
    print("!!! Autoencoder (CAE) for better results and efficiency with images.     !!!")
    print("!!! You may encounter memory issues or very long training times.         !!!")
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n")

    # Load and preprocess your images
    X_images_flattened, image_filenames, input_dim_images = load_and_preprocess_images(
        image_dir,
        target_size=TARGET_IMAGE_SIZE
    )

    if X_images_flattened is not None:
        print(f"\nStarting deep joint clustering for image data with actual input_dim: {input_dim_images}")
        if input_dim_images != TARGET_IMAGE_SIZE[0] * TARGET_IMAGE_SIZE[1] * 3:
            print(f"Warning: Calculated input_dim ({input_dim_images}) doesn't match expected from TARGET_IMAGE_SIZE * 3.")
            print("This might happen if some images were not RGB or had issues during loading.")

        y_images = None # Or load your labels if available

        y_pred_final_images, trained_idec_model_images = train_deep_joint_clustering(
            X_images_flattened,
            y=y_images,
            n_clusters=N_CLUSTERS_IMAGES,
            input_dim_param=input_dim_images,
            latent_dim=LATENT_DIM_IMAGES,
            ae_hidden_dims=AE_HIDDEN_DIMS_IMAGES,
            ae_epochs=AE_EPOCHS_IMAGES,
            ae_batch_size=AE_BATCH_SIZE_IMAGES,
            joint_epochs=JOINT_EPOCHS_IMAGES,
            joint_batch_size=JOINT_BATCH_SIZE_IMAGES,
            gamma=GAMMA_IMAGES,
            tol=0.001,
            random_seed=42
        )

        print("\n--- Final Image Clustering Results ---")
        print(f"Predicted cluster assignments shape: {y_pred_final_images.shape}")
        # You can now map y_pred_final_images back to image_filenames
        # for fname, cluster_id in zip(image_filenames, y_pred_final_images):
        #     print(f"Image: {fname}, Cluster ID: {cluster_id}") # Uncomment to see all

        # Print first few and summary
        for i in range(min(10, len(image_filenames))):
            print(f"Image: {image_filenames[i]}, Cluster ID: {y_pred_final_images[i]}")
        if len(image_filenames) > 10:
            print("...")
        
        unique_clusters, counts = np.unique(y_pred_final_images, return_counts=True)
        print("\nCluster distribution:")
        for cluster_id, count in zip(unique_clusters, counts):
            print(f"Cluster {cluster_id}: {count} images")


        if y_images is not None and len(y_images) == len(X_images_flattened):
            final_nmi_images = normalized_mutual_info_score(y_images, y_pred_final_images)
            final_ari_images = adjusted_rand_score(y_images, y_pred_final_images)
            print(f"Final NMI for images: {final_nmi_images:.4f}")
            print(f"Final ARI for images: {final_ari_images:.4f}")
        else:
            print("\nGround truth labels (y_images) not provided, so NMI/ARI are not calculated.")
    else:
        print("Exiting due to image loading/processing issues.")

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!! intensive and likely suboptimal. Consider using a Convolutional      !!!
!!! Autoencoder (CAE) for better results and efficiency with images.     !!!
!!! You may encounter memory issues or very long training times.         !!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

Loading images from: ./storage/clean/blood_cell/segmenter


KeyboardInterrupt: 