# Define our functions

In [1]:
import math
import matplotlib.pyplot as plt
import cv2
import numpy as np
import scipy.constants as sc
from sklearn.cluster import KMeans
from skimage import measure
import random
from collections import deque
import os
import math
import PIL
from scipy.fft import fft2, ifft2, fftshift
from joblib import Parallel, delayed
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter  
from skimage import io
from cupy_common import check_cupy_available
gpu_accelerated = check_cupy_available()

In [2]:
# Inputs the raw AFM image, binarizes the image and outputs the binarized image
class Preprocess:
    def __init__(self, directory):
        self.global_thresh_value = 350
        self.adaptive_thresh_window_size = 25
        self.adaptive_thresh_C = 2
        self.morph_kernel_size = 2
        self.morph_iterations = 1
        self.directory = directory

    def read_image(self):
        self.image_path = f'./inputs/{directory}.png'
        return cv2.imread(self.image_path, cv2.IMREAD_GRAYSCALE)

    def global_threshold(self, image):
        _, global_thresh_mask = cv2.threshold(image, self.global_thresh_value, 255, cv2.THRESH_BINARY)
        return global_thresh_mask

    def adaptive_threshold(self, image):
        adaptive_mask = cv2.adaptiveThreshold(
            image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV,
            self.adaptive_thresh_window_size, self.adaptive_thresh_C)
        return adaptive_mask

    def combine_masks(self, global_mask, adaptive_mask):
        return cv2.bitwise_or(global_mask, adaptive_mask)

    def morphological_operations(self, combined_mask):
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.morph_kernel_size, self.morph_kernel_size))
        return cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel, iterations=self.morph_iterations)

    def save_cleaned_image(self, cleaned_image):
        cv2.imwrite(self.output_path, cleaned_image)

    def process_image(self):
        image = self.read_image()
        global_mask = self.global_threshold(image)
        adaptive_mask = self.adaptive_threshold(image)
        combined_mask = self.combine_masks(global_mask, adaptive_mask)
        cleaned_image = self.morphological_operations(combined_mask)
        
        return cleaned_image

In [3]:
# Inputs the binarize AFM image, applies the color wheel the image and outputs the binarized image
class ColorWheelProcessor:
    def __init__(self, binarized_image, gpu_accelerated):
        self.binarized_image = binarized_image
        self.sym = 2
        self.color = 5
        self.brightness = 1
        self.contrast = 5
        self.gpu_accelerated = gpu_accelerated

    def process_image(self):
        data = np.array(self.binarized_image)
        clrwhl = self._bldclrwhl(data.shape[0], data.shape[1], self.sym)
        imnp = self._nofft(clrwhl, data, data.shape[1], data.shape[1])
        imnp = imnp - np.min(imnp)
        imnp = imnp / np.max(imnp) * 255
        rgb2 = Image.fromarray(np.uint8(imnp))
        img2 = rgb2.filter(ImageFilter.GaussianBlur(radius=0.5))
        converter = ImageEnhance.Color(img2)
        img2 = converter.enhance(self.color)
        converter = ImageEnhance.Brightness(img2)
        img2 = converter.enhance(self.brightness)
        converter = ImageEnhance.Contrast(img2)
        img2 = converter.enhance(self.contrast)
        return img2

    def _bldclrwhl(self, nx, ny, sym):
        cda = cp.ones((nx, ny, 2))
        cx = cp.linspace(-nx, nx, nx)
        cy = cp.linspace(-ny, ny, ny)
        cxx, cyy = cp.meshgrid(cy, cx)
        czz = (((cp.arctan2(cxx, cyy) / math.pi) + 1.0) / 2.0) * sym
        cd2 = cp.dstack((czz, cda))
        carr = cd2
        chi = cp.floor(carr[..., 0] * 6)
        f = carr[..., 0] * 6 - chi
        p = carr[..., 2] * (1 - carr[..., 1])
        q = carr[..., 2] * (1 - f * carr[..., 1])
        t = carr[..., 2] * (1 - (1 - f) * carr[..., 1])
        v = carr[..., 2]
        chi = cp.stack([chi, chi, chi], axis=-1).astype(cp.uint8) % 6
        out = cp.choose(
            chi, cp.stack([cp.stack((v, t, p), axis=-1),
                           cp.stack((q, v, p), axis=-1),
                           cp.stack((p, v, t), axis=-1),
                           cp.stack((p, q, v), axis=-1),
                           cp.stack((t, p, v), axis=-1),
                           cp.stack((v, p, q), axis=-1)]))
        if self.gpu_accelerated:
            return cp.asnumpy(out)
        else:
            return out

    def _nofft(self, whl, img, nx, ny):
        imnp = cp.array(img)
        fimg = cp.fft.fft2(imnp)
        whl = cp.fft.fftshift(whl)
        proimg = cp.zeros((nx, ny, 3))
        comb = cp.zeros((nx, ny, 3), dtype=complex)
        magnitude = cp.repeat(np.abs(fimg)[:, :, np.newaxis], 3, axis=2)
        phase = cp.repeat(np.angle(fimg)[:, :, np.newaxis], 3, axis=2)
        proimg = whl * magnitude
        comb = cp.multiply(proimg, cp.exp(1j * phase))
        for n in range(3):
            proimg[:, :, n] = cp.real(cp.fft.ifft2(comb[:, :, n]))
            proimg[:, :, n] = proimg[:, :, n] - cp.min(proimg[:, :, n])
            proimg[:, :, n] = proimg[:, :, n] / cp.max(proimg[:, :, n])

        if self.gpu_accelerated:
            return cp.asnumpy(proimg)
        else:
            return proimg

In [4]:
# Inputs the color wheel image, subtracts the binarized image from it resulting in a colored, one phase image
class PhaseSubtraction:
    def __init__(self, input_image, binarized_image):
        self.input_image = input_image
        self.binarized_image = binarized_image

    def subtract_black_from_input(self):
        # Convert images to NumPy arrays
        input_array = np.array(self.input_image)
        binned_array = np.array(self.binarized_image)

        # Ensure the mask has the same number of channels as the input image
        if len(binned_array.shape) == 2:
            binned_array = np.expand_dims(binned_array, axis=-1)

        # Subtract black parts of the mask from the input image
        result_array = np.where(binned_array == 0, input_array, 255)

        # Create a PIL Image from the result array
        result_image = Image.fromarray(result_array.astype(np.uint8))
        return result_image

In [5]:
# Inputs the one phase image, creates masks for each orientation, outputs the filtered masks
class ColorMaskProcessor:
    def __init__(self, input_image, output_path):
        self.input_image = input_image
        self.output_path = output_path

    def create_color_mask(self, num_clusters):
        # Convert the image to a NumPy array
        img_array = np.array(self.input_image)

        # Reshape the array to a list of RGB values
        reshaped_array = img_array.reshape((-1, 3))

        # Use k-means clustering to group similar colors
        kmeans = KMeans(n_clusters=num_clusters, random_state=42)
        kmeans.fit(reshaped_array)

        # Get the labels assigned to each pixel
        labels = kmeans.labels_

        # Reshape the labels back to the original image shape
        segmented_image = labels.reshape(img_array.shape[:2])

        # Create a mask for each cluster
        masks = [(segmented_image == i) for i in range(num_clusters)]

        return masks

    def save_masks_as_images(self, image, masks):
        # Save each mask as a separate image to the current sample folder
        for i, mask in enumerate(masks):
            color_mask = np.zeros_like(image)
            color_mask[mask] = image[mask]
            mask_image = Image.fromarray(color_mask)
            #mask_image.save(os.path.join(self.output_path, f"mask_{i}.tiff"))
            
            # Identify non-black pixels
            mask_array = np.array(mask_image)
            non_black_pixels = (mask_array[:, :, :3] > 0).any(axis=2)

            # Remove small clusters
            non_black_pixels = self.remove_small_clusters(non_black_pixels, min_size=15)

            # Create a new image with the modified non-black pixels
            result_img_array = np.zeros_like(mask_array)
            result_img_array[non_black_pixels] = mask_array[non_black_pixels]
            
            result_img = Image.fromarray(result_img_array, self.input_image.mode)
            result_img.save(os.path.join(self.output_path, f"filtered_mask_{i}.tiff"))


    def remove_small_clusters(self, image, min_size):
        labeled_image, num_labels = measure.label(image, connectivity=2, return_num=True)
        for label in range(1, num_labels + 1):
            cluster_size = np.sum(labeled_image == label)
            if cluster_size < min_size:
                image[labeled_image == label] = 0  # Set pixels in the small cluster to black
        return image

    def process_image(self):
        # Create color masks
        masks = self.create_color_mask(num_clusters=5)

        # Save each mask as a separate image
        image = self.input_image
        self.save_masks_as_images(np.array(image), masks)

In [6]:
from collections import deque
from PIL import Image
import random

class GrainFinder:
    def __init__(self, mask, directory):
        # Initialize GrainFinder object with mask and directory paths
        self.mask = mask
        self.directory = directory
        # Open the image corresponding to the mask
        self.image = Image.open(f'./{self.directory}/filtered_mask_{self.mask}.tiff')
        self.pixels = self.image.load()  # Load pixel data of the image
        self.output_path = f'./{self.directory}/grains/Mask_{self.mask}.tiff'
        self.grouped = set()  # Set to store grouped pixel coordinates
        self.group_id = 1  # Identifier for each group
        self.group_sizes = {}  # Dictionary to store sizes of each group

    def group_pixels(self, x, y):
        # Function to group adjacent pixels with the same color
        queue = deque([(x, y)])  # Initialize a queue with starting pixel coordinates
        current_group_size = 0  # Initialize size counter for the current group

        # Breadth-first search to traverse adjacent pixels
        while queue:
            current_x, current_y = queue.popleft()  # Get coordinates of the current pixel from the queue

            # Check if current pixel is out of bounds or already visited
            if (
                current_x < 0
                or current_y < 0
                or current_x >= self.image.width
                or current_y >= self.image.height
            ):
                continue

            pixel = self.image.getpixel((current_x, current_y))  # Get color of the current pixel
            if pixel == (0, 0, 0) or (current_x, current_y) in self.grouped:
                continue  # Skip black pixels and already visited pixels

            self.grouped.add((current_x, current_y))  # Mark current pixel as visited
            self.image.putpixel((current_x, current_y), self.group_id)  # Assign group id to the current pixel

            current_group_size += 1  # Increment group size counter

            # Add adjacent pixels to the queue for further processing
            for i in range(-8, 8):
                for j in range(-8, 8):
                    queue.append((current_x + i, current_y + j))

        # Store size of the current group
        self.group_sizes[self.group_id] = current_group_size
        self.group_id += 1  # Increment group id for the next group

    def process_image(self):
        # Check if the image has no valid pixels (only black pixels)
        black_pixel_count = sum(1 for x in range(self.image.width) for y in range(self.image.height) if self.pixels[x, y] == (0, 0, 0))
        if black_pixel_count == self.image.width * self.image.height:
            print("Image contains only black pixels.")
            # Save an empty file
            with open(f'./{self.directory}/grains/grains_{self.mask}.txt', "w") as file:
                pass
            return

        # Group pixels
        for x in range(self.image.width):
            for y in range(self.image.height):
                pixel = self.pixels[x, y]
                if pixel != (0, 0, 0) and (x, y) not in self.grouped:
                    self.group_pixels(x, y)

        # Calculate average group size
        average_size = sum(self.group_sizes.values()) / len(self.group_sizes) if self.group_sizes else 0

        # Filter out groups deviating more than 150% from the average and smaller than the average size
        filtered_group_ids = [group_id for group_id, size in self.group_sizes.items() if size < average_size * 1.5 and size < average_size]

        # Remove filtered groups
        self.group_sizes = {group_id: size for group_id, size in self.group_sizes.items() if group_id not in filtered_group_ids}

        # Generate random colors for each group
        colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) for _ in range(self.group_id)]

        # Colorize the image based on group ids
        for x in range(self.image.width):
            for y in range(self.image.height):
                pixel = self.image.getpixel((x, y))
                if pixel != (0, 0, 0):
                    group_id = pixel[0]  # Use the red channel as the group ID
                    self.image.putpixel((x, y), colors[group_id])

        # Save the colorized image
        self.image.save(self.output_path)

        # Save the number of pixels in each group to a text document
        with open(f'./{self.directory}/grains/grains_{self.mask}.txt', "w") as file:
            for group_id, size in self.group_sizes.items():
                file.write(f"Group {group_id}: {size} pixels\n")

In [8]:
%%time

# Here we will call our classes and loop through our samples
# If you need to change the directory addresses entirely, they are refrenced in cells 2 & 6.

if gpu_accelerated:
    print("Running on GPU")
    cp = __import__("cupy")
    gpu_accel=True
else:
    print("Running on CPU")
    gpu_accel=False
    cp = __import__("numpy")


parent_path = os.getcwd()
directories = ['A', 'B', 'C', 'D', 'E', 'F']
for directory in directories:
    
    # Output path for the cleaned image
    path = os.path.join(parent_path, directory)
    os.mkdir(path)
    grain_path = os.path.join(path, 'grains')
    os.mkdir(grain_path)
    
    # Create an instance of Preprocess and process the image
    binzarize_image = Preprocess(directory)
    # Call the instance to run the class
    binned_image = binzarize_image.process_image()
    binned_image = Image.fromarray(binned_image)
    binned_image.save(os.path.join(path, f"biined_{directory}.tiff"))
    
    # # Create an instance of ColorWheel
    # processor = ColorWheelProcessor(binned_image, gpu_accel)
    # # Call the instance to run the class
    # processed_img = processor.process_image()
    
    # # Create an instance of PhaseSubtraction
    # phase_sub = PhaseSubtraction(processed_img, binned_image)
    # one_phase = phase_sub.subtract_black_from_input()
    
    # # Create an instance of ColorMaskProcessor
    # mask_maker = ColorMaskProcessor(one_phase, path)
    # mask_maker.process_image()
    
    # for i in range(5):
    #     # Create an instance of GrainFinder
    #     grain_finder = GrainFinder(i, directory)
    #     grain_finder.process_image()

Running on CPU
CPU times: user 97.8 ms, sys: 82 µs, total: 97.9 ms
Wall time: 96.7 ms


In [19]:
from PIL import Image

# Load the images
sample = 'F'
image1 = Image.open(f"./output_test/{sample}/grains/Mask_1.tiff")
image2 = Image.open(f"./output_test/{sample}/grains/Mask_2.tiff")
image3 = Image.open(f"./output_test/{sample}/grains/Mask_3.tiff")
image4 = Image.open(f"./output_test/{sample}/grains/Mask_4.tiff")
base_image = Image.open(f"./{sample}/biined_{sample}.tiff")

# Convert images to RGBA if they are not
image1 = image1.convert("RGBA")
image2 = image2.convert("RGBA")
image3 = image3.convert("RGBA")
image4 = image4.convert("RGBA")
base_image = base_image.convert("RGBA")

# Function to overlay images
def overlay_images(base, *images):
    for image in images:
        for x in range(image.width):
            for y in range(image.height):
                r, g, b, a = image.getpixel((x, y))
                if (r, g, b) != (0, 0, 0):  # Check if the pixel is not black
                    base.putpixel((x, y), (r, g, b, a))
    return base

# Overlay the images
result_image = overlay_images(base_image, image1, image2, image3, image4)

# Save the result
result_image.save(f"{sample}_overlay.tiff")
