In [19]:
#percentile technique 11.9:10:33
from scipy.stats import rankdata, percentileofscore, scoreatpercentile
import seaborn as sns
from nd2reader import ND2Reader
import matplotlib.pyplot as plt
import matplotlib.patches as Patches
import numpy as np
from skimage.filters import gaussian, threshold_yen
from skimage import measure
from skimage.util import crop
from scipy.ndimage import find_objects, measurements
import multiprocessing as mp
import gzip
import re
import time
import os
import sys
from PIL import Image

from skimage.filters import try_all_threshold

min_area = 100
PATCH_SIZE = 256
gimg = None
KEEP_Z = 3
NUM_CPU = int(mp.cpu_count()/2)-1
CNT = 0

def main():
    image_stack = ImageStack('/home/shared/kevinv/uhack_2019/data/Hackathon/190710_20X_50K_0004.nd2')
    return
    all_files = []
    for filename in os.listdir("data/Hackathon"): # original nd2 filenames
        if re.match(".*20X.*nd2", filename):
            print(filename)
            #all_files.append(filename[:-4])
            try:
                all_files.append("data/Hackathon/"+filename)
            except:
                continue
            
    with mp.Pool(len(all_files)) as p:
        p.map(ImageStack, [f for f in all_files])

def show_image(img):
    global CNT
    """ Display an image (Y,X)
    """
    plt.figure(figsize = (5,5))
    plt.imshow(img)
    return
    plt.figure(figsize=(10,10))
    plt.axis("off")
    plt.imshow(img)
    plt.axis('off')
    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
    plt.savefig(str(CNT)+".png", frameon='false', bbox_inches='tight', transparent=True, pad_inches=0.0)
    CNT += 1

def illum_filter(img,gauss_sigma=150):
    """ Filter to remove lens abberation.
    :param sigma: modifier for blur effect
    """
    if img is None:
        return None
    fltr = gaussian(img,gauss_sigma)
    #show_image(fltr) #debug image
    img = img-fltr
    img[img<0] = 0
    return img.astype(dtype=np.uint32,copy=False)

def get_threshold(img, ratio=10):
    """ Threshold for a single image frame. #FAST
    Ignores this image if the threshold is below a ratio.
    """
    thr = threshold_yen(img)
    background = img[img<thr]
    #if thr < ratio*np.mean(background):
    #    print("thr:{0:.5f}, ratio*mean:{1:.5f}".format(thr,ratio*img.mean()))
    #    return 0
    return thr
        
def is_valid_patch(img, c, patch_size):
    """
    :param: centroid to examine.
    """
    psize = patch_size
    height, width = img.shape # YX format
    hsize = psize *0.5
    return not ((c[1]+hsize > width) or (c[1]-hsize < 0) 
                or (c[0]+hsize > height) or (c[0]-hsize < 0))

#FAST
def get_centroids(input_tuple):
    """ Identifies regions and returns midpoints.
    """
    img, patch_size = input_tuple
    labels = measure.label(img, background=0, connectivity=1)
    centroids = []
    #show_image(img)
    for offset in find_objects(labels):
        region = img[offset]
        area = (region>0).sum()
        if area >= min_area:
            offsets = [(sl.start) for sl in offset]
            centroid = measurements.center_of_mass(region)
            centroid = [int(centroid[0])+offsets[0],int(centroid[1])+offsets[1]]
            if is_valid_patch(img, centroid, patch_size):
                centroids.append(centroid)
                #show_image(region)
    return np.asarray(centroids)

def debug_show_patches(img, centroids, patch_size):
    global CNT
    """ Draw patch locations on image (single frame).
    """
    hsize = int(patch_size *0.5)
    fig,ax = plt.subplots(1,figsize=(20,20))
    plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
    ax.imshow(img)
    for y, x in centroids:
        rect = Patches.Rectangle((x -hsize,y - hsize),patch_size,patch_size,linewidth=1,edgecolor='r',facecolor='none')
        ax.add_patch(rect)
    plt.show()
    
    #plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
    #plt.savefig(str(CNT)+".png", frameon='false', bbox_inches='tight', transparent=True, pad_inches=0.0)
    CNT += 1


def patchify(input_tuple):
    """ Get image patches at centroid locations.
    :param input_tuple: img, centroids
    """
    #start = time.time()
    img, centroids, patch_size = input_tuple
    hsize = int(patch_size *0.5)
    patches = []
    for y, x in centroids:
        cropped = img[y - hsize:y+hsize, x -hsize:x+hsize]
        patches.append(cropped)
    #debug_show_patches(img, centroids, patch_size)
    #end = time.time(); print("patchify: ", end-start, "seconds")
    return np.asarray(patches)#np.asarray(patches)
    
def label_patches(patches):
    print(patches.shape)
    #for v in range(patches.shape[0]):
    #    for p in range(patches.shape[3]):
    #        label_patch(patches[v][:][:][p])
    
def label_patch(patch):
    print(patch.shape)
    
class ImageStack:
    def __init__(self, img_path):
        self.images = {}
        self.np_data = None
        self.patches = None
        self.path = img_path
        self.patch_size = PATCH_SIZE
        self.n_channels = 4
        self.max_z = 7
        self.max_v = 0
        self.frames = []
        ## Read Images in nd2
        self.read_nd2(img_path)
        
    def read_nd2(self, path):
        global gimg, NUM_CPU

        with ND2Reader(path) as images:
            print("Starting " + path)        
            self.images = images
            #self.images = np.array(images)
            #print(self.images.shape)

            # Obtain metadata & construct struct container:
            self.width = images.metadata['width']
            self.height = images.metadata['height']
            self.z_levels = images.metadata['z_levels']
            self.channels = images.metadata['channels']
            self.max_v = images.metadata['fields_of_view'].stop
            
            if re.match(".*40X.*", path):
                print("40X!")
                self.patch_size = 384

            self.max_v = 2
                
            print("getting centroids...")
            z = 0 # arbitrary index for selecting the z-stack
            self.centroids = []
            
            # Single threaded
            for v in range(self.max_v):
                self.centroids.append(get_centroids((self.get_merged_nuclei(v), self.patch_size)))
                
            print("patchify...")
            self.patches = []
            
            for _v in range(self.max_v):
                for _z in range(2,5): #self.max_z
                    for _c in range(self.n_channels):
                        self.patches.append(patchify((self.get_frame_vcz(_v,_c,_z), self.centroids[_v], self.patch_size)))
            """
            for _v in range(self.max_v):
                for _c in range(self.n_channels):
                    self.patches.append(patchify((self.get_frame_vcz(_v,_c,KEEP_Z), self.centroids[_v], self.patch_size)))
            """

            print("finished patchify")

            self.patches = np.asarray(self.patches)
            self.patches = np.asarray(np.split(self.patches, 3*self.max_v)) #unc for -z#self.max_z
            self.patches = np.asarray(np.split(self.patches, self.max_v))
            
            gimg = self.patches
            
            print("generating labels...")
            label_patches(self.patches)
            self.patches = converge3z(self.patches)
            self.patches = normalize(self.patches)
            
            return
            print("saving data...")
            file = 'patches/norm_20X/norm_'+re.search("\d*_\d*X_\d*._\d*", path).group(0)+'.npy'
            np.save(file, self.patches, True)
            print('saved norm_'+re.search("\d*_\d*X_\d*._\d*", path).group(0)+'.npy.gz')
            
    def get_merged_nuclei(self, v):
        """ Assumes at least 3 channels in total.
        """
        #start = time.time();

        img = self.get_threshold_vcz(v,1,0)
        for z in range(2):#self.max_z):#self.max_z
            if z > 0:
                img = self.get_threshold_vcz(v,1,z) | img
            for _c in range(2,self.n_channels):
                img = self.get_threshold_vcz(v,_c,z) | img
    
        #show_image(img)
        #end = time.time(); print("get_merged_nuclei: ", end-start, "seconds")
        return img
    
    #@jit(nopython = True, parallel = True)
    def get_threshold_vcz(self, v, c, z):
        """ Threshold for a channel & z.
        :param v: the image index from nd2.
        :return: the thresholded image.
        """
        img = self.get_frame_vcz(v,c,z)
        thr = get_threshold(img)
        return img > thr
    
    def get_frame_vcz(self, v, c, z):
        #return self.images[c + self.n_channels * (z + self.max_z * v)]
        return self.images.get_frame_vczyx(v,c,0,z,0,0)
        #return img#illum_filter(img) # Fix lens abberation
    
    def get_c_stack(self, v, z):
        stack = np.zeros(shape=(len(self.channels), self.width, self.height), dtype=int)
        for c in range(len(self.channels)):
            stack[c] = self.get_frame_vcz(v,c,z)
        return stack
    
def converge3z(img):
    try:
        data = np.asarray(img)

        n_patches = 0
        for v in range(data.shape[0]):
            for z in range(data.shape[1]):
                for c in range(data.shape[2]):
                    data[v][z][c] = np.asarray(data[v][z][c])

        patch_sizes = []
        for v in range(data.shape[0]):
            n_patches += data[v][0][0].shape[0]
            patch_sizes.append(data[v][0][0].shape[0])

        print(str(n_patches) + " patches found")

        y = x = data[0][0][0].shape[1]; c = 4
        formatted = np.ndarray(shape=(n_patches, 3, c, y, x))

        p_offset = 0
        for v in range(data.shape[0]):
            n_patches = int(patch_sizes[v])
            offset = 0
            for p_idx in range(n_patches):
                channels = []
                for z in range(3):
                    for c in range(4):
                        formatted[p_offset+offset][z][c] = data[v][z][c][p_idx]
                offset+=1
            p_offset += n_patches
        #file = 'patches/raw_3z_40X/raw_'+re.search("\d*_\d*X_\d*._\d*", path).group(0)+'.npy'
        return formatted
    except Exception as e:
            print("\n" + str(e) + "\n")
            return None
        
def normalize(patches):
    print(patches.shape)
    global gimg, gimg2
    #print(np.mean(patches, axis=2))
    gimg = np.copy(patches)
    for c in range(patches.shape[2]):
        vals = patches[:,:,c,:,:].flatten()
        high = scoreatpercentile(vals,99)
        low = scoreatpercentile(vals,1)
        patches[:,:,c,:,:][patches[:,:,c,:,:] > high] = high
        patches[:,:,c,:,:][patches[:,:,c,:,:] < low] = low
        patches[:,:,c,:,:] = patches[:,:,c,:,:]-patches[:,:,c,:,:].min()
        patches[:,:,c,:,:] = patches[:,:,c,:,:]/patches[:,:,c,:,:].max()
    gimg2 = np.copy(patches)
    return patches
    
if __name__ == "__main__":
    main() 

Starting /home/shared/kevinv/uhack_2019/data/Hackathon/190710_20X_50K_0004.nd2
getting centroids...
patchify...
finished patchify
generating labels...
(2, 3, 4)
15 patches found
(15, 3, 4, 256, 256)


In [29]:
def get_cell_state(gimg2):
   mid = int(gimg2.shape[-1]/2)
   w=25
   states = []
   for p in range(gimg2.shape[0]):
       percs = {}
       for c,n in ((1,'g'),(2,'r')):
           patch = gimg2[p,1,c,:,:]
           #show_image(patch)
           center = patch[mid-w:mid+w, mid-w:mid+w]
           vals = gimg2[:,:,c,:,:].flatten()
        #   print('{}, perc: {:.3f}'.format(n,perc), end=' | ')
           percs[n] = perc
       if percs['r'] < 20 and percs['g'] < 20:
           state = 'Uncalled'
       elif np.abs(percs['r']-percs['g']) < 10:
           state = 'Early S'
       elif percs['r'] > percs['g']:
           state = 'G1'
       else:
           state= 'S/G2/M'
       states.append(state)
   return states

In [None]:
print(gimg.shape)

img = gimg[0][0]

weights = []
#img_thr = img[3] > thr
show_image(img_thr)
fig, ax = try_all_threshold(img_thr, figsize=(10, 8), verbose=False)
plt.show()
for c in range(4):
    thr = threshold_yen(img)
    show_image(img[c])
    

In [22]:
print(get_cell_state(gimg))

['S/G2/M', 'G1', 'G1', 'G1', 'S/G2/M', 'S/G2/M', 'Early S', 'G1', 'Early S', 'G1', 'S/G2/M', 'G1', 'G1', 'Early S', 'G1']


In [None]:
def get_cell_state(all_patches)
    for filename in os.listdir("patches/norm_20X"): # original nd2 filenames
            if re.match("*.npy", filename):