In [3]:
import numpy as np
from sklearn.cluster import KMeans
import matplotlib
matplotlib.use('Agg')  # or 'Agg', 'Qt5Agg', etc.
from matplotlib import pyplot as plt

import cv2
from scipy.ndimage import label

from GMM import GaussianMM
from Utils import HUES, image_load

#input image from the folder
#file_name = input('Enter Image name: ')
file_name = "DJI_20240308125440_0001_D.JPG"
image_path = 'images/'+file_name
image = image_load(image_path)

#check as RGB image
try:
    image_height, image_width, image_channels = image.shape
    
    image_type = 0
    #create the vector as np array
    image_pixels = np.reshape(image, (-1, image_channels))

#using greyscale image
except ValueError as e:
    image_height, image_width = image.shape

    image_type = 1
    #create the vector as np array
    image_pixels = np.reshape(image, (-1, 1))

# Input number of classes
#K_param = int(input('Input K: '))
K_param = 2


#Apply K-Means for initial weights, covariance for GMM
k_means = KMeans(n_clusters=K_param)
labels = k_means.fit_predict(image_pixels)
cl_means = k_means.cluster_centers_

cl_weights = []
cl_covariance = []

for i in range(K_param):
    dt = np.array([image_pixels[j, :] for j in range(len(labels)) if labels[j] == i]).T
    cl_covariance.append(np.cov(dt))
    cl_weights.append(dt.shape[1] / float(len(labels)))

#Create a GMM object
gmm = GaussianMM(cl_means, cl_covariance, cl_weights, K_param, image_type)

#Apply EM Algorithm
logs = []
prev_log_likelihood = None

#setting max iterations as 50
for i in range(100):
    pos_prob, log_likelihood = gmm.expectation(image_pixels) # E-step
    gmm.maximization(image_pixels, pos_prob)   # M-step
    print(f"Iteration {i+1} - Log_Likelihood: {log_likelihood}")
    logs.append(log_likelihood)
    
    #difference of logs to be negligible
    if prev_log_likelihood != None and abs(log_likelihood - prev_log_likelihood) < 1e-10:
        break
    prev_log_likelihood = log_likelihood



Iteration 1 - Log_Likelihood: -282136774.88648885
Iteration 2 - Log_Likelihood: -281776070.7284787
Iteration 3 - Log_Likelihood: -281683557.4990346
Iteration 4 - Log_Likelihood: -281641472.3430676
Iteration 5 - Log_Likelihood: -281614554.3263725
Iteration 6 - Log_Likelihood: -281593717.6802985
Iteration 7 - Log_Likelihood: -281576039.63724387
Iteration 8 - Log_Likelihood: -281560450.55156887
Iteration 9 - Log_Likelihood: -281546534.19376826
Iteration 10 - Log_Likelihood: -281534129.40401125
Iteration 11 - Log_Likelihood: -281523171.6954679
Iteration 12 - Log_Likelihood: -281513617.7240136
Iteration 13 - Log_Likelihood: -281505407.545726
Iteration 14 - Log_Likelihood: -281498450.7781008
Iteration 15 - Log_Likelihood: -281492628.7883171
Iteration 16 - Log_Likelihood: -281487805.2784564
Iteration 17 - Log_Likelihood: -281483838.76855296
Iteration 18 - Log_Likelihood: -281480592.98796606
Iteration 19 - Log_Likelihood: -281477943.70199686
Iteration 20 - Log_Likelihood: -281475782.13207346
I

### If you want to display all the connected components greater than the threshold pixel values, run this

In [None]:
#Show Result
pos_prob, log_likelihood = gmm.expectation(image_pixels)
map_pos_prob = np.reshape(pos_prob, (image_height, image_width, K_param))

# Get the cluster assignments
pixel_clusters = np.argmax(pos_prob, axis=-1)

# Reshape to the image dimensions
cluster_map = np.reshape(pixel_clusters, (image_height, image_width))

# Initialize a list to store the separate images for each component
component_images = []

# Iterate over each cluster to find connected components
for cluster_index in range(K_param):
    # Create a binary mask for the current cluster
    binary_mask = (cluster_map == cluster_index).astype(np.uint8)
    
    # Optional: Apply morphological operations to clean up the mask
    kernel = np.ones((3,3),np.uint8)
    binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
    
    # Find connected components in the binary mask
    labeled_mask, num_features = label(binary_mask, structure=np.ones((3, 3)))
    
    # Process the connected components
    for component_id in range(1, num_features + 1):
        component_mask = (labeled_mask == component_id).astype(np.uint8)  # Select the current component
        
        # Create an empty image to store the connected component
        component_image = np.zeros_like(image, dtype=np.uint8)
        
        # If RGB image, apply the mask to all channels
        if image_type == 0:
            for c in range(image_channels):
                component_image[:, :, c] = image[:, :, c] * component_mask
        # If grayscale, apply the mask to the single channel
        else:
            component_image = image * component_mask
        
        # Count the number of non-zero pixels in the component
        non_zero_pixel_count = np.count_nonzero(component_image)
        # Save the component only if it has more than 100 non-zero pixels
        if non_zero_pixel_count > 1000000:
            component_images.append(component_image)
            print(non_zero_pixel_count)
            # Save the component image with correct data type conversion
            output_filename = f'results/component_{cluster_index}_{component_id}.png'
            plt.imsave(output_filename, component_image.astype(np.uint8))
            print(f"Saved: {output_filename}")
        
        # Only save the first valid component and break if you want only one per cluster
        #break

# Optionally, display the first connected component image that was saved
if component_images:
    plt.imshow(component_images[0])
    plt.title('First Valid Connected Component')
    plt.axis('off')
    plt.show()

### If you want only the top 10 components run this code

In [None]:
#Show Result
import heapq
pos_prob, log_likelihood = gmm.expectation(image_pixels)
map_pos_prob = np.reshape(pos_prob, (image_height, image_width, K_param))


# Get the cluster assignments
pixel_clusters = np.argmax(pos_prob, axis=-1)

# Reshape to the image dimensions
cluster_map = np.reshape(pixel_clusters, (image_height, image_width))

# Priority queue to store only the top 10 components
# It will store tuples of the form: (-pixel_count, component_image, cluster_index, component_id)
# We use negative pixel_count for max-heap behavior in min-heap
top_components = []

# Iterate over each cluster to find connected components
for cluster_index in range(K_param):
    # Create a binary mask for the current cluster
    binary_mask = (cluster_map == cluster_index).astype(np.uint8)
    
    # Optional: Apply morphological operations to clean up the mask
    kernel = np.ones((3,3),np.uint8)
    binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
    
    # Find connected components in the binary mask
    labeled_mask, num_features = label(binary_mask, structure=np.ones((3, 3)))
    
    # Process the connected components
    for component_id in range(1, num_features + 1):
        component_mask = (labeled_mask == component_id).astype(np.uint8)  # Select the current component
        
        # Create an empty image to store the connected component
        component_image = np.zeros_like(image, dtype=np.uint8)
        
        # If RGB image, apply the mask to all channels
        if image_type == 0:
            for c in range(image_channels):
                component_image[:, :, c] = image[:, :, c] * component_mask
        # If grayscale, apply the mask to the single channel
        else:
            component_image = image * component_mask
        
        # Count the number of non-zero pixels in the component
        non_zero_pixel_count = np.count_nonzero(component_image)
        
        # Only consider components with more than 1,000,000 non-zero pixels
        if non_zero_pixel_count > 1000000:
            # Maintain a heap of the top 10 components
            if len(top_components) < 10:
                heapq.heappush(top_components, (-non_zero_pixel_count, component_image, cluster_index, component_id))
            else:
                heapq.heappushpop(top_components, (-non_zero_pixel_count, component_image, cluster_index, component_id))

# Save the top 10 components
for idx, (_, component_image, cluster_index, component_id) in enumerate(top_components):
    output_filename = f'results/top_component_{idx+1}_cluster_{cluster_index}_id_{component_id}.png'
    plt.imsave(output_filename, component_image.astype(np.uint8))
    print(f"Saved: {output_filename}")