In [1]:
from tifffile import *
import numpy as np
import matplotlib.pyplot as plt
from skimage.measure import label, regionprops, find_contours
import time
import cv2
%matplotlib inline


In [2]:
tiff_path = '/media/jamesrowland/DATA/plab/j031/4/2019-02-28_J031_t-004.tif'
tiff = imread(tiff_path)
tiff.shape

(1942, 512, 512)

In [3]:
#stim occured from frame 60-68
stim_start = 0
stim_end = 200

In [4]:
def find_threshold(tiff, stim_start, sigma=4):
    
    '''
    tiff: tiff stack 
    stim_start: frame on which the stimulus started
    sigma: sigma value over which to threshold pixels

    thresh: thresholded single frame
    '''
         
    base_vals = tiff[stim_start-10:stim_start-2, :, :]
    base_mean = np.mean(base_vals, 0)
    base_std =  np.std(base_vals, 0)
    thresh = base_mean + base_std*sigma
    
    return thresh 

In [5]:
def binarise_frame(frame, thresh):
    
    '''binarises frame, where all pixels > thresh = 1'''
    
    assert frame.shape == thresh.shape
    
    return np.greater(frame, thresh).astype('int')

In [6]:
def process_frame(frame, frame_bin, width_thresh=10):
    
    '''
    finds regions of connected pixels in a frame and their widths
    returns:
    labelled: array of frame dimensions with each pixels region labelled
    widths: the width (x len) of each region
    '''
    #find regions of pixel connectivity
    labelled, num_labels = label(frame_bin, connectivity=1, return_num=True)
    
    # the properties of connected regions
    regions = regionprops(labelled)

    for i,props in enumerate(regions):
        
        #if i==0:continue
        
        coords = props['Coordinates']
        
        rows = coords[:,0]
        cols = coords[:,1]
        
        width_rows = max(rows) - min(rows)
        width_cols = max(cols) - min(cols)
        
        width = (width_rows if width_rows<width_cols else width_cols)
        
        # the width of the labelled region is thin or it is very asymmetrical
        if width < width_thresh or props['major_axis_length'] / props['minor_axis_length'] > 2:
            
            frame[rows, cols] = 0
            
            #useful for debugging
            #labelled[rows, cols] = 0
        else:
            pass
            #labelled[rows, cols] = width

    
    return frame    

In [7]:
def main(tiff, stim_start, stim_end, width_thresh=10):
    
    thresh = find_threshold(tiff, stim_start)
    
    for frame_idx in range(stim_start, stim_end):

        frame = tiff[frame_idx, :, :]

        frame_bin = binarise_frame(frame, thresh)
        
        processed_frame = process_frame(frame, frame_bin, width_thresh)

        tiff[frame_idx, :, :] = processed_frame


    
    return tiff

In [8]:
t1 = time.time()
processed_tiff = main(tiff, stim_start, tiff.shape[0])
t2 = time.time()
t2-t1

269.7435395717621

In [11]:
(t2-t1) / 60

4.495725659529368

In [12]:
imsave('original_stack.tiff', tiff[0:200,:,:])
imsave('processed_stack.tiff', processed_tiff)