## Masking-Based Colocalization Measurement

### Notes <a id="notes"></a>
[notes](#notes) [prep](#prep) [test](#test) [run](#run)

Pipeline to measure colocalization of one channel ("1st") within the compartments delineated by another ("2nd"). Works by overall background subtraction followed by Otsu thresholding of the 2nd channel and then measuring the ratio of mean intensities of the 1st channel within the 2nd channel's mask vs in the entire image. To become less influenced by irrelevant regions of the image, all this is done within a bounding box surrounding the apical center of the neuromasts (the lumen).

**Note:** 8-bit conversion is done before this, using the Fiji macro `8bit_macro.ijm`. A fixed conversion range is used that is kept the same across a given experiment. Minima are always 0 or 10000 (depending on airyscan settings), maxima are adjusted based on intensity range; the values are logged in `data\metadata.xlsx`.

**Note:** For this to run the location of the apical center (the lumen position) of the neuromast has to be determined manually and its coordinates (in pixels) must be written in `<fpath>\metadata.xlsx`, which then has to be exported as a *tab-separated text file* called `<fpath>\metadata.txt`!


### Pipeline Outline

- Preprocessing
    - Crop to a region around the lumen
        - This was added to improve the quality of the measurements
        - It makes the measurement of total intensity more precise
        - It may also help by making the thresholding more consistent
    - Background subtraction
        - Either global based on background region *[preferred!]*
        - Or local based on heavy Gaussian background
    
    
- Thresholding of target vesicles (red or far-red)
    - Either automated thresholding *[preferred]*
        - Tested a few; Otsu looks good
    - Or full threshold series
        
    
- Measurements
    - Use bgsubbed Cxcr7/Cxcr4 channels for measurements
    - Get means & sums within the threshold masks and in total
    - Final measure: the ratio `threshold_mean / total_mean` *[preferred]*

### Prep <a id="prep"></a>
[notes](#notes) [prep](#prep) [test](#test) [run](#run)

In [None]:
# General
from __future__ import division
import os, warnings, pickle, time
import numpy as np
np.random.seed(42)
import scipy.ndimage as ndi
import matplotlib.pyplot as plt
from skimage.io import imread, imsave

# Internal
import coloc.colocalization as coloc

### Testing  <a id="test"></a>
[notes](#notes) [prep](#prep) [test](#test) [run](#run)

In [None]:
### Test Data Creation

# Settings
run_test_only  = False

# Parameters
shape   = (400, 400)
#shape   = (40, 400, 400)
offset  = 20
size    = (30, 20)
max_int = (70, 60)
bg_loc  = (30, 40)
bg_scl  = 5
chunk_s = 3

# Channel generation function
def create_channel(shape, offset, size=20, 
                   max_int=60, sig=7, PSF_sig=3,
                   bg_loc=5, bg_scl=5):

    # Null
    img = np.zeros(shape, dtype=np.uint8)

    # Signal
    pos = [s//2 - size//2 for s in shape]
    pos[0] += offset//2
    slc = tuple(slice(p, p+size) for p in pos)
    img[slc] = max_int
    
    # Smoothen signal
    img = ndi.gaussian_filter(img, sigma=sig)
    
    # Background
    img += np.abs(np.random.normal(bg_loc, bg_scl, shape)).astype(np.uint8)
    
    # PSF
    img = ndi.gaussian_filter(img, sigma=PSF_sig)    
    
    # Detector noise
    img += np.abs(np.random.normal(0, 2, shape)).astype(np.uint8)
    
    # Done
    return img
   
# Create single example image
img = np.zeros((2,)+shape, dtype=np.uint8)
img[0] = create_channel(shape, -offset//2, size=size[0], 
                        max_int=max_int[0], PSF_sig=chunk_s/3,
                        bg_loc=bg_loc[0], bg_scl=bg_scl)
img[1] = create_channel(shape,  offset//2, size=size[1], 
                        max_int=max_int[1], PSF_sig=chunk_s/3,
                        bg_loc=bg_loc[1], bg_scl=bg_scl)

# Prep for plotting
if img.ndim == 3:
    ch0_plot = img[0,...]
    ch1_plot = img[1,...]
elif img.ndim == 4:
    ch0_plot = img[0, img.shape[1]//2, ...]
    ch1_plot = img[1, img.shape[1]//2, ...]
    
# Display as RGB
rgb = np.dstack([ch0_plot, ch1_plot, np.zeros_like(ch0_plot)])
plt.imshow(rgb, interpolation='none')
plt.axis('off')
plt.show()

# Report
print "max int ch0:", np.max(img[0])
print "max int ch1:", np.max(img[1])

# Create full test series
if len(shape)==2:
    test_offsets = range(0, 101, 20)
if len(shape)==3: 
    test_offsets = [0, 20, 40, 60, 80, 100]
test_imgs = [] 
for test_offset in test_offsets:
    for intensity_factor in np.linspace(1.0, 2.0, 5):
        test_img = np.zeros((2,)+shape, dtype=np.uint8)
        test_img[0] = create_channel(shape, -test_offset//2, size=size[0], 
                                     max_int=max_int[0], PSF_sig=chunk_s/3,
                                     bg_loc=bg_loc[0], bg_scl=bg_scl) * intensity_factor
        test_img[1] = create_channel(shape,  test_offset//2, size=size[1], 
                                     max_int=max_int[1], PSF_sig=chunk_s/3,
                                     bg_loc=bg_loc[1], bg_scl=bg_scl) * intensity_factor
        test_imgs.append(test_img)

In [None]:
### Background Subtraction

bgsub = np.zeros_like(img)
bgsub[0, :, :] = coloc.bgsub_global(img[0, :, :])
bgsub[1, :, :] = coloc.bgsub_global(img[1, :, :])
#bgsub[0, :, :] = coloc.bgsub_local(img[0, :, :], sigma=10)
#bgsub[1, :, :] = coloc.bgsub_local(img[1, :, :], sigma=10)

# Prep for plotting
if bgsub.ndim == 3:
    ch0_plot = bgsub[0,...]
    ch1_plot = bgsub[1,...]
elif img.ndim == 4:
    ch0_plot = bgsub[0, img.shape[1]//2, ...]
    ch1_plot = bgsub[1, img.shape[1]//2, ...]
    
# Display as RGB
rgb = np.dstack([ch0_plot, ch1_plot, np.zeros_like(ch0_plot)])
plt.imshow(rgb, interpolation='none')
plt.axis('off')
plt.show()

In [None]:
### Thresholding & Measurement

#np.seterr(all='raise')

# Prep plot
fig, ax = plt.subplots(1, 6, figsize=(12,3))

# Otsu thresholding
threshs  = []
means    = []
sums     = []
m_ratios = []
s_ratios = []
for test_img in test_imgs:
    test_img_bgsub_ch0 = coloc.bgsub_global(test_img[0, :, :])
    test_img_bgsub_ch1 = coloc.bgsub_global(test_img[1, :, :])
    t, m, s, mr, sr = coloc.thresh_detect(test_img_bgsub_ch0, test_img_bgsub_ch1)
    #t, m, s, mr, sr, _, _, _ = thresh_detect(test_img_bgsub_ch0, test_img_bgsub_ch1)
    threshs.append(t)
    means.append(m)
    sums.append(s)
    m_ratios.append(mr)
    s_ratios.append(sr)

# Plot thresholds
ax[0].plot(threshs)
ax[0].set_ylim([min(threshs)-5,
                max(threshs)+5])
ax[0].set_xlabel('test image index')
ax[0].set_ylabel('otsu threshold')

# Plot results
ax[1].plot(means)
ax[1].set_xlabel('test image index')
ax[1].set_ylabel('foreground mean')
ax[2].plot(sums)
ax[2].set_xlabel('test image index')
ax[2].set_ylabel('foreground sum')
ax[3].plot(m_ratios)
ax[3].set_xlabel('test image index')
ax[3].set_ylabel('foreground mean / total ratio')
ax[4].plot(s_ratios)
ax[4].set_xlabel('test image index')
ax[4].set_ylabel('foreground sum / total ratio')

# Threshold series
threshs = []
means   = []
indices = []
for i, test_img in enumerate(test_imgs):
    t, m, _, _, _ = coloc.thresh_series(test_img[0,:,:], test_img[1,:,:])
    threshs.append(t)
    means.append(m)
    indices.append(np.ones_like(t)*i)

# Plot results
scat = ax[5].scatter(threshs, means, 
                     c=indices, cmap='viridis', 
                     edgecolors='face')
ax[5].set_xlabel('threshold')
ax[5].set_ylabel('foreground mean')
plt.colorbar(scat, label='test image index')

# Done
plt.tight_layout()
plt.show()

### Running the Data <a id="run"></a>
[notes](#notes) [prep](#prep) [test](#test) [run](#run)

In [None]:
### Halt in case only test runs should be done

if run_test_only:
    raise ValueError("Run terminated because `run_test_only` is set to True!")

In [None]:
### Settings

# Input data
dirpath = r'data_ex'
suffix  = r'_8bit.tif'
trigger = r'coloc'

# Processing parameters
region_size = (20, 150, 200)   # Has to be len (z, y, x). For 2D imgs, z is ignored.
if 'rev' in trigger: 
    region_size = (20, 180, 240)  # For revisions: Adjusted to increased zoom on LSM980!

In [None]:
### Retrieve file names

# Prep
fnames = [fname for fname in os.listdir(dirpath) 
          if trigger in fname and fname.endswith(suffix)]
fpaths = [os.path.join(dirpath, fname) for fname in fnames]

In [None]:
### Run pipeline

# For each file...
for fname, fpath in zip(fnames, fpaths):
    
    # Report
    print '\nProcessing image "' + fname + '"'
    
    # Load raw
    img = imread(fpath)
    
    # Organize dims and remove surplus channels
    if 2 in img.shape:
        img = np.rollaxis(img, img.shape.index(2))
    elif 3 in img.shape:
        img = np.rollaxis(img, img.shape.index(3))
        img = img[:2, ...]
    else:
        raise IOError("Opened an image that does not have a valid channel dimension.")
        
    # Get lumen position
    lumen = 'none'
    with open(os.path.join(os.path.split(fpath)[0], r"metadata.txt"), "r") as infile:
        for line in infile.readlines():
            line = line.strip()
            line = line.split('\t')
            if line[0] in os.path.split(fpath)[1]:
                lumen = np.array([int(value) for value in line[1:4]])
                break
    if lumen is 'none':
        raise Exception("Appropriate lumen metadata not found. Aborting!")
        
    # Crop to region around lumen
    rs  = region_size
    l   = lumen
    ims = img.shape
    if img.ndim == 4:
        img = img[:, np.max([0, l[0]-rs[0]]) : np.min([l[0]+rs[0], ims[1]-1]), 
                     np.max([0, l[1]-rs[1]]) : np.min([l[1]+rs[1], ims[2]-1]), 
                     np.max([0, l[2]-rs[2]]) : np.min([l[2]+rs[2], ims[3]-1])]
    elif img.ndim == 3:
        img = img[:, np.max([0, l[1]-rs[1]]) : np.min([l[1]+rs[1], ims[1]-1]), 
                     np.max([0, l[2]-rs[2]]) : np.min([l[2]+rs[2], ims[2]-1])]
        
    # Report
    print '  Loaded image of shape: ', str(img.shape)
        
    # Warn about saturation
    if img[0,...].max() == 255:
        with warnings.catch_warnings():
            warnings.simplefilter('always')
            warnings.warn("There is some saturation in this image!")
            time.sleep(0.5)
    
    # Perform background subtraction
    mask_img_bgsub_global    = coloc.bgsub_global(img[1,...])
    measure_img_bgsub_global = coloc.bgsub_global(img[0,...])
    mask_img_bgsub_local     = coloc.bgsub_local(img[1,...], sigma=10)
    
    # Threshold and extract measurements, construct results dict
    results = dict()
    results['total_mean'] = np.mean(measure_img_bgsub_global)
    results['total_sum']  = np.sum(measure_img_bgsub_global)
    
    global_otsu = coloc.thresh_detect(mask_img_bgsub_global, measure_img_bgsub_global)
    results['global_otsu_thresh']     = global_otsu[0]
    results['global_otsu_mean']       = global_otsu[1]
    results['global_otsu_sum']        = global_otsu[2]
    results['global_otsu_mean_ratio'] = global_otsu[3]
    results['global_otsu_sum_ratio']  = global_otsu[4]
    
    global_series = coloc.thresh_series(mask_img_bgsub_global, measure_img_bgsub_global)
    results['global_series_threshs']     = global_series[0]
    results['global_series_means']       = global_series[1]
    results['global_series_sums']        = global_series[2]
    results['global_series_means_slope'] = global_series[3]
    results['global_series_sums_slope']  = global_series[4]
    
    local_otsu = coloc.thresh_detect(mask_img_bgsub_local, measure_img_bgsub_global)
    results['local_otsu_thresh']     = local_otsu[0]
    results['local_otsu_mean']       = local_otsu[1]
    results['local_otsu_sum']        = local_otsu[2]
    results['local_otsu_mean_ratio'] = local_otsu[3]
    results['local_otsu_sum_ratio']  = local_otsu[4]
    
    local_series = coloc.thresh_series(mask_img_bgsub_local, measure_img_bgsub_global)
    results['local_series_threshs']    = local_series[0]
    results['local_series_means']      = local_series[1]
    results['local_series_slope']      = local_series[2]
    results['local_series_means_slope'] = local_series[3]
    results['local_series_sums_slope']  = local_series[4]    
    
    # Save measurements
    with open(fpath[:-4]+"_maskcoloc.pkl", 'wb') as resultfile:
        pickle.dump(results, resultfile, pickle.HIGHEST_PROTOCOL)
        
    # Report
    print '  Processing complete!'
    
# Final report
print '\nALL INPUT DATA PROCESSED!'