# Annotate experiments
Notebook used to create binary annotations for experiments using a computer vision approach, or deep networks.  
Manual tuning and correction are used to improve the quality of the detection, before it is used as ground truths.

In [17]:
%matplotlib inline

import os, sys, time, shutil
import warnings
import ipywidgets as widgets
from ipywidgets import interact

import numpy as np
import matplotlib.pyplot as plt
from skimage import io, filters
from skimage.morphology import disk

import torch

from utils_common.image import imread_to_float, to_npint, overlay_contours
from utils_common.processing import nlm_denoising
from utils_common.register_cc import register_stack
from computer_vision.cv_detector import cv_detect
from deep_learning.utils_data import normalize_range, pad_transform_stack
from deep_learning.utils_test import predict_stack
# from deep_learning

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load experiment

In [7]:
animal_to_annotate = os.listdir("/data/talabot/experiments/to_annotate/")
animal_annotated = os.listdir("/data/talabot/experiments/annotated/")

# Randomly select an animal without any annotated experiment
animal = np.random.choice([animal for animal in animal_to_annotate if animal not in animal_annotated])
# Randomly select an experiment to annotated
experiment = np.random.choice(os.listdir(os.path.join("/data/talabot/experiments/to_annotate/", animal)))

# Load RGB stack and register it
rgb_stack = imread_to_float(os.path.join("/data/talabot/experiments/to_annotate/", animal, experiment, "RGB.tif"))
reg_rgb = register_stack(rgb_stack, ref_num=0)

@interact(image=(0, len(rgb_stack) - 1))
def plot_stack(image=0):
    plt.figure(figsize=(10,5))
    plt.suptitle("Animal: %s\nExp: %s" % (animal, experiment))
    plt.subplot(121)
    plt.title("Raw frame %d" % image)
    plt.imshow(rgb_stack[image])
    plt.subplot(122)
    plt.title("Mean temporal registered image")
    plt.imshow(reg_rgb.mean(0))
    plt.tight_layout()
    plt.show()

interactive(children=(IntSlider(value=0, description='image', max=599), Output()), _dom_classes=('widget-inter…

# Computer Vision detector

In [12]:
# Test the denoising and detector to tune their parameters
print("Denoising")
@interact(image=(0, len(rgb_stack) - 1),
          h_red=widgets.BoundedIntText(value=11, min=0),
          h_green=widgets.BoundedIntText(value=11, min=0))
def plot_denoising(image=0, h_red=11, h_green=11):
    start = time.time()
    denoised = nlm_denoising(rgb_stack, img_id=image, h_red=h_red, h_green=h_green)
    duration = time.time() - start
    
    plt.figure(figsize=(10,5))
    plt.subplot(121)
    plt.title("Raw input")
    plt.imshow(rgb_stack[image])
    plt.subplot(122)
    plt.title("Denoised\n(Took %f s.)" % duration)
    plt.imshow(denoised)
    plt.tight_layout()
    plt.show()

print("Detection")
@interact(image=(2, len(rgb_stack) - 3),
          h_red=widgets.BoundedIntText(value=11, min=0),
          h_green=widgets.BoundedIntText(value=11, min=0),
          thresh_fn=widgets.ToggleButtons(options=["Otsu", "Li", "Constant"]),
          thresh_val=widgets.BoundedIntText(value=30, min=0, max=255),
          erosion=(0,5))
def plot_detector(image=0, h_red=11, h_green=11, thresh_fn="Otsu", thresh_val=30, erosion=0):
    if thresh_fn == "Otsu":
        thresholding_fn = filters.threshold_otsu
    elif thresh_fn == "Li":
        thresholding_fn = filters.threshold_li
    elif thresh_fn == "Constant":
        thresholding_fn = lambda x: thresh_val / 255
    
    start = time.time()
    detection = cv_detect(rgb_stack[image - 2: image + 3], 
                          h_red=h_red, h_green=h_green,
                          thresholding_fn=thresholding_fn, 
                          registration=False, selem=disk(erosion))[2]
    duration = time.time() - start
    
    plt.figure(figsize=(13,5))
    plt.subplot(131)
    plt.title("Raw input")
    plt.imshow(rgb_stack[image])
    plt.subplot(132)
    plt.title("Detection\n(Took %f s.)" % duration)
    plt.imshow(detection)
    plt.subplot(133)
    plt.title("Detection contours")
    plt.imshow(overlay_contours(rgb_stack[image], detection))
    plt.tight_layout()
    plt.show()

Denoising


interactive(children=(IntSlider(value=0, description='image', max=599), BoundedIntText(value=11, description='…

Detection


interactive(children=(IntSlider(value=2, description='image', max=597, min=2), BoundedIntText(value=11, descri…

In [16]:
# Apply detector to the whole stack (/!\ a bit slow)
h_red = 12
h_green = 12
thresholding_fn = filters.threshold_otsu
erosion = 1

start = time.time()
cv_detection = cv_detect(rgb_stack, h_red=h_red, h_green=h_green, thresholding_fn=thresholding_fn, 
                         registration=False, selem=disk(erosion))
duration = time.time() - start
duration_str = "%d min %d s" % (duration // 60, duration % 60)
print("CV detection took %s." % duration_str)

@interact(image=(0, len(cv_detection) - 1))
def plot_detector(image=0):
    plt.figure(figsize=(13,5))
    plt.subplot(131)
    plt.title("Raw input")
    plt.imshow(rgb_stack[image])
    plt.subplot(132)
    plt.title("Detection")
    plt.imshow(cv_detection[image])
    plt.subplot(133)
    plt.title("Detection contours")
    plt.imshow(overlay_contours(rgb_stack[image], cv_detection[image]))
    plt.tight_layout()
    plt.show()

CV detection took 3 min 33 s.


interactive(children=(IntSlider(value=0, description='image', max=599), Output()), _dom_classes=('widget-inter…

# Deep learning detector

In [26]:
model_name = "RG_synth_190311"

model_dir = "deep_learning/models/"
input_channels = "RG"
u_depth = 4
out1_channels = 16
batch_size = 32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Load model
sys.path.append(os.path.join(model_dir, model_name))
from utils_model_save import CustomUNet as ModelNet

model = ModelNet(len(input_channels), u_depth=u_depth, out1_channels=out1_channels, batchnorm=True, device=device)
model.load_state_dict(torch.load(os.path.join(model_dir, model_name, "model_best.pth")))
model.eval()

# Predict
transform = lambda stack: normalize_range(pad_transform_stack(stack, u_depth))
start = time.time()
predictions = predict_stack(model, rgb_stack, batch_size, input_channels=input_channels, 
                            channels_last=True, transform=transform)
print("Prediction took %.1f s." % (time.time() - start))
predictions = torch.sigmoid(predictions).numpy()

# Display results
@interact(image=(0, len(cv_detection) - 1),
          detection_th=widgets.FloatSlider(value=0.5, min=0, max=1, step=0.1, readout_format=".1f"))
def plot_detector(image=0, detection_th=0.5):
    detection = predictions[image] > detection_th
    
    plt.figure(figsize=(13,8))
    plt.subplot(231)
    plt.title("Raw input")
    plt.imshow(rgb_stack[image])
    plt.subplot(232)
    plt.title("Detection")
    plt.imshow(detection)
    plt.subplot(233)
    plt.title("Detection contours")
    plt.imshow(overlay_contours(rgb_stack[image], detection))
    plt.subplot(235)
    plt.title("Prediction probabilities")
    plt.imshow(predictions[image])
    plt.tight_layout()
    plt.show()

Device: cuda:0
Prediction took 1.1 s.


interactive(children=(IntSlider(value=0, description='image', max=599), FloatSlider(value=0.5, description='de…

In [27]:
# Make final predictions
dl_detection_th = 0.5

dl_detection = predictions > dl_detection_th
@interact(image=(0, len(cv_detection) - 1))
def plot_detector(image=0):    
    plt.figure(figsize=(13,5))
    plt.subplot(131)
    plt.title("Raw input")
    plt.imshow(rgb_stack[image])
    plt.subplot(132)
    plt.title("Detection")
    plt.imshow(dl_detection[image])
    plt.subplot(133)
    plt.title("Detection contours")
    plt.imshow(overlay_contours(rgb_stack[image], dl_detection[image]))
    plt.tight_layout()
    plt.show()

interactive(children=(IntSlider(value=0, description='image', max=599), Output()), _dom_classes=('widget-inter…

# Select and tune best results
Compare results of detector, manually tune/correct them, etc.

# Save final detection
Save the final detection as ground truths segmentation and move the corresponding experience folder to the annotated directory.