In [11]:
import import_ipynb
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from skimage import data, util, filters
from sklearn.datasets import load_sample_image
from scipy.ndimage import gaussian_filter

def crop_png(Brightfield):
    # Plot Image
    B = imageio.imread(Brightfield)
    
    # Assuming im1 and im2 are your images
    
    # Extracting one frame from im2
    frame = B[3, :, :]
    
    # Cropping the frame to match the width of im1
    frame_cropped = frame[:, :428]
    
    # Converting the cropped array to image
    im2_cropped = Image.fromarray(frame_cropped)

    return im2_cropped

def trpos_conversion(filename):
    cx = 5  # Column number of x position in file
    cy = 6  # Column number of y position in file
    pixelsize = 117  # Pixel size in nm
    
    # Import Localizations and BF
    TR = np.loadtxt(filename, delimiter=',', skiprows=1)
    trnum = np.unique(TR[:, 2])
    trpos = np.zeros((len(trnum), 2))
    
    for i in range(len(trnum)):
        idx = np.where(TR[:, 2] == trnum[i])[0][0]
        trpos[i, 0] = TR[idx, 4]
        trpos[i, 1] = TR[idx, 5]
    
    trpos /= pixelsize
    return trpos

def get_array_swapped(array):
    return array[:, ::-1]

def display_plots(image, trpos):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))  # Create 1x2 grid of subplots

    # First subplot: original image with points
    axes[0].imshow(image, cmap='gray')
    axes[0].scatter(trpos[:, 0], trpos[:, 1], s=2, color='r')
    axes[0].set_xlabel('X-axis (0 to 428)')
    axes[0].set_ylabel('Y-axis (600 to 0)')
    axes[0].set_title('Original Image with Points')
    
    # Second subplot: flipped points
    axes[1].scatter(trpos[:, 0], trpos[:, 1], s=2, color='red')
    axes[1].set_xlim(0, 428)
    axes[1].set_ylim(0, 684)
    axes[1].set_xlabel('X-axis (0 to 428)')
    axes[1].set_ylabel('Y-axis (600 to 0)')
    axes[1].invert_yaxis()  # Invert y-axis direction for visual effect
    axes[1].set_aspect('equal', adjustable='box')
    axes[1].set_title('Flipped Points')
    
    # Adjust layout to make sure subplots don't overlap
    plt.tight_layout()
    plt.show()


### Fourier Image transformer 

We use this to split , modify and analyze the intensity spectrum magnitudes of the image and mask in order to identify the features that the mask manages to retain.

In [15]:
class FourierTransformer():
    def __init__(self):
        
        self.image = any
        self.mask = any
        self.masked_image = any
        self.image_gray = any
        
        self.f_transform_shifted = any
        self.log_magnitude_spectrum = any
        
        # Parameters for shifting
        self.shift_x = 0  # Horizontal shift (positive to the right, negative to the left)
        self.shift_y = 0 

        # Display params
        self.show_log_mag_spectrum = True
        self.show_radial_intensity = True

        # Image Params
        self.log_magnitude_spectrum_original = any
        self.radial_intensity_original = any
        self.log_magnitude_spectrum_masked = any
        self.radial_intensity_masked = any

    def __fit__(self, keep=True):

        # ================= Original Image ===========
        # perform the the fourier transformation
        f_transform_original = self.perform_fft(self.image_gray)
        log_magnitude_spectrum_original = self.calculate_magnitude_spectrum(f_transform_original)
        # ================= Radial Intensity ===========
        radial_intensity_original = self.radial_profile(log_magnitude_spectrum_original)


        # ================= Masked Image ===========
        if self.mask is None:
            # Apply Gaussian filter to smooth the image (optional, for better segmentation)
            smoothed_image = gaussian_filter(self.image_gray, sigma=2)
            # Apply global thresholding to create a binary mask
            threshold_value = filters.threshold_otsu(smoothed_image)  # Using Otsu's method to determine the threshold
            segmented_mask = smoothed_image > threshold_value  # Binary mask
            masked_image = np.where(segmented_mask, 0 , self.image_gray)
            # Perform FFT and compute magnitude spectrum for the mask
            f_transform_mask = self.perform_fft(masked_image)
            log_magnitude_spectrum_masked = self.calculate_magnitude_spectrum(f_transform_mask)
            radial_intensity_masked = self.radial_profile(log_magnitude_spectrum_masked)

        else:
            # Perform FFT and compute magnitude spectrum for the mask
            if keep :
                masked_image = np.where(self.mask, self.image_gray, 0)
            else:
                masked_image = np.where(self.mask, 0 , self.image_gray)
            f_transform_mask = self.perform_fft(masked_image)
            log_magnitude_spectrum_masked = self.calculate_magnitude_spectrum(f_transform_mask)
            radial_intensity_masked = self.radial_profile(log_magnitude_spectrum_masked)
        
        self.__set_image_params__(
            log_magnitude_spectrum_original,
            radial_intensity_original, 
            masked_image,
            log_magnitude_spectrum_masked,
            radial_intensity_masked
        )
        return self

    def __set__(self, image, mask=None):
        self.image = image
        self.mask = mask
        self.masked_image = any
        self.image_gray = image
        
        self.f_transform_shifted = any
        self.log_magnitude_spectrum = any
        
        # Parameters for shifting
        self.shift_x = 0  # Horizontal shift (positive to the right, negative to the left)
        self.shift_y = 0 

        # Display params
        self.show_log_mag_spectrum = True
        self.show_radial_intensity = True

        # Image Params
        self.log_magnitude_spectrum_original = any
        self.radial_intensity_original = any
        self.log_magnitude_spectrum_masked = any
        self.radial_intensity_masked = any

        return self

    def __set_shift_x_y___(self, shift_x, shift_y):
        this.shift_x = shift_x  
        this.shift_y = shift_y

    def __get_image_params__(self):
        return dict(
            Log_Mag_Spectrum_Original= self.log_magnitude_spectrum_original,
            Radial_Intensity_Original= self.radial_intensity_original,
            Log_Mag_Spectrum_Mask= self.log_magnitude_spectrum_masked,
            Radial_Intensity_Mask= self.radial_intensity_masked,
            Masked_Image = self.masked_image
        )
    def __set_image_params__(
        self, 
        log_magnitude_spectrum_original,
        radial_intensity_original, 
        masked_image,
        log_magnitude_spectrum_masked,
        radial_intensity_masked
    ):
        self.log_magnitude_spectrum_original = log_magnitude_spectrum_original
        self.radial_intensity_original = radial_intensity_original
        self.log_magnitude_spectrum_masked = log_magnitude_spectrum_masked
        self.radial_intensity_masked = radial_intensity_masked
        self.masked_image = masked_image
          

    def __set_display_params__(self, show_log_mag_spectrum,show_radial_intensity):
        self.show_log_mag_spectrum = show_log_mag_spectrum
        self.show_radial_intensity = show_radial_intensity

    def perform_fft(self,image):
        """Perform Fourier Transformation and shift the zero frequency component to the center."""
        f_transform = np.fft.fft2(image)
        f_transform_shifted = np.fft.fftshift(f_transform)
        return f_transform_shifted
    
    def calculate_magnitude_spectrum(self,f_transform_shifted):
        """Compute the magnitude spectrum and apply logarithmic scaling."""
        magnitude_spectrum = np.abs(f_transform_shifted)
        log_magnitude_spectrum = np.log1p(magnitude_spectrum)
        return log_magnitude_spectrum

    def create_circular_mask(self,shape, radius, shift_x=0, shift_y=0):
        h, w = shape[:2]
        center = (h // 2 + shift_y, w // 2 + shift_x)  # Apply shifts to the default center
    
        Y, X = np.ogrid[:h, :w]
        dist_from_center = np.sqrt((X - center[1])**2 + (Y - center[0])**2)
    
        mask = dist_from_center <= radius
        return mask

    def radial_profile(self,image):
        """Compute the radial profile of an image."""
        center = np.array(image.shape) // 2
        Y, X = np.ogrid[:image.shape[0], :image.shape[1]]
        radii = np.sqrt((X - center[1])**2 + (Y - center[0])**2)
        radii = radii.astype(np.int64)
        
        tbin = np.bincount(radii.ravel(), image.ravel())
        nr = np.bincount(radii.ravel())
        radial_prof = tbin / nr
        radial_prof = radial_prof[1:]  # Skip the zero radius
        
        return radial_prof

    def plot_image_analysis(self):
        plt.figure(figsize=(18, 9))
        
        plot_items = [
            {'data': self.image_gray, 'title': 'Original Image', 'type': 'image'},
            {'data': self.masked_image, 'title': 'Image with Mask', 'type': 'image'}
        ]
        
        if self.show_log_mag_spectrum:
            plot_items.extend([
                {'data': self.log_magnitude_spectrum_original, 'title': 'Magnitude Spectrum (Original)', 'type': 'image'},
                {'data': self.log_magnitude_spectrum_masked, 'title': 'Magnitude Spectrum (Masked)', 'type': 'image'}
            ])
        
        if self.show_radial_intensity:
            plot_items.append({
                'data': [self.radial_intensity_original, self.radial_intensity_masked],
                'title': 'Spectrum Intensity Profile', 'type': 'plot'
            })
    
        for idx, item in enumerate(plot_items, start=1):
            plt.subplot(3, 3, idx)
            if item['type'] == 'image':
                plt.imshow(item['data'], cmap='gray')
                plt.title(item['title'])
                plt.axis('off')
            elif item['type'] == 'plot':
                plt.plot(item['data'][0], label='Original Image', color='blue')
                plt.plot(item['data'][1], label='Mask Image', color='red')
                plt.title(item['title'])
                plt.xlabel('Radius')
                plt.ylabel('Intensity')
                plt.legend()
        
        plt.tight_layout()
        plt.show()


        

In [3]:
# import numpy as np
# import matplotlib.pyplot as plt
# from sklearn.datasets import load_sample_image
# from sklearn.cluster import KMeans

# # Load the sample image
# i_image = load_sample_image('flower.jpg')

# # Normalize image data to be between 0 and 1
# i_image = np.array(i_image, dtype=np.float64) / 255

# # Reshape the image to a 2D array of pixels
# w, h, d = i_image.shape
# image_array = np.reshape(i_image, (w * h, d))

# # Perform k-means clustering to segment the image into 2 clusters
# kmeans = KMeans(n_clusters=2, random_state=0).fit(image_array)
# labels = kmeans.predict(image_array)

# # Reshape the labels to the shape of the original image
# labels = labels.reshape(w, h)

# # Create a mask for the flower (assuming the flower is the larger cluster)
# mask = labels == labels.max()

# # Highlight the flower in the original image
# highlighted_image = np.copy(i_image)
# highlighted_image[~mask] = -1  # Set the background pixels to black

# # Plot the original image, the mask, and the highlighted image
# plt.figure(figsize=(12, 4))

# plt.subplot(1, 3, 1)
# plt.imshow(i_image)
# plt.title('Original Image')

# plt.subplot(1, 3, 2)
# plt.imshow(mask, cmap='gray')
# plt.title('Mask')

# plt.subplot(1, 3, 3)
# plt.imshow(highlighted_image)
# plt.title('Highlighted Image')

# plt.show()


# image = load_sample_image('flower.jpg')
# print("Image Type", type(image))
# print("Shape Image", image.shape)
# transformer = FourierTransformer()
# transformer.__set__(image, mask).__fit__().plot_image_analysis()
# #print(transformer.__get_image_params__())

### Mask to Image, information retaining metric

We use this metric class to compute/analyze the best mask that retains the same features from an image in order to classify which mask belongs to which image

In [14]:
class InformationRetainMetric():
    def __init__(self):
        
        self.image = any
        self.mask = any

    