In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import rcParams
from skimage import io
from os.path import expanduser
from tqdm import tqdm
HOME = expanduser("~")
import os, sys
import cv2
import nipy
from nipy.labs.mask import compute_mask
from skimage import exposure, img_as_float
from skimage.exposure import rescale_intensity, adjust_gamma, equalize_adapthist
from scipy import stats
import pandas as pd
%load_ext autoreload
%autoreload 2

In [None]:
animal = 'DK41'
DIR = f'/net/birdstore/Active_Atlas_Data/data_root/pipeline_data/{animal}'
INPUT = os.path.join(DIR, 'preps', 'CH1', 'thumbnail')
MASK = os.path.join(DIR, 'preps','thumbnail_masked')
masks = sorted(os.listdir(MASK))
files = sorted(os.listdir(INPUT))

In [None]:
PATH = '/home/eddyod/programming/pipeline_utility'
sys.path.append(PATH)
from utilities.utilities_mask import trim_edges, create_mask_pass1, fix_with_fill, equalized
from utilities.sqlcontroller import SqlController
sqlController = SqlController(animal)

In [None]:
bad = []
bads = ['{}.tif'.format(str(i).zfill(3)) for i in bad]
print(f'Need to manually fix {len(bads)} files or %{100*len(bads)/len(files)}')

In [None]:
names = []
norms = []
masks1 = []
passes1 = []
masks2 = []
passes2 = []
start, finish = (20,30)
lowVal, highVal, threshold = (0,0,0)

for file in tqdm(files[start:finish]):
    infile = os.path.join(INPUT, file)
    img = io.imread(infile)
    normed = equalized(img)
    img8 = (normed/256).astype('uint8')

    #gamma_corrected = exposure.adjust_gamma(img, 2)
    norms.append(normed)
    img = trim_edges(img)
    mask1 = create_mask_pass1(img)
    pass1 = cv2.bitwise_and(img, img, mask=mask1)
    masks1.append(mask1)    
    passes1.append(img2)
    ## pass2
    pass1 = cv2.GaussianBlur(pass1,(33,33),0)
    mask2, lowVal, highVal, threshold = fix_with_fill(pass1, debug=True)
    #mask2 = mask1
    pass2 = cv2.bitwise_and(img, img, mask=mask2)
    masks2.append(mask2)
    passes2.append(pass2)
    name = f'{file} {round(lowVal)} {round(highVal)} {round(threshold)}'
    names.append(name)
    
fig, ax = plt.subplots(nrows=len(masks1), ncols=5, sharex=False, sharey=False)
plt.style.use('classic')
i = 0
for name, norm, mask1, pass1, mask2, pass2 in zip(names, norms, masks1, passes1, masks2, passes2):
    ax[i,0].set_title(f'{name}')
    ax[i,0].imshow(norm, cmap="gray")
    ax[i,1].set_title('1st pass mask')
    ax[i,1].imshow(mask1, cmap="gray")
    ax[i,2].set_title('1st pass img')
    ax[i,2].imshow(pass1, cmap="gray")
    ax[i,3].set_title('2nd pass mask')
    ax[i,3].imshow(mask2, cmap="gray")
    ax[i,4].set_title('2nd pass img')
    ax[i,4].imshow(pass2, cmap="gray")

    i += 1
fig.set_size_inches(np.array([18, 4 * len(names)]), forward=False)
plt.tight_layout()
plt.show()

In [None]:
def get_binary_mask(img):
    '''
    Turn the RGB image into grayscale before
    applying an Otsu threshold to obtain a
    binary segmentation
    '''
    img8 = (img/256).astype('uint8')
    blurred_img = cv2.GaussianBlur(img8,(225,225),0)
    #return blurred_img
    #gray_img = cv2.cvtColor(blurred_img, cv2.COLOR_RGBA2GRAY)
    gray_img = blurred_img.copy()
    ret, otsu = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    
    kernel = np.ones((40,40),np.uint8)
    closed_mask = cv2.morphologyEx(otsu, cv2.MORPH_CLOSE, kernel)
    return closed_mask

In [None]:
# 181, very dark
# 381, kinda light
file = '281.tif' 
infile = os.path.join(INPUT, file)
img = io.imread(infile)
img = trim_edges(img)
normed = equalized(img)
#maskfile = os.path.join(MASK, file)
#mask1 = io.imread(maskfile)
mask1 = create_mask_pass1(img)
img = cv2.bitwise_and(img, img, mask=mask1)

mask2 = get_binary_mask(img)
fixed = cv2.bitwise_and(img, img, mask=mask2)
fixed = equalize_adapthist(fixed)
#fixed = linnorm(fixed, 43000, mask1)
#rescaled = rescale_intensity(fixed.astype(np.float), (0,43000), (0, 2**16-1)).astype(np.uint16)
#clahe = cv2.createCLAHE(clipLimit=40.0, tileGridSize=(8, 8))
#fixed = clahe.apply(fixed)
# figure size in inches optional
rcParams['figure.figsize'] = 18 ,18

# display images
fig, ax = plt.subplots(1,3)
ax[0].imshow(normed, cmap="gray");
ax[0].set_title('original')
ax[1].imshow(mask2, cmap="gray");
ax[1].set_title('mask2')
ax[2].imshow(fixed, cmap="gray");
ax[2].set_title('fixed')
