In [1]:
cd /Users/jz-rolling/Desktop/OMEGA_GUI/notebooks/

/Users/jz-rolling/Desktop/OMEGA_GUI/notebooks


In [2]:
import momia2 as mo
from skimage import measure, feature, filters, transform, morphology
import numpy as np
import pandas as pd
import read_roi as rr
import matplotlib.pyplot as plt
import glob,os
import pickle as pk

In [153]:
model=mo.classify.classifier.load_model('/Users/jz-rolling/Desktop/Size_calculation/training_set/model_20220101.pk')
gabor_kernels = prepare_gabor_kernels()
pad=20
src = '/Users/jz-rolling/Desktop/Size_calculation/20220118/Alignment_20220120/'
dst = '/Users/jz-rolling/Desktop/Size_calculation/20220118/Cropped_20220120/'

In [221]:
data_dict = {}
for f in sorted(glob.glob(src+'Plate-7/')):
    plate = f.split('/')[-2]
    rois = rr.read_roi_zip(f+'{}_roi.zip'.format(plate))
    images = []
    days = []
    for img in sorted(glob.glob(f+'*.tif')):
        images.append(plt.imread(img))
        days.append(int(img.split('/')[-1].split('_')[0]))
    args = np.argsort(days)
    sorted_days = np.array(days)[args]
    sorted_images = np.array(images)[args]
    img_h,img_w = images[0].shape
    for i,k in enumerate(rois.keys()):
        roi_data = {}
        v=rois.get(k)
        x1,y1 = v['top'],v['left']
        x2,y2 = x1+v['height'],y1+v['width']
        
        pad_top = min(pad,x1)
        pad_left = min(pad,y1)
        pad_bottom = min(pad,img_h-x2-1)
        pad_right = min(pad,img_w-y2-1)
        window = sorted_images[:,x1-pad_top:x2+pad_bottom,y1-pad_left:y2+pad_right]
        corrected,_w = correct_tl(window)
        corrected_cropped = np.array([x[pad_top:-pad_bottom,pad_left:-pad_right] for x in corrected])
        masks = []
        mask_probs = []
        edges = []
        for s in corrected_cropped:
            prob,fg,edge= image2labels(s,model,
                                       gabor_kernels=gabor_kernels,
                                       edge_threshold=0.2)
            masks.append(fg)
            mask_probs.append(prob)
            edges.append(edge)
        masks=np.array(masks)
        mask_probs=np.array(mask_probs)
        edges = np.array(edges)
        roi_data['Days'] = sorted_days
        roi_data['Images'] = corrected_cropped
        roi_data['Masks'] = masks
        roi_data['Mask_probs'] = mask_probs
        roi_data['Edges'] = edges
        fname = dst+'{}_{}.pk'.format(plate,k)
        pk.dump(roi_data,open(fname,'wb'))

In [205]:
from scipy import ndimage as ndi
from skimage.filters import gabor_kernel
def prepare_gabor_kernels(n_theta=4,sigmas=[2.8],frequencies=[0.23]):
    kernels = []
    for theta in range(n_theta):
        theta = (theta / n_theta) * np.pi
        for sigma in sigmas:
            for frequency in frequencies:
                kernel = np.real(gabor_kernel(frequency, theta=theta,
                                              sigma_x=sigma, sigma_y=sigma))
                kernels.append(kernel)
    return kernels


def min_max(data):
    return (data-data.min())/(data.max()-data.min())

def power(image, kernel):
    # Normalize images for better comparison.
    image = (image - image.mean()) / image.std()
    power_image = np.sqrt(ndi.convolve(image, np.real(kernel), mode='wrap')**2 +
                          ndi.convolve(image, np.imag(kernel), mode='wrap')**2)
    convolved = filters.gaussian(power_image/np.median(power_image),sigma=0.8)
    return convolved

def image2labels(target,model,gabor_kernels,edge_threshold = 0.05):
    target = target.astype(int)
    target = (target/240)*255
    target[target>255]=255
    target = mo.utils.dual_bandpass(target,
                                    pixel_microns=1,
                                    min_structure_scale=1,
                                    max_structure_scale=400).astype(np.uint8)
    x,y = np.where(target>-1)
    gaussians,img_features = mo.metrics.image_feature.multiscale_image_feature(target,sigmas=(1,2.5,5),
                                                                               shapeindex=True,
                                                                               rog=True)
    pixel_stats = mo.metrics.image_feature.local_stat(img_features,x,y)
    for i,g in enumerate(gabor_kernels):
        filtered = power(target, g)
        pixel_stats['Gabor_{}'.format(i)] = filtered[x,y]
    mask = np.zeros(list(target.shape))
    mask_prob = np.zeros(list(target.shape)+[3])
    predictions = model.predict(pixel_stats,probability=False)
    probs = model.predict(pixel_stats,probability=True)
    mask_prob[x,y,:] = probs
    mask[x,y,] = predictions
    fg = morphology.binary_opening(mask==1)
    fg = morphology.remove_small_holes(fg,200)
    edge = (mask_prob[:,:,2])>edge_threshold
    edge = morphology.binary_opening(edge)
    return mask_prob,fg,edge

def min_max(data):
    return (data-data.min())/(data.max()-data.min())

def xydrift_correction(target_img, shift, max_drift=25):
    if max(np.abs(shift)) <= max_drift:
        return shift_image(target_img, np.array(shift))
    else:
        return target_img
    
def get_xydrift(ref_img, target_img):
    from skimage import registration
    shift, error, _diff = registration.phase_cross_correlation(ref_img, target_img, upsample_factor=10)
    return shift

def shift_image(img, shift):
    from scipy import ndimage as ndi
    """
    correct xy drift between phase contrast image and fluorescent image(s)
    :param img: input image
    :param shift: subpixel xy drift
    :return: drift corrected image
    """
    offset_image = ndi.fourier_shift(np.fft.fftn(img), shift)
    offset_image = np.fft.ifftn(offset_image)
    offset_image = np.round(offset_image.real)
    return offset_image


def correct_tl(pos_data,max_drift=20):
    is_weird = [0]
    last_drift = np.array([0,0])
    output_data = []
    for t in range(len(pos_data)):
        if t == 0:
            output_data.append(pos_data[t])
        else:
            shift = get_xydrift(pos_data[t-1],pos_data[t])
            if np.abs(shift).max()>max_drift:
                is_weird.append(1)
            else:
                is_weird.append(0)
            last_drift = last_drift + shift
            output_data.append(xydrift_correction(pos_data[t],last_drift,max_drift=max_drift))
    output_data = np.array(output_data)
    return output_data, is_weird