In [28]:
from nd2reader import ND2Reader
import matplotlib.pyplot as plt
import matplotlib.patches as Patches
import numpy as np
from skimage.filters import gaussian, threshold_yen
from skimage import measure
from skimage.util import crop
from scipy.ndimage import find_objects, measurements
import multiprocessing as mp
import gzip
import re
from numba import jit
import time
import cv2
import cython
import os
import sys
from PIL import Image

min_area = 30
PATCH_SIZE = 256
gimg = None

def main():
    path = "data/Tiffs/" # path for the individual tiff files.
    
    all_files = []
    for filename in os.listdir("data/Hackathon"): # original nd2 filenames
        if re.match(".*nd2", filename):
            #print(filename[:-4])
            #print(filename)
            all_files.append(filename[:-4])
             
    """ Command-line interface:
    n_args = len(sys.argv) - 1
    if n_args == 1:
        file = sys.argv[1]
    else:
        print("invalid args")
        return
    """
    
    for file in all_files:
        try:
            n_files = 0
            for filename in os.listdir(path):
                if re.match(file+".*tif", filename):
                    n_files += 1
            #print(n_files)
            ordered_files = [None]*n_files
            for filename in os.listdir(path):
                if re.match(file+".*tif", filename):
                    #print(filename)
                    idx = int((re.search('\d*(?=...)[.]', filename).group(0))[:-1])
                    #idx = int(filename[20:-4]) #faulty
                    ordered_files[idx] = os.path.join(path, filename)
                    
            image_stack = ImageStack(ordered_files)
            del image_stack
        except Exception as e:
            print("\n" + str(e) + "\n")
            continue
    
    #for file in ordered_files:
    #    print(file)

def show_image(img):
    """ Display an image (Y,X)
    """
    plt.figure(figsize = (5,5))
    plt.imshow(img)

def illum_filter(img,gauss_sigma=150):
    """ Filter to remove lens abberation.
    :param sigma: modifier for blur effect
    """
    if img is None:
        return None
    fltr = gaussian(img,gauss_sigma)
    #show_image(fltr) #debug image
    img = img-fltr
    img[img<0] = 0
    return img.astype(dtype=np.uint32,copy=False)

def get_threshold(img, ratio=10):
    """ Threshold for a single image frame. #FAST
    Ignores this image if the threshold is below a ratio.
    """
    thr = threshold_yen(img)
    background = img[img<thr]
    if thr < ratio*np.mean(background):
        print("thr:{0:.5f}, ratio*mean:{1:.5f}".format(thr,ratio*img.mean()))
        return 0
    return thr
        
def is_valid_patch(img, c, patch_size):
    """
    :param: centroid to examine.
    """
    psize = patch_size
    height, width = img.shape # YX format
    hsize = psize *0.5
    return not ((c[1]+hsize > width) or (c[1]-hsize < 0) 
                or (c[0]+hsize > height) or (c[0]-hsize < 0))

#FAST
def get_centroids(input_tuple):
    """ Identifies regions and returns midpoints.
    """
    img, patch_size = input_tuple
    labels = measure.label(img, background=0, connectivity=1)
    centroids = []
    #show_image(img)
    for offset in find_objects(labels):
        region = img[offset]
        area = (region>0).sum()
        if area >= min_area:
            offsets = [(sl.start) for sl in offset]
            centroid = measurements.center_of_mass(region)
            centroid = [int(centroid[0])+offsets[0],int(centroid[1])+offsets[1]]
            if is_valid_patch(img, centroid, patch_size):
                centroids.append(centroid)
                #show_image(region)
    return np.asarray(centroids)

def debug_show_patches(img, centroids, patch_size):
    """ Draw patch locations on image (single frame).
    """
    hsize = int(patch_size *0.5)
    fig,ax = plt.subplots(1,figsize=(10,10))
    ax.imshow(img)
    for y, x in centroids:
        rect = Patches.Rectangle((x -hsize,y - hsize),patch_size,patch_size,linewidth=1,edgecolor='r',facecolor='none')
        ax.add_patch(rect)
    plt.show()


def patchify(input_tuple):
    """ Get image patches at centroid locations.
    :param input_tuple: img, centroids
    """
    #start = time.time()
    img, centroids, patch_size = input_tuple
    hsize = int(patch_size *0.5)
    patches = []
    for y, x in centroids:
        cropped = img[y - hsize:y+hsize, x -hsize:x+hsize]
        patches.append(cropped)
    #debug_show_patches(img, centroids, patch_size)
    #end = time.time(); print("patchify: ", end-start, "seconds")
    return patches#np.asarray(patches)
    
def label_patches(patches):
    print(patches.shape)
    #for v in range(patches.shape[0]):
    #    for p in range(patches.shape[3]):
    #        label_patch(patches[v][:][:][p])
    
def label_patch(patch):
    print(patch.shape)
    
class ImageStack:
    def __init__(self, img_paths):
        self.images = {}
        self.np_data = None
        self.patches = None
        self.paths = img_paths
        self.patch_size = PATCH_SIZE
        self.n_channels = 4
        self.max_z = 7
        self.max_v = 0
        ## Read Images in nd2
        self.read_tiffs(img_paths)
        
    def read_tiffs(self, paths):
        global gimg, NUM_CPU

        images = np.asarray([np.array(Image.open(file)) for file in paths])
        print("Starting " + paths[0] )        
        self.images = images
        self.images = np.array(images)
        print(self.images.shape)
        
        self.max_v = int(self.images.shape[0] / self.n_channels / self.max_z)
        print(int(self.max_v))
        
        if re.match(".*40X.*", paths[0]):
            print("40X!")
            self.patch_size = 384

        
        print("getting centroids...")
        z = 0 # arbitrary index for selecting the z-stack
        self.centroids = []
        for v in range(self.max_v):
            self.centroids.append(get_centroids((self.get_merged_nuclei(v), self.patch_size)))
        
        print("patchify...")
        self.patches = []
        for _v in range(self.max_v):
            for _z in range(self.max_z):
                for _c in range(self.n_channels):
                    self.patches.append(patchify((self.get_frame_vcz(_v,_c,_z), self.centroids[_v], self.patch_size)))
        
        print("finished patchify")

        self.patches = np.asarray(self.patches)
        self.patches = np.asarray(np.split(self.patches, self.max_z*self.max_v))
        self.patches = np.asarray(np.split(self.patches, self.max_v))

        print("generating labels...")
        label_patches(self.patches)
        
        print("saving data...")
        file = 'data_'+re.search("\d*_\d*X_\d*._\d*", self.paths[0]).group(0)+'.npy'
        np.save(file, self.patches, True)
        print('saved data_'+re.search("\d*_\d*X_\d*._\d*", self.paths[0]).group(0)+'.npy.gz')
            
    def get_merged_nuclei(self, v):
        """ Assumes at least 3 channels in total.
        """
        start = time.time();
        img = self.get_threshold_vcz(v,1,0)
        for z in range(self.max_z):
            if z > 0:
                img = self.get_threshold_vcz(v,1,z) | img
            else:
                img = self.get_threshold_vcz(v,1,z)
            for _c in range(2,self.n_channels):
                img = self.get_threshold_vcz(v,_c,z) | img
    
        #for _c in range(2,self.n_channels):
        #    img = self.get_threshold_vcz(v,_c,z) | img
        end = time.time(); print("get_merged_nuclei: ", end-start, "seconds")
        return img
    
    def get_threshold_vcz(self, v, c, z):
        """ Threshold for a channel & z.
        :param v: the image index from nd2.
        :return: the thresholded image.
        """
        img = self.get_frame_vcz(v,c,z)
        thr = get_threshold(img)
        return img > thr
    
    def get_frame_vcz(self, v, c, z):
        return self.images[c + self.n_channels * (z + self.max_z * v)]
        #img = self.images.get_frame_vczyx(v,c,0,z,0,0)
        #return img#illum_filter(img) # Fix lens abberation
    
    def get_c_stack(self, v, z):
        stack = np.zeros(shape=(len(self.channels), self.width, self.height), dtype=int)
        for c in range(len(self.channels)):
            stack[c] = self.get_frame_vcz(v,c,z)
        return stack
    
    
if __name__ == "__main__":
    main()

data/Tiffs/190710_20X_50K_0005_000000.tif
28
data/Tiffs/190705_20X_10K_0001_000000.tif
28
data/Tiffs/190627_20X_25K_0001_000000.tif
28
data/Tiffs/190709_20X_50K_0001_000000.tif
28
data/Tiffs/190705_20X_5K_0001_000000.tif
28
data/Tiffs/190625_40X_100K_0001_000000.tif
28
data/Tiffs/190625_40X_12K_0001_000000.tif
28
data/Tiffs/190710_20X_50K_0002_000000.tif
28
data/Tiffs/190701_40X_10K_0002_000000.tif
28
data/Tiffs/190704_20X_5K_0001_000000.tif
28
data/Tiffs/190625_20X_25K_0001_000000.tif
28
data/Tiffs/190709_20X_25K_0002_000000.tif
28
data/Tiffs/190627_40X_12K_0001_000000.tif
28
data/Tiffs/190710_40X_50K_0001_000000.tif
28
data/Tiffs/190705_40X_10K_0001_000000.tif
28
data/Tiffs/190627_20X_100K_0001_000000.tif
28
data/Tiffs/190625_40X_25K_0001_000000.tif
28
data/Tiffs/190627_20X_12K_0001_000000.tif
28
data/Tiffs/190627_40X_50K_0001_000000.tif
28
data/Tiffs/190704_40X_10K_0002_000000.tif
28
data/Tiffs/190625_40X_50K_0001_000000.tif
28
data/Tiffs/190709_20X_50K_0003_000000.tif
28
data/Tiffs