# Import Statements

In [118]:
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt


# import images
image_dir = 'Input Images/'
export_dir = 'Output Images/'
cat_dir = 'Category Images/'

# Analysis Modules

In [72]:
class Cat_Generator:
    
    def __init__(self, cat_dir): 
        self.cat_dict = self.cat_generator(cat_dir)
        
    #takes a dictionary key corresponding to an image and returns a tuple with the average RGB values in the image
    def avg_rgb(self, cat_image):
        r = np.mean(cat_image[:,:,0].ravel())
        g = np.mean(cat_image[:,:,1].ravel())
        b = np.mean(cat_image[:,:,2].ravel())
    
        return (r,g,b)
    
    #takes a directory containing category images and returns key:value pairs for each category/image
    def cat_generator(self, cat_dir):
        cat_dict = {}
        color_placeholder = (0,0,0)
        for file in os.listdir(cat_dir): 
            cat_image = np.asarray(Image.open(cat_dir + file))
            cat_rgb = self.avg_rgb(cat_image)
            cat_dict[os.path.basename(file)] = [cat_rgb, color_placeholder]
        
        return cat_dict
    
    #allows a user to set the color replacement rgb values associated with each category
    def cat_color(self, cat_key, rgb):
        self.cat_dict[cat_key][1] = rgb
        
    #allow access to the dictionary
    def d(self):
        return self.cat_dict

In [138]:
class Histology_CAT: 
    
    def __init__(self, image_dir, export_dir, cat_dict):
        
        self.export_dir = export_dir #directory where processed images should be deposited
        self.cat_dict = cat_dict #dictionary where each key:value pair corresponds to a classification category
        
        #generate dictionary of images from the directory of images
        self.image_dict = {}
        for file in os.listdir(image_dir): 
            self.image_dict[os.path.basename(file)] = np.asarray(Image.open(image_dir + file))
            
        #generate indexed list of keys from category dictionary (cat_dict)
        self.cat_keys = []
        for key in self.cat_dict:
            self.cat_keys.append(key)
            
     
    #process all images in the image directory
    def kowalski(self):
    
        for image_key in self.image_dict:
            self.img_recolor(image_key).save(self.export_dir + image_key, "JPEG")
     
        return "All done, Skipper!"
            
            
    #given an image and a dictionary of categories, classify each pixel's category and then replace RGB values to match category RGB
    def img_recolor(self, image_key):
        #get stack of RGB differences for each category
        cat_stack = []
        for cat_key in self.cat_dict:
            cat_stack.append(self.diff(image_key, cat_key))
        cat_stack = np.asarray(cat_stack)
        
        #split stack & recolor each pixel
        recolored_img = []
        pxl_stack = np.split(cat_stack, np.shape(cat_stack)[1], axis = 1)
        for pxl in pxl_stack:
            cat_index = np.argmin(tuple(element.squeeze()))  
            recolored_img.append(self.cat_dict[self.cat_keys[cat_index]][1])
        recolored_img = np.asarray(recolored_img)
        
        #reshape 1D array into 2D image
        pxl_x, pxl_y, _ = np.shape(self.image_dict[image_key])
        np.reshape(recolored_img, (pxl_x, pxl_y, 3))
        recolored_img = Image.fromarray(recolored_img, mode = 'RGB') #problem is probably that i'm passing ints but it wants 8 bit values
        
        return recolored_img
        
  
    #returns a flattened array of the net difference in RGB of the histology image compared to the average RGB of a given category
    def diff(self, image_key, cat_key):
   
        r_diff = self.image_dict[image_key][:,:,0].ravel() - self.cat_dict[cat_key][0][0]
        g_diff = self.image_dict[image_key][:,:,1].ravel() - self.cat_dict[cat_key][0][1]
        b_diff = self.image_dict[image_key][:,:,2].ravel() - self.cat_dict[cat_key][0][2]
    
        return r_diff + g_diff + b_diff

# Kowalski, Analysis!

In [139]:
cat_dict = Cat_Generator(cat_dir)
cat_dict.cat_color("edge.jpg",(0,0,0)) #black (implant)
cat_dict.cat_color("implant.jpg",(0,0,0)) #black (implant)
cat_dict.cat_color("empty.jpg",(255,255,255)) #white (empty)
cat_dict.cat_color("stain_e.jpg",(255,255,255)) #white (empty)
cat_dict.cat_color("tissue.jpg",(120,81,169)) #purple (tissue)
cat_dict.cat_color("stain_t.jpg",(120,81,169)) #purple (tissue)

categorizer = Histology_CAT(image_dir, export_dir, cat_dict.d())
print(categorizer.kowalski())

(5688, 3)
79 72
(1536, 3)
24 64
All done, Skipper!
