In [None]:
#general packages
import numpy as np
import time, os, sys
import matplotlib.pyplot as plt
import matplotlib as mpl
import glob
import re
from tqdm import tqdm
from util import pil_imread
import tifffile as tf
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
#import cellpose for segmentation
from cellpose import models, io
from cellpose import plot
#ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [None]:
import plotly.express as px
def plot_2d(img, zmax):
    """Function to generate plots with slide panel
    Parameters:
    -----------
    img = image containing ref and corrected
    zmax= set maximum intensity"""
    
    #For Plotting 2d image
    #-------------------------------------------
    fig = px.imshow(
        img,
        width=600,
        height=600,
        binary_string=True,
        binary_compression_level=4,
        binary_backend='pil',
        zmax = zmax)
    
    fig.show()

In [None]:
import plotly.express as px
def plot_slideshow(img, zmax):
    """Function to generate plots with slide panel
    Parameters:
    -----------
    img = image containing ref and corrected
    zmax= set maximum intensity"""
    
    #For Plotting 2d image
    #-------------------------------------------
    fig = px.imshow(
        img,
        width=600,
        height=600,
        binary_string=True,
        binary_compression_level=4,
        animation_frame=0,
        binary_backend='pil',
        zmax = zmax)
    
    fig.show()

In [None]:
directory = "/groups/CaiLab/personal/Lex/raw/2020-08-08-takei/HybCycle_0/*.tif"
files=glob.glob(directory)

In [None]:
files

In [None]:
#read images and generate a list of arrays
collection = []

for i in tqdm(range(len(files))):
    collection.append(pil_imread(files[i], swapaxes=True))
# for i in tqdm(range(2)):
#     collection.append(pil_imread(files[i], swapaxes=False))

In [None]:
#z,c,x,y
collection[0].shape

In [None]:
plot_slideshow(collection[0][2], zmax=3000)

In [None]:
img = np.swapaxes(collection[0],0,1)
maxc = np.max(img[1], axis=0)
maxn = np.max(img[3],axis=0)

In [None]:
max_cell = np.array([maxc,maxn])

In [None]:
plot_2d(max_cell[1], zmax=5000)

In [None]:
# DEFINE CELLPOSE MODEL
# model_type='cyto' or model_type='nuclei'
model = models.Cellpose(gpu=False, model_type='cyto')

In [None]:
# define CHANNELS to run segementation on
# grayscale=0, R=1, G=2, B=3
# channels = [cytoplasm, nucleus]

channels = [1,2]

masks_cyto, flows, styles, diams = model.eval(max_cell, diameter=150, 
                                         channels=channels, flow_threshold=0.5,cellprob_threshold=0)


In [None]:
plt.imshow(masks_cyto)

In [None]:
# DISPLAY RESULTS

fig = plt.figure(figsize=(30,30))
plot.show_segmentation(fig, max_cell[0], masks_cyto, flows[0])
#plt.tight_layout()
plt.show()

In [None]:
#export cyto masks
import os
os.mkdir("../Labeled_Images")
tf.imwrite("/groups/CaiLab/personal/Lex/raw/2020-08-08-takei/notebook_pyfiles/Labeled_Images/MMStack_Pos0.tif",masks_cyto)

In [None]:
# DEFINE CELLPOSE MODEL
# model_type='cyto' or model_type='nuclei'
model = models.Cellpose(gpu=False, model_type='nuclei')

In [None]:
# define CHANNELS to run segementation on
# grayscale=0, R=1, G=2, B=3
# channels = [cytoplasm, nucleus]

channels = [0,2]

masks_nuclear, flows, styles, diams = model.eval(z1_collection, diameter=200, 
                                         channels=channels, flow_threshold=0.4,cellprob_threshold=0)


In [None]:
# DISPLAY RESULTS

fig = plt.figure(figsize=(30,30))
plot.show_segmentation(fig, z1_collection[0][1], masks_nuclear[0], flows[0][0])
#plt.tight_layout()
plt.show()

In [None]:
#export cyto masks
for i in range(len(masks_nuclear)):
    tf.imwrite("/groups/CaiLab/personal/Lex/Sandbox/20k_dash_3t3_exp1/nuc_masks/nucmask{}.tif".format(i), 
               masks_nuclear[i])

---------------------------------------------------------------------------------------------------------------------

In [None]:
#read in masks
nuc_paths = glob.glob("/groups/CaiLab/personal/Lex/Sandbox/20k_dash_3t3_exp1/nuc_masks/*.tif")
cyto_paths = glob.glob("/groups/CaiLab/personal/Lex/Sandbox/20k_dash_3t3_exp1/cell_masks/*.tif")

#organize files numerically
key = [int(re.search('nucmask(\\d+)', f).group(1)) for f in nuc_paths]
nuc_paths = list(np.array(nuc_paths)[np.argsort(key)])

key = [int(re.search('cytomask(\\d+)', f).group(1)) for f in cyto_paths]
cyto_paths = list(np.array(cyto_paths)[np.argsort(key)])

In [None]:
nuclear = []
cyto = []
for i in tqdm(range(len(nuc_paths))):
    nuclear.append(pil_imread(nuc_paths[i]))
    cyto.append(pil_imread(cyto_paths[i]))

In [None]:
def nuclear_cyto_matching(cyto, nuc, threshold=0.20):
    """Match cyto masks and nuclear masks. Keep cyto masks that have nucleus
    Parameters
    ----------
    cyto=list of arrays or single cyto array
    nuc=list of arrays or single nuc array
    threshold=percent overlap"""
    
    if type(cyto) != list:
        #make copy of mask to not overwrite original
        cyto_new = np.copy(cyto)
        #converst masks to only one of the cells
        for i in np.arange(1, len(np.unique(cyto)),1):
            arr1_int = (cyto==i).astype(int)
            arr2_int = (nuc>0).astype(int)

            #compare masks
            matched_counts = np.where((arr1_int==1)& (arr2_int==1))
            total_count = np.count_nonzero(arr1_int == 1)
            percent = len(matched_counts[0])/total_count

            #if percent overlap is greater than threshold keep, else throw away
            if percent < threshold:
                cyto_new[cyto_new==i]=0
                
        #get array of old number assignment
        new_numbers = np.arange(0,len(np.unique(cyto_new)),1)
            
        #changes old number assignments to new
        for i in range(len(np.unique(cyto_new))):
            if i !=0:
                old_number = np.unique(cyto_new)
                cyto_new[cyto_new==old_number[i]]=new_numbers[i]
                        
        return cyto_new
    
    else:
        #for new masks
        new_arr = []
        for i in tqdm(range(len(cyto))):
            #make copy of mask to not overwrite original
            cyto_new = np.copy(cyto[i])
            #converst masks to only one of the cells
            for j in np.arange(1, len(np.unique(cyto[i])),1):
                arr1_int = (cyto[i]==j).astype(int)
                arr2_int = (nuc[i]>0).astype(int)

                #compare masks
                matched_counts = np.where((arr1_int==1)& (arr2_int==1))
                total_count = np.count_nonzero(arr1_int == 1)
                percent = len(matched_counts[0])/total_count

                #if percent overlap is greater than threshold keep else throw away
                if percent < threshold:
                    cyto_new[cyto_new==j]=0
            #get array of old number assignment
            new_numbers = np.arange(0,len(np.unique(cyto_new)),1)
            
            #changes old number assignments to new
            for k in range(len(np.unique(cyto_new))):
                if k !=0:
                    old_number = np.unique(cyto_new)
                    cyto_new[cyto_new==old_number[k]]=new_numbers[k]
            new_arr.append(cyto_new)
        return new_arr

In [None]:
cyto_new = nuclear_cyto_matching(cyto,nuclear, threshold=0.10)

In [None]:
for i in range(len(cyto_new)):
    tf.imwrite("/groups/CaiLab/personal/Lex/Sandbox/20k_dash_3t3_exp1/Labeled_Images/MMStack_Pos{}.ome.tif".format(i),
               cyto_new[i])

In [None]:
# #path to labeled images
# directory = "./Labeled_Images/*.tif"
# files=glob.glob(directory)

# #organize files numerically
# key = [int(re.search('MMStack_Pos(\\d+)', f).group(1)) for f in files]
# files = list(np.array(files)[np.argsort(key)])

In [None]:
# #read in images
# labeled = []
# for i in tqdm(range(len(files))):
#     labeled.append(pil_imread(files[i]))

In [None]:
# plt.imshow(labeled[0])

In [None]:
# plt.imshow(cyto[0])

In [None]:
# plt.imshow(nuclear[0])