In [None]:
from itertools import product
from time import time
from copy import deepcopy
from collections import defaultdict

import matplotlib.pyplot as plt
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import numpy as np
import torch
from tqdm.notebook import tqdm
import pandas as pd

from scipy import ndimage
from skimage.measure import find_contours
from skimage.segmentation import mark_boundaries
from skimage import segmentation
from skimage.util import img_as_float
import cv2
from skimage import data
from skimage.segmentation import (morphological_geodesic_active_contour,
                                  inverse_gaussian_gradient,
                                  checkerboard_level_set)

import mouse.utils.constants as const
from mouse.utils import data_util
from mouse.utils import sound_util
from mouse import segmentation as mouse_seg
from mouse.utils import visualization
from mouse.utils import metrics

In [None]:
data_folder: data_util.DataFolder = data_util.load_data(const.LABELED_SOURCES)[0]

In [None]:
n_fft=512
win_length=256#64
hop_length=128

squeak_signal = data_folder.signals[0]

spec = sound_util.signal_spectrogram(squeak_signal, start=0., end=.1,
                                                  n_fft=n_fft,
                                              win_length=win_length,
                                              hop_length=hop_length)
spec.spec = spec.spec[spec.freqs>18000,:]
spec.freqs = spec.freqs[spec.freqs>18000]

In [None]:
spec_log = deepcopy(spec)
spec_log.spec = np.log10(spec.spec)

In [None]:
results_by_method = {}

In [None]:
boxes = mouse_seg.find_USVs(spec)
results_by_method["basic GAC segmentation"] = boxes

In [None]:
boxes = mouse_seg.find_USVs(spec_log)
results_by_method["basic GAC segmentation with log preprocessing"] = boxes

In [None]:
# get real squeaks
real_squeaks = data_util.load_squeak_boxes(data_folder, squeak_signal.name, spec)

In [None]:
for name, boxes in results_by_method.items():
    print(name, len(boxes))

In [None]:
len(real_squeaks)

In [None]:
# t = time()
# res = metrics.intersection_over_union_elementwise(cover=real_squeaks, 
#                                             target=results_by_method['basic GAC segmentation'])
# print(time() - t)
# len(res)

In [None]:
def compare(results_by_method: dict, real_squeaks):
    comparison = defaultdict(dict)
    for name, boxes in tqdm(results_by_method.items()):
#         comparison['element precision'] = metrics.intersection_over_union_elementwise(target=real_squeaks, cover=boxes)
        comparison[name]['IoU'] = metrics.intersection_over_union_global(ground_truth=real_squeaks, prediction=boxes, axis=0)
        cover_recall, cover_precision = metrics.coverage(squeaks_1=real_squeaks, squeaks_2=boxes)
        comparison[name]['Cover_recall'] = cover_recall
        comparison[name]['Cover_precision'] = cover_precision
        comparison[name]['Cover_f1'] = 2 * cover_precision * cover_recall / (cover_precision + cover_recall)
        
#         comparison['element recall'] = metrics.intersection_over_union_elementwise(cover=real_squeaks, target=boxes)
        
    return pd.DataFrame(comparison)

compare(results_by_method, real_squeaks)