# Setup

In [None]:
from itertools import product
from time import time

import matplotlib.pyplot as plt
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import numpy as np
from tqdm.notebook import tqdm

from skimage.measure import find_contours
from skimage.segmentation import mark_boundaries
from skimage import segmentation
from skimage.util import img_as_float
import cv2
from skimage import data
from skimage.segmentation import (morphological_geodesic_active_contour,
                                  inverse_gaussian_gradient,
                                  checkerboard_level_set)
from scipy import ndimage

import mouse.utils.constants as const
from mouse.utils import data_util
from mouse.utils import sound_util
from mouse import segmentation as mouse_seg
from mouse.utils import visualization
from mouse.utils import metrics

In [None]:
def split_draw_spectrogram(freqs, times, spec, real_squeaks, detected_boxes, data_folder, signal_name, t_delta=1.5, apply_log=True):
    t_span = times[-1] - times[0]
    plots_needed = int(t_span // t_delta) + 1
    fig, axes = plt.subplots(plots_needed, 1, figsize=(11, 2*plots_needed))
    if plots_needed > 1:
        axes = axes.ravel()
    else:
        axes = [axes]
        
    t_start = times[0]
    if apply_log:
        print('applying log')
        s = np.log10(spec)
        vmin = np.min(s)
        vmax = np.max(s)
    else:
        s = np.array(spec)
        vmin = np.min(s)
        vmax = np.max(s)
        
    
    for ax in tqdm(axes):
        t_idx = np.logical_and(times >= t_start, times < t_start + t_delta)
        sub_times = times[t_idx]
        sub_spec = s[:, t_idx]
        sub_spec_len = sub_spec.shape[1]
        real_squeaks = data_util.load_squeak_boxes(data_folder, signal_name, 
                                                   sound_util.SpectrogramData(spec=sub_spec, freqs=freqs, times=sub_times))
        sub_boxes = []
        
        for box in detected_boxes:
            if np.any(t_idx[box.t_start:box.t_end]):
                shift = np.min(np.argwhere(t_idx))
                sub_boxes.append(data_util.SqueakBox(freq_start=box.freq_start,
                                                    freq_end=box.freq_end,
                                                    label=None,
                                                    t_start=np.clip(box.t_start - shift, 0, sub_spec_len),
                                                    t_end=np.clip(box.t_end - shift, 0, sub_spec_len)))

        visualization.draw_spectrogram(sound_util.SpectrogramData(spec=sub_spec, freqs=freqs, times=sub_times), ax, vmin=vmin, vmax=vmax)
        visualization.draw_boxes(real_squeaks, sub_spec.shape[0], ax, linewidth=1, facecolor='none')
        visualization.draw_boxes(sub_boxes, sub_spec.shape[0], ax, edgecolor='cyan', linewidth=1, facecolor='none')
        t_start += t_delta
        
def show_result_grid(rows, cols, results, image, fig=None, axes=None, figsize=(16, 8)):
    if fig==None:
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        ax = axes.flatten()

    for i, (name, result, evolution, kwargs) in enumerate(results):
        iterations = kwargs["iterations"]

        contour = ax[i].contour(evolution[iterations//3], [0.5], colors='y')
        contour.collections[0].set_label(f"Iteration {iterations//3}")

        contour = ax[i].contour(evolution[2*iterations//3], [0.5], colors='g')
        contour.collections[0].set_label(f"Iteration {2*iterations//3}")
  
        ax[i].imshow(image, cmap="gray")
        ax[i].set_axis_off()
        ax[i].contour(result, [0.5], colors='r')
        ax[i].set_title(name, fontsize=12)

    fig.tight_layout()
    
def store_evolution_in(lst):
    """Returns a callback function to store the evolution of the level sets in
    the given list.
    """

    def _store(x):
        lst.append(np.copy(x))

    return _store
    
def run_GAC(image, preprocess, **kwargs):
    if preprocess:
        gimage = inverse_gaussian_gradient(image)
    else:
        giamge = image
    
    init_ls = np.ones(image.shape, dtype=np.int8)

    # List with intermediate results for plotting the evolution
    evolution = []
    
    callback = store_evolution_in(evolution)
    ls = morphological_geodesic_active_contour(gimage, init_level_set=init_ls, iter_callback=callback, **kwargs)
    return ls, evolution

def get_results_GAC(combinations, image, preprocess=True):
    results = []
    for name, kwargs in tqdm(combinations):
        result, evolution = run_GAC(image, preprocess, **kwargs)
        results.append((name, result, evolution, kwargs))
    return results

In [None]:
from itertools import product
basic_args_GAC = {"iterations":230, "smoothing":1, "threshold": 0.7, "balloon":-1}

def modify_args(update, basic_args=basic_args_GAC):
    args_combinations = []
    keys = update.keys()
    values = [update[k] for k in keys]
    
    for config in product(*values):
        new_values = {k: config[i] for i, k in enumerate(keys)}
        name = str(new_values)
        args_combinations.append((name, {**basic_args, **new_values}))
    
    return args_combinations

modify_args({"iterations": [1,2], 'smoothing': [1,2]})

In [None]:
def show_result_grid(rows, cols, results, image, fig=None, axes=None, figsize=(16, 8)):
    if fig==None:
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        ax = axes.flatten()

    for i, (name, result, evolution, kwargs) in enumerate(results):
        iterations = kwargs["iterations"]

        contour = ax[i].contour(evolution[iterations//3], [0.5], colors='y')
        contour.collections[0].set_label(f"Iteration {iterations//3}")

        contour = ax[i].contour(evolution[2*iterations//3], [0.5], colors='g')
        contour.collections[0].set_label(f"Iteration {2*iterations//3}")
  
        ax[i].imshow(image, cmap="gray")
        ax[i].set_axis_off()
        ax[i].contour(result, [0.5], colors='r')
        ax[i].set_title(name, fontsize=12)

    fig.tight_layout()

## Load signal and generate spectrogram

In [None]:
folders = data_util.load_data(const.LABELED_SOURCES)
len(folders)

In [None]:
data_folder: data_util.DataFolder = folders[1]
len(data_folder.signals)

In [None]:
n_fft=512
win_length=256#64
hop_length=128

squeak_signal = data_folder.signals[0]

spec = sound_util.signal_spectrogram(squeak_signal, start=0., end=1.,
                                                  n_fft=n_fft,
                                              win_length=win_length,
                                              hop_length=hop_length)
spec.spec = spec.spec[spec.freqs>18000,:]
spec.freqs = spec.freqs[spec.freqs>18000]
image = img_as_float(spec.spec)
spec.spec = np.array(spec.spec)
# plt.figure(figsize=(20, 20))
# plt.imshow(np.log(spec[:, :3000]))
spec_orginal = sound_util.SpectrogramData(spec=spec.spec, times=spec.times, freqs=spec.freqs)
spec.spec.shape

# GAC

### Tuning parameters

In [None]:
np.min(np.log10(spec.spec[:, :1000])), np.max(np.log10(spec.spec[:, :1000]))

In [None]:
np.min(spec.spec[:, :1000]), np.max(spec.spec[:, :1000])

In [None]:
spec.spec.shape[1] * 13/30

In [None]:
image = spec.spec[:, 22150:22950]
plt.figure(figsize=(10, 3))
plt.imshow(np.log10(image))

In [None]:
image = spec.spec[:, 22150:22950]

# image = segmentation.flood_fill(image, (100,50), 0, tolerance=1., connectivity=300) # no preprocessing
# print(np.min(image[image > 0]))

image = np.log10(image)
image = segmentation.flood_fill(image, (100,50), -3.3, tolerance=.8, connectivity=300) # log10

# image = ndimage.median_filter(image, 4)

# num_tiles = tuple([image.shape[i]//25 for i in (0, 1)])
# clip_limit = 0.05
# alpha = 0.2
# print(num_tiles)
# # image = adaptive_histogram_equalization(log_and_normalize(image), num_tiles, alpha=alpha, clip_limit=clip_limit)
# # image = ndimage.median_filter(image, 8)
# image = np.log10(image)

plt.figure(figsize=(10, 3))
plt.imshow(image)

In [None]:
args = modify_args({"threshold": [i/10 for i in [7, 8, 9]], 'balloon': [-1, -0.9, -0.8]})
results = get_results_GAC(args, image)

In [None]:
show_result_grid(3, 3, results, image)

In [None]:
basic =  {"iterations":230, "smoothing":0, "threshold": 0.7, "balloon":-1.}
args = modify_args({"smoothing":[0, 1, 2]}, basic)
results = get_results_GAC(args, image)

In [None]:
show_result_grid(3, 1, results, image)

In [None]:
basic =  {"iterations":230, "smoothing":1, "threshold": 0.7, "balloon":-1.}
args = modify_args({"iterations":[100, 130, 150, 300]}, basic)
results = get_results_GAC(args, image)

In [None]:
show_result_grid(4, 1, results, image)

### Run GAC on whole spectrogram

In [None]:
t = time()
kwargs =  {"iterations":150, "smoothing":0, "threshold": 0.9, "balloon":-1} # no processing
# ls = run_GAC(image, True, **kwargs)

spec_cp = np.copy(spec.spec)
spec.spec = np.array(spec.spec)

### PREPROCESSING SECTION
### uncomment a subsection to add preprocessing
## log10
# kwargs =  {"iterations":230, "smoothing":1, "threshold": 0.8, "balloon":-1}
# spec_log = np.log10(spec.spec)
# spec.spec = segmentation.flood_fill(spec_log, (100,50), -3.3, tolerance=1., connectivity=300) # log10

## flood
spec.spec = segmentation.flood_fill(spec.spec, (100,50), 0, tolerance=1., connectivity=300)

# median filter; bad method?
# kwargs =  {"iterations":230, "smoothing":1, "threshold": 0.7, "balloon":-0.9}
# spec.spec = ndimage.median_filter(spec.spec, 4)

### PREPROCESSING ENDS

boxes = mouse_seg.find_USVs(spec, **kwargs)
print(time() - t, "[s]")
spec.spec = spec_cp

len(boxes)

In [None]:
real_squeaks = data_util.load_squeak_boxes(data_folder, squeak_signal.name, spec)
filtered_boxes = data_util.filter_boxes(spec, boxes)
print(f"{len(filtered_boxes)} left out of {len(boxes)}")
len(real_squeaks)

In [None]:
# fig, ax = plt.subplots(figsize=(20, 5))
# visualization.draw_spectrogram(spec_log, ax)
# visualization.draw_boxes(real_squeaks, spec.get_height(), ax, linewidth=1, facecolor='none')
# visualization.draw_boxes(boxes, spec.get_height(), ax, edgecolor='b', linewidth=1, facecolor='none')
# plt.show()

In [None]:

# fig, ax = plt.subplots(figsize=(20, 5))
# visualization.draw_spectrogram(spec, ax)
# visualization.draw_boxes(real_squeaks, spec.get_height(), ax, linewidth=1, facecolor='none')
# visualization.draw_boxes(filtered_boxes, spec.get_height(), ax, edgecolor='g', linewidth=2, facecolor='none')
# plt.show()

In [None]:
metrics.coverage(real_squeaks, filtered_boxes)

In [None]:
real_squeaks = [r for r in real_squeaks if r.label != "junk"]

In [None]:
threshold = 0.

In [None]:
print(metrics.detection_recall(ground_truth=real_squeaks, prediction=filtered_boxes, threshold=threshold))
print(metrics.detection_precision(ground_truth=real_squeaks, prediction=filtered_boxes, threshold=threshold))

In [None]:
# def get_ommited(spec, boxes_to_cover, boxes_covering):
#     to_cover_mask = np.ones(spec.shape[1])
#     covering_mask = np.zeros(spec.shape[1])
#     for box in boxes_to_cover:
#         to_cover_mask[box.t_start:box.t_end] = 0
#     for box in boxes_covering:
#         covering_mask[box.t_start:box.t_end] = 1
#     return np.logical_not(np.logical_or(to_cover_mask, covering_mask))

# ommited = get_ommited(spec.spec, real_squeaks, filtered_boxes)

In [None]:
!mkdir figures
!mkdir figures/preprocessed
!mkdir figures/log

In [None]:
threshold = 0
recall = metrics.detection_recall(ground_truth=real_squeaks, prediction=filtered_boxes, threshold=threshold)[0]
prec = metrics.detection_precision(ground_truth=real_squeaks, prediction=filtered_boxes, threshold=threshold)

In [None]:
s = np.array(spec.spec)
t_delta = 1
s[s==0] = np.min(s[s!=0])
split_draw_spectrogram(spec.freqs, spec.times, s, real_squeaks, filtered_boxes, data_folder,  squeak_signal.name, t_delta=t_delta)
# split_draw_spectrogram(spec.freqs, spec.times, np.array(spec.spec), real_squeaks, boxes, data_folder,  squeak_signal.name)
# plt.title(f"recall={recall:.2f} | precision={prec:.2f} | threshold={threshold}")
plt.tight_layout()
# plt.savefig(fname=f'figures/preprocessed/{squeak_signal.folder}-{squeak_signal.name}--flood',format='svg', facecolor='w')
plt.show()

In [None]:
s = np.array(spec_orginal.spec)
t_delta = 1
s[s==0] = np.min(s[s!=0])
split_draw_spectrogram(spec.freqs, spec.times, s, real_squeaks, filtered_boxes, data_folder,  squeak_signal.name, t_delta=t_delta)
# split_draw_spectrogram(spec.freqs, spec.times, np.array(spec.spec), real_squeaks, boxes, data_folder,  squeak_signal.name)
plt.tight_layout()
# plt.title(f"recall={recall:.2f} | precision={prec:.2f} | threshold={threshold}")

# plt.savefig(fname=f'figures/log/{squeak_signal.folder}-{squeak_signal.name}--flood',format='svg', facecolor='w')
plt.show()

## test all methods and save results

In [None]:
def load_spec(signal):
    spec = sound_util.signal_spectrogram(signal, start=0., end=.05,
                                                  n_fft=n_fft,
                                              win_length=win_length,
                                              hop_length=hop_length)
    spec.spec = spec.spec[spec.freqs>18000,:]
    spec.freqs = spec.freqs[spec.freqs>18000]

    spec.spec = np.array(spec.spec)
    return spec

In [None]:
kwargs =  {"iterations":230, "smoothing":0, "threshold": 0.9, "balloon":-1} # no processing

# median filter; bad method?
# kwargs =  {"iterations":230, "smoothing":1, "threshold": 0.7, "balloon":-0.9}
# spec.spec = ndimage.median_filter(spec.spec, 8)

methods = {#'no-processing': (kwargs, lambda spec: spec),
           'flood': (kwargs, lambda spec: segmentation.flood_fill(spec, (100,50), 0, tolerance=1., connectivity=300))#,
           # 'log10-and-flood': ({"iterations":230, "smoothing":1, "threshold": 0.8, "balloon":-1},
           #                    lambda spec: segmentation.flood_fill(np.log10(spec), (100,50), -3.3, tolerance=1., connectivity=300))
          }

In [None]:
for folder in folders:
    for sig in tqdm(folder.signals[:1]):
        spec_copy = load_spec(sig)
        
        for name, (args, method) in methods.items():
            spec = sound_util.SpectrogramData(spec=spec_copy.spec, times=spec_copy.times, freqs=spec_copy.freqs)
            spec.spec = method(spec.spec)
            boxes = mouse_seg.find_USVs(spec, **args)
            
            real_squeaks = data_util.load_squeak_boxes(folder, sig.name, spec)
            filtered_boxes = data_util.filter_boxes(spec, boxes)
            print(f"{len(filtered_boxes)} left out of {len(boxes)}  | {len(real_squeaks)} real boxes")
            
#             threshold = 0
#             recall = metrics.detection_recall(ground_truth=real_squeaks, prediction=filtered_boxes, threshold=threshold)[0]
#             prec = metrics.detection_precision(ground_truth=real_squeaks, prediction=filtered_boxes, threshold=threshold)
            
            plt.figure()
            s = np.array(spec.spec)
            t_delta = 1
            split_draw_spectrogram(spec.freqs, spec.times, s, real_squeaks, filtered_boxes, folder,  sig.name, t_delta=t_delta, apply_log=False)
            plt.tight_layout()
            plt.savefig(fname=f'figures/preprocessed/{sig.folder}-{sig.name}--{name}',format='svg', facecolor='w')
            plt.close('all')
            plt.figure()
            
            
            s = np.array(spec_copy.spec)
            t_delta = 1
            # if np.any(s<=1e-6):
            s[s<=1e-6] = 1e-6
            split_draw_spectrogram(spec.freqs, spec.times, s, real_squeaks, filtered_boxes, folder,  sig.name, t_delta=t_delta, apply_log=True)
            plt.tight_layout()

            plt.savefig(fname=f'figures/log/{sig.folder}-{sig.name}--{name}',format='svg', facecolor='w')
            plt.close('all')