# Setup

In [None]:
from itertools import product

import matplotlib.pyplot as plt
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import numpy as np
import torch
from tqdm.notebook import tqdm

from scipy import ndimage
from skimage.measure import find_contours
from skimage.segmentation import mark_boundaries
from skimage import segmentation
from skimage.util import img_as_float
import cv2
from skimage import data
from skimage.segmentation import (morphological_chan_vese,
                                  morphological_geodesic_active_contour,
                                  inverse_gaussian_gradient,
                                  checkerboard_level_set)

import mouse.utils.constants as const
from mouse.utils import data_util
from mouse.utils import sound_util
from mouse.segmentation import result_processing as proc
from mouse.utils import visualization

In [None]:
data_folder: data_util.DataFolder = data_util.load_data([const.DATA_VPA])[0]
#data_folder.df

In [None]:
l = data_folder.signals[0].signal.shape[0]
squeak_signal = data_folder.signals[0].signal[int(l*3.8/120):int(l*5./120)]

In [None]:
%matplotlib inline
spec = sound_util.spectrogram(squeak_signal)
spec.spec = spec.spec[spec.freqs>18000,:]
spec.freqs = spec.freqs[spec.freqs>18000]

spec_plot = sound_util.SpectrogramData(spec=np.log10(spec.spec), times=spec.times, freqs=spec.freqs)
spec_log = sound_util.SpectrogramData(spec=np.log10(spec.spec), times=spec.times, freqs=spec.freqs)

In [None]:
fig, ax = plt.subplots(1, 1)
visualization.draw_spectrogram(spec_log, ax=ax)

In [None]:
def store_evolution_in(lst):
    """Returns a callback function to store the evolution of the level sets in
    the given list.
    """

    def _store(x):
        lst.append(np.copy(x))

    return _store

In [None]:
# image = np.log(np.array(spec)+1e-9)
image = np.array(spec.spec)

In [None]:
def run_ACWE(image, **kwargs):
    # Initial level set
    init_ls = checkerboard_level_set(image.shape, 6)
#     init_ls = np.ones(image.shape, dtype=np.int8)
    
#     init_ls = np.zeros(image.shape, dtype=np.int8)
#     init_ls[10:-10, 10:-10] = 1
    
    # List with intermediate results for plotting the evolution
    evolution = []
    
    callback = store_evolution_in(evolution)
    ls = morphological_chan_vese(image, init_level_set=init_ls, iter_callback=callback, **kwargs)

    return ls, evolution

def run_GAC(image, preprocess, **kwargs):
    if preprocess:
        gimage = inverse_gaussian_gradient(image)
    else:
        giamge = image
    
    init_ls = np.ones(image.shape, dtype=np.int8)

    # List with intermediate results for plotting the evolution
    evolution = []
    
    callback = store_evolution_in(evolution)
    ls = morphological_geodesic_active_contour(gimage, init_level_set=init_ls, iter_callback=callback, **kwargs)
    return ls, evolution

def get_results_GAC(combinations, image, preprocess=True):
    results = []
    for name, kwargs in tqdm(combinations):
        result, evolution = run_GAC(image, preprocess, **kwargs)
        results.append((name, result, evolution, kwargs))
    return results

def get_results_ACWE(combinations, image):
    results = []
    for name, kwargs in tqdm(combinations):
        result, evolution = run_ACWE(image, **kwargs)
        results.append((name, result, evolution, kwargs))
    return results

In [None]:
def show_result_grid(rows, cols, results, image, fig=None, axes=None, figsize=(16, 8)):
    if fig==None:
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        ax = axes.flatten()

    for i, (name, result, evolution, kwargs) in enumerate(results):
        iterations = kwargs["iterations"]

        contour = ax[i].contour(evolution[iterations//3], [0.5], colors='y')
        contour.collections[0].set_label(f"Iteration {iterations//3}")

        contour = ax[i].contour(evolution[2*iterations//3], [0.5], colors='g')
        contour.collections[0].set_label(f"Iteration {2*iterations//3}")
  
        ax[i].imshow(image, cmap="gray")
        ax[i].set_axis_off()
        ax[i].contour(result, [0.5], colors='r')
        ax[i].set_title(name, fontsize=12)

    fig.tight_layout()

In [None]:
from itertools import product
basic_args_GAC = {"iterations":230, "smoothing":1, "threshold": 0.7, "balloon":-1}

def modify_args(update, basic_args=basic_args_GAC):
    args_combinations = []
    keys = update.keys()
    values = [update[k] for k in keys]
    
    for config in product(*values):
        new_values = {k: config[i] for i, k in enumerate(keys)}
        name = str(new_values)
        args_combinations.append((name, {**basic_args, **new_values}))
    
    return args_combinations

modify_args({"iterations": [1,2], 'smoothing': [1,2]})

# GAC

In [None]:
fig, ax = plt.subplots(1, 1)
visualization.draw_spectrogram(spec_log, ax=ax)

In [None]:
args = modify_args({"threshold": [i/10 for i in range(4,8)], 'balloon': [-1, -0.8, -0.6, -0.4]})
results = get_results_GAC(args, image)

In [None]:
show_result_grid(4, 4, results, image)

In [None]:
args = modify_args({"threshold": [i/10 for i in [7, 8, 9]], 'balloon': [-1, -0.9, -0.8, -0.7]})
results = get_results_GAC(args, image)

In [None]:
show_result_grid(3, 4, results, image)

In [None]:
basic =  {"iterations":230, "smoothing":0, "threshold": 0.9, "balloon":-1}
args = modify_args({"threshold":[0.85, 0.9, 0.95]}, basic)
results = get_results_GAC(args, image)

In [None]:
show_result_grid(3, 1, results, image)

# ACWE

`ones` initial level set - very bad results  
`checkboard` initial level set - acceptable results

In [None]:
basic =  {'iterations': 30, 'smoothing': 3, 'lambda1': 1, 'lambda2': 1}
args = modify_args({"iterations": [30, 50, 80, 100]}, basic)
results = get_results_ACWE(args, image)

In [None]:
basic =  {'iterations': 40, 'smoothing': 3, 'lambda1': 1, 'lambda2': 1}
args = modify_args({"lambda1": [0, 0.5, 1.], "lambda2": [0, 0.5, 1.]}, basic)
results = get_results_ACWE(args, image)

In [None]:
show_result_grid(3, 3, results, image)

In [None]:
basic =  {'iterations': 40, 'smoothing': 3, 'lambda1': 1, 'lambda2': 1}
args = modify_args({"lambda1": [0.995, 1., 1.005]}, basic)
results = get_results_ACWE(args, image)

In [None]:
show_result_grid(3, 1, results, image)