# Import Statements

In [37]:
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from scipy.ndimage import median_filter


# import images
image_dir = 'Input Images/'
export_dir = 'Output Images/'
cat_dir = 'Category Images/'

# Analysis Modules

In [49]:
class Cat_Generator:
    
    def __init__(self, cat_dir): 
        self.cat_dict = self.cat_generator(cat_dir)
        
    #takes a dictionary key corresponding to an image and returns a tuple with the average RGB values in the image
    def avg_rgb(self, cat_image):
        r = np.median(cat_image[:,:,0].ravel()).squeeze()
        g = np.median(cat_image[:,:,1].ravel()).squeeze()
        b = np.median(cat_image[:,:,2].ravel()).squeeze()
    
        return (r,g,b)
    
    #takes a directory containing category images and returns key:value pairs for each category/image
    def cat_generator(self, cat_dir):
        cat_dict = {}
        color_placeholder = (0,0,0)
        for file in os.listdir(cat_dir): 
            cat_image = np.asarray(Image.open(cat_dir + file))
            cat_rgb = self.avg_rgb(cat_image)
            cat_dict[os.path.basename(file)] = [cat_rgb, color_placeholder]
        
        return cat_dict
    
    #allows a user to set the color replacement rgb values associated with each category
    def cat_color(self, cat_key, rgb):
        self.cat_dict[cat_key][1] = rgb
        
    #allow access to the dictionary
    def d(self):
        return self.cat_dict

In [50]:
class Histology_CAT: 
    
    def __init__(self, image_dir, export_dir, cat_dict):
        
        self.export_dir = export_dir #directory where processed images should be deposited
        self.cat_dict = cat_dict #dictionary where each key:value pair corresponds to a classification category
        self.image_key = None #filename of image currently being processed (gets reassigned during analysis)
        
        #generate dictionary of images from the directory of images
        self.image_dict = {}
        for file in os.listdir(image_dir): 
            self.image_dict[os.path.basename(file)] = np.asarray(Image.open(image_dir + file))
            
        #generate indexed list of keys from category dictionary (cat_dict)
        self.cat_keys = []
        for key in self.cat_dict:
            self.cat_keys.append(key)
            
     
    #process all images in the image directory
    def kowalski(self, n_cluster, krnl_sz):
    
        for image_key in self.image_dict:
            self.image_key = image_key
            kmeans_arr = self.kmeans_quant(n_cluster)
            medfilt_arr = self.med_filter(kmeans_arr, krnl_sz)
            recolor_arr = self.naive_recolor(medfilt_arr)
            medfilt_arr = self.med_filter(recolor_arr, krnl_sz*2)
     
        return "All done, Skipper!"
    
    
    # Scatter-plot visualization in RGB space of the input image distribution
    def visualize_dist(self, image_key):
        r = self.image_dict[image_key][:,:,0].ravel()
        g = self.image_dict[image_key][:,:,1].ravel()
        b = self.image_dict[image_key][:,:,2].ravel()
        
        fig = plt.figure(figsize = (20,10))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(r,g,b, marker = '.')
        plt.show()
       
    
    #perform K Means clustering to quantize RGB colors in the image 
    def kmeans_quant(self, n_cluster):
        
        flat_arr = self.image_dict[self.image_key].reshape((-1,3))
        clustering = KMeans(n_clusters=n_cluster, random_state=42).fit(flat_arr)
        labels = clustering.labels_
        centers = clustering.cluster_centers_
        quantized_arr = centers[labels].reshape(self.image_dict[self.image_key].shape).astype('uint8')
        
        #Save image after application of K Means clustering quantization
        quantized_img = Image.fromarray(quantized_arr, mode = 'RGB') 
        quantized_img.save(self.export_dir + str(n_cluster) + " quanta "+ self.image_key, "JPEG")
    
        return quantized_arr
    
    
    #perform median filtering to smooth local blemishes while keeping edges
    def med_filter(self, image_arr, krnl_sz):
        
        filtered_arr = median_filter(image_arr, size = (krnl_sz, krnl_sz, 3))
        
        #Save image after application of median filtering
        filtered_img = Image.fromarray(filtered_arr, mode = 'RGB') 
        filtered_img.save(self.export_dir + str(krnl_sz) + " pixel median filter "+ self.image_key, "JPEG")
        
        return filtered_arr
        
        
    #given an image and a dictionary of categories, classify each pixel's category and then replace RGB values to match category RGB
    def naive_recolor(self, image_arr):
        #get stack of RGB differences for each category
        cat_stack = []
        for cat_key in self.cat_dict:
            cat_stack.append(self.diff(image_arr, cat_key))
        cat_stack = np.asarray(cat_stack)
        
        #split stack & recolor each pixel
        recolored_arr = []
        pxl_stack = np.split(cat_stack, np.shape(cat_stack)[1], axis = 1)
        for pxl in pxl_stack:
            cat_index = np.argmin(tuple(pxl.squeeze()))  
            recolored_arr.append(self.cat_dict[self.cat_keys[cat_index]][1])
        recolored_arr = np.asarray(recolored_arr)
        
        #reshape 1D array into 2D image
        pxl_x, pxl_y, _ = np.shape(self.image_dict[self.image_key])
        recolored_arr = np.reshape(recolored_arr, (pxl_x, pxl_y, 3))
        
        #Save image after application of naive categorical recoloring
        recolored_img = Image.fromarray(recolored_arr, mode = 'RGB') 
        recolored_img.save(self.export_dir + " naive categorical recoloring "+ self.image_key, "JPEG")
        
        return recolored_arr
        
  
    #returns a flattened array of the euclidian distance in RGB space of each pixel in the histology image compared to the average RGB of a given category
    def diff(self, image_arr, cat_key):
   
        r_diff = np.square(image_arr[:,:,0].ravel() - self.cat_dict[cat_key][0][0])
        g_diff = np.square(image_arr[:,:,1].ravel() - self.cat_dict[cat_key][0][1])
        b_diff = np.square(image_arr[:,:,2].ravel() - self.cat_dict[cat_key][0][2])
    
        return np.sqrt(r_diff + g_diff + b_diff) 

# Kowalski, Analysis!

In [53]:
cat_dict = Cat_Generator(cat_dir)
categorizer = Histology_CAT(image_dir, export_dir, cat_dict.d())

print(cat_dict.d())

#categorizer.visualize_dist('4109.jpg')
#for i in range(10,101):
 #   categorizer.kmeans_quant(i)

{'empty.jpg': [(248.0, 242.0, 242.0), (0, 0, 0)], 'implant.jpg': [(213.0, 199.0, 190.0), (0, 0, 0)], 'implant1.jpg': [(232.0, 218.0, 215.0), (0, 0, 0)], 'tissue.jpg': [(72.0, 31.0, 29.0), (0, 0, 0)], 'tissue1.jpg': [(159.0, 117.0, 118.0), (0, 0, 0)]}


In [54]:
cat_dict = Cat_Generator(cat_dir)
cat_dict.cat_color("implant.jpg",(0,0,0)) #black (implant)
cat_dict.cat_color("implant1.jpg",(0,0,0)) #black (implant)
cat_dict.cat_color("empty.jpg",(255,255,255)) #white (empty)
cat_dict.cat_color("tissue.jpg",(120,81,169)) #purple (tissue)
cat_dict.cat_color("tissue1.jpg",(120,81,169)) #purple (tissue)

categorizer = Histology_CAT(image_dir, export_dir, cat_dict.d())
print(categorizer.kowalski(n_cluster = 10, krnl_sz = 5))

All done, Skipper!
