# Entropy-energy search

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import cv2
from scipy import stats, signal
from tqdm.notebook import tqdm

from mouse.utils import constants
from mouse.utils import data_util
from mouse.utils import metrics
from mouse.utils import sound_util
from mouse.utils import visualization

## Implementation

In [None]:
data = data_util.load_data(constants.LABELED_SOURCES)

In [None]:
def entropy(x):
    x = np.array(x)
    return stats.mstats.gmean(x, axis=0) / np.mean(x, axis=0)

def moving_average(x, w):
    x = np.pad(x, (w//2, (w-1)//2), 'constant', constant_values=(1,))
    return np.convolve(x, np.ones(w), 'valid') / w

def moving_average_2d(x, w_height, w_width):
    x = np.pad(x, ((w_height//2, (w_height-1)//2),(0,0)), 'mean')
    return signal.convolve2d(x, np.ones((w_height, w_width)), 'valid').flatten() / (w_height * w_width)

In [None]:
def entropy_plot(signal, window=100):
    spec = sound_util.spectrogram(signal[None, :])
    spec_ = spec.spec[spec.freqs > 22000, :]
    plt.figure(figsize=(10,5))
    plt.plot(moving_average(entropy(spec_), window))
    plt.show()

def draw_spectrogram(times, freqs, spec, title=''):
    plt.figure(figsize=(10,5))
    plt.pcolormesh(times, freqs, spec, shading='gouraud')
    plt.title(title)
    #plt.ylabel('Frequency [Hz]')
    #plt.xlabel('Time [sec]')

In [None]:
# # Code operating on binary maps
# def mark_low_entropy(spectrogram, window=100, threshold=0.780):
#     avg_entropy = moving_average(entropy(spectrogram), window)
#     low_entropy = avg_entropy < threshold
#     return np.array([low_entropy]*spectrogram.shape[0])

# def mark_high_energy_rows(spectrogram, energy_window=10, energy_threshold=1.5):
#     mean_energy = np.mean(np.power(spectrogram, 2).numpy())
#     width = spectrogram.shape[1]

#     return moving_average_2d(np.power(spectrogram, 2).numpy(), energy_window, width) > energy_threshold * mean_energy

# def mark_high_energy(spectrogram, mask_to_check, energy_window, energy_threshold):
#     i = 0
#     while i < spectrogram.shape[1]:
#         if mask_to_check[0, i]:
#             start = i
#             stop = i
#             while stop < spectrogram.shape[1] and mask_to_check[0, stop]:
#                 stop += 1
#             marked_energy = mark_high_energy_rows(spectrogram[:,start:stop], energy_window, energy_threshold)
#             for j in range(start, stop):
#                 mask_to_check[:, j] = marked_energy
#             i = stop
#         else:
#             i += 1
#     return mask_to_check

# def mark_low_entropy_high_energy(spectrogram, entropy_window=100, entropy_threshold=0.780, energy_window=10, energy_threshold=1.5):
#     marked_entropy = mark_low_entropy(spectrogram, entropy_window, entropy_threshold)
#     #return marked_entropy
#     return mark_high_energy(spectrogram, marked_entropy, energy_window, energy_threshold)
  
# freq_cutoff = int(22/sound_data.duration * spect.freqs.shape[0])
# delta_time = int(0.25/sound_data.duration * spect.times.shape[0])
# start_time = int(4.6/sound_data.duration * spect.times.shape[0])
# end_time = start_time + delta_time

# marked = mark_low_entropy_high_energy(
#     spect.spec[freq_cutoff:, 
#     start_time:end_time], 
#     entropy_window=10,
#     entropy_threshold=0.790,
#     energy_window=20,
#     energy_threshold=1.5)
# plt.figure(figsize=(10, 10))
# plt.imshow(spect.spec[freq_cutoff:, start_time:end_time])
# plt.imshow(marked, cmap='jet', alpha=0.2)
# plt.show()    

In [None]:
def binary_mask_to_ranges(binary_mask):
    beginnings = [i+1 for i in range(binary_mask.shape[0]-1) if not binary_mask[i] and binary_mask[i+1]]
    
    ends = [i-1 for i in range(1, binary_mask.shape[0]) if binary_mask[i-1] and not binary_mask[i]]
    
    if binary_mask[0]:
        beginnings = [0] + beginnings
    
    if binary_mask[-1]:
        ends.append(binary_mask.shape[0])
        
    return list(zip(beginnings, ends))

def mark_low_entropy(spectrogram, window=100, threshold=0.780):
    avg_entropy = moving_average(entropy(spectrogram), window)
    low_entropy = avg_entropy < threshold
        
    return binary_mask_to_ranges(low_entropy)

def mark_high_energy_rows(spectrogram, energy_window=10, energy_threshold=1.5):
    if not isinstance(spectrogram, np.ndarray):
        spectrogram = spectrogram.numpy()
    mean_energy = np.mean(np.power(spectrogram, 2))
    width = spectrogram.shape[1]
    high_energy = moving_average_2d(np.power(spectrogram, 2), energy_window, width) > energy_threshold * mean_energy
    return binary_mask_to_ranges(high_energy)

def mark_low_entropy_high_energy(spectrogram, freqs, freq_cutoff=22000, entropy_window=100, entropy_threshold=0.780, energy_window=10, energy_threshold=1.5):
    spectrogram_ = spectrogram[freqs > freq_cutoff,:]
    print(spectrogram_.shape)
    delta_freq = spectrogram.shape[0] - spectrogram_.shape[0]
    columns_to_check = mark_low_entropy(spectrogram_, entropy_window, entropy_threshold)
    squeak_boxes = []
    for cols in columns_to_check:
        row_ranges = mark_high_energy_rows(spectrogram_[:, cols[0]:cols[1]+1], energy_window, energy_threshold)
        for rows in row_ranges:
            squeak_boxes.append(data_util.SqueakBox(
                freq_start=rows[0] + delta_freq, 
                freq_end=rows[1] + delta_freq, 
                t_start=cols[0], 
                t_end=cols[1],  
                label=None))
    return squeak_boxes

def merge_entropy_neighbors(squeaks, delta_freq, delta_time):
    squeaks.sort(key=lambda squeak : (squeak.t_start, squeak.freq_start))
    print(squeaks)
    result = []
    to_merge = []
    merged = data_util.SqueakBox(0, 0, 0, 0, None)
    for i in range(len(squeaks)):
        
        if len(to_merge) > 0:
            print("f", abs(squeaks[i].freq_start - to_merge[-1].freq_end))
            print("t", squeaks[i].t_start - to_merge[-1].t_end)
            print(to_merge[-1])
            print(squeaks[i])
            if (squeaks[i].t_start - to_merge[-1].t_end <= delta_time 
                and abs(squeaks[i].freq_start - to_merge[-1].freq_end) <= delta_freq):
                to_merge.append(squeaks[i])
            else:
                merged_squeak = data_util.SqueakBox(
                    freq_start = min(to_merge, key=lambda squeak: squeak.freq_start).freq_start,
                    freq_end = max(to_merge, key=lambda squeak: squeak.freq_end).freq_end,
                    t_start = min(to_merge, key=lambda squeak: squeak.t_start).t_start,
                    t_end = max(to_merge, key=lambda squeak: squeak.t_end).t_end,
                    label = None
                )
                result.append(merged_squeak)
                to_merge = []
            
        if len(to_merge) == 0:
            to_merge.append(squeaks[i]) 
            
    if len(to_merge) > 0:
        merged_squeak = data_util.SqueakBox(
            freq_start = min(to_merge, key=lambda squeak: squeak.freq_start).freq_start,
            freq_end = max(to_merge, key=lambda squeak: squeak.freq_end).freq_end,
            t_start = min(to_merge, key=lambda squeak: squeak.t_start).t_start,
            t_end = max(to_merge, key=lambda squeak: squeak.t_end).t_end,
            label = None
            )
        result.append(merged_squeak)
        to_merge = []
            
    return result

def merge_entropy_neighbors(squeaks, delta_freq, delta_time):
    squeaks.sort(key=lambda squeak : (squeak.t_start, squeak.freq_start))

    result = []
    merged = data_util.SqueakBox(0, 0, 0, 0, None)
    merged_present = False
    for i in range(len(squeaks)):

        if merged_present:

            if (squeaks[i].t_start - merged.t_end <= delta_time 
                and (abs(squeaks[i].freq_start - merged.freq_end) <= delta_freq
                or abs(squeaks[i].freq_end - merged.freq_start) <= delta_freq)):
                
                merged = data_util.SqueakBox(
                    t_start = min(squeaks[i].t_start, merged.t_start),
                    t_end = max(squeaks[i].t_end, merged.t_end),
                    freq_start = min(squeaks[i].freq_start, merged.freq_start),
                    freq_end = max(squeaks[i].freq_end, merged.freq_end),
                    label = None)
            else:
                result.append(merged)
                merged = data_util.SqueakBox(0, 0, 0, 0, None)
                merged_present = False
            
        if not merged_present:
            merged = squeaks[i] 
            merged_present = True
            
    if merged_present:
        result.append(merged)
            
    return result

def filter_by_ratio(squeaks, ratio_cutoff):
    result = []
    for squeak in squeaks:
        width = squeak.t_end - squeak.t_start
        height = squeak.freq_end - squeak.freq_start

        if width > 0 and height / width < ratio_cutoff:
            result.append(squeak)
    return result
            
def get_squeak_boxes(
    spectrogram, 
    entropy_window=100, 
    entropy_threshold=0.780, 
    energy_window=10, 
    energy_threshold=1.5):
    
    squeak_boxes = mark_low_entropy_high_energy(
        spectrogram.spec, 
        entropy_window, 
        entropy_threshold, 
        energy_window, 
        energy_threshold)
    print(squeak_boxes)
    delta_freq = spectrogram.freqs[-1] - spectrogram.freqs[0]
    delta_time = spectrogram.times[-1] - spectrogram.times[0]
    return [data_util.SqueakBox(
        freq_start=spectrogram.freqs[0] + squeak.freq_start * delta_freq / spectrogram.freqs.shape[0],
        freq_end=spectrogram.freqs[0] + squeak.freq_end * delta_freq / spectrogram.freqs.shape[0],
        t_start=spectrogram.times[0] + squeak.t_start * delta_time / spectrogram.times.shape[0],
        t_end=spectrogram.times[0] + squeak.t_end * delta_time / spectrogram.times.shape[0],
        label=None
    ) for squeak in squeak_boxes]

In [None]:
def draw_detection(start_time = 2.5, delta_time = 1.):
    freq_cutoff = int(22 / sound_data.duration * spect.freqs.shape[0])
    delta_time = int(delta_time / sound_data.duration * spect.times.shape[0])
    start_time = int(start_time / sound_data.duration * spect.times.shape[0])
    end_time = start_time + delta_time

    spect_fragment = sound_util.SpectrogramData(
        spect.spec[freq_cutoff:,start_time:end_time],
        spect.times[start_time:end_time],
        spect.freqs[freq_cutoff:]
    )
    marked = mark_low_entropy_high_energy(
        spect_fragment.spec, 
        spect_fragment.freqs, 
        entropy_window=5,
        entropy_threshold=0.79,
        energy_window=18,
        energy_threshold=2.1)
    marked_merged = merge_entropy_neighbors(marked, 35, 10)
    marked_filtered = filter_by_ratio(marked_merged, 2.1)

    #ground_truth = data_util.load_squeak_boxes(data[0].df, data[0].wavs[0], spect_fragment)
    ground_truth = data_util.load_squeak_boxes(data[0].df, data[0].wavs[-1], spect_fragment)

    axs = plt.figure(figsize=(20, 40)).subplots(1, 1)
    visualization.draw_spectrogram(spect_fragment, axs)
    visualization.draw_boxes(marked, spect_fragment.get_height(), axs)
    visualization.draw_boxes(marked_filtered, spect_fragment.get_height(), axs, edgecolor="yellow")
    visualization.draw_boxes(ground_truth, spect_fragment.get_height(), axs, edgecolor="green")

In [None]:
def draw_spectrogram_for_comparison(start_time = 2.5, delta_time = 1.):
    #Spectrogram
    freq_cutoff = 0
    delta_time = int(delta_time/sound_data.duration * spect.times.shape[0])
    start_time = int(start_time/sound_data.duration * spect.times.shape[0])
    end_time = start_time + delta_time
    draw_spectrogram(spect.times[start_time:end_time], spect.freqs[freq_cutoff:], spect.spec[freq_cutoff:, start_time:end_time])
    plt.show()
def draw_entropy_for_comparison(start_time = 2.5, delta_time = 1.):
    # Entropy plot
    delta_time = int(delta_time/sound_data.duration * sound_data.signal.shape[0])
    start_time = int(start_time/sound_data.duration * sound_data.signal.shape[0])
    end_time = start_time + delta_time
    entropy_plot(sound_data.signal[start_time:end_time], window=5)
    plt.show()

In [None]:
def split_draw_spectrogram(freqs, times, spec, real_squeaks, detected_boxes, data_folder, signal_name, t_delta=1.5):
    t_span = times[-1] - times[0]
    plots_needed = int(t_span // t_delta) + 1
    fig, axes = plt.subplots(plots_needed, 1, figsize=(11, 2*plots_needed))
    if plots_needed > 1:
        axes = axes.ravel()
    else:
        axes = [axes]
        
    t_start = times[0]
    vmin = np.min(np.log10(spec))
    vmax = np.max(np.log10(spec))
    
    for ax in tqdm(axes):
        t_idx = np.logical_and(times >= t_start, times < t_start + t_delta)
        sub_times = times[t_idx]
        sub_spec = spec[:, t_idx]
        sub_spec_len = sub_spec.shape[1]
        real_squeaks = data_util.load_squeak_boxes(data_folder, signal_name, 
                                                   sound_util.SpectrogramData(spec=np.log10(sub_spec), freqs=freqs, times=sub_times))
        sub_boxes = []
        
        for box in detected_boxes:
            if np.any(t_idx[box.t_start:box.t_end]):
                shift = np.min(np.argwhere(t_idx))
                sub_boxes.append(data_util.SqueakBox(freq_start=box.freq_start,
                                                    freq_end=box.freq_end,
                                                    label=None,
                                                    t_start=np.clip(box.t_start - shift, 0, sub_spec_len),
                                                    t_end=np.clip(box.t_end - shift, 0, sub_spec_len)))

        visualization.draw_spectrogram(sound_util.SpectrogramData(spec=np.log10(sub_spec), freqs=freqs, times=sub_times), ax, vmin=vmin, vmax=vmax)
        visualization.draw_boxes(real_squeaks, sub_spec.shape[0], ax, linewidth=1, facecolor='none')
        visualization.draw_boxes(sub_boxes, sub_spec.shape[0], ax, edgecolor='cyan', linewidth=1, facecolor='none')
        t_start += t_delta

In [None]:
def intersection_over_union_elementwise(prediction, ground_truth):
            
    return (metrics.intersection_over_union_elementwise(prediction, ground_truth),
           metrics.intersection_over_union_elementwise(ground_truth, prediction))

## Examples

In [None]:
sound_data = sound_util.SignalData(data[0].wavs[0])
spect = sound_util.signal_spectrogram(sound_data)

### Whole spectrogram - no denoising

In [None]:
ground_truth = data_util.load_squeak_boxes(data[0].df, data[0].wavs[0], spect)
#Best so far: (5, 0.71, 18, 2.1, 35, 10, 2.1)
marked = mark_low_entropy_high_energy(
                     spect.spec, 
                     spect.freqs,
                     entropy_window=5,
                     entropy_threshold=0.71,
                     energy_window=18,
                     energy_threshold=2.1)
marked_merged = merge_entropy_neighbors(marked, 35, 10)
marked_filtered = filter_by_ratio(marked_merged, 2.1)

In [None]:
ground_truth = data_util.load_squeak_boxes(data[0].df, data[0].wavs[0], spect)

In [None]:
#positive_true_raw, true_positive_raw = intersection_over_union_elementwise(marked, ground_truth)
positive_true_filtered, true_positive_filtered = intersection_over_union_elementwise(marked_filtered, ground_truth)
positive_true_merged, true_positive_merged = intersection_over_union_elementwise(marked_merged, ground_truth)

In [None]:
print("Raw detections")
s = np.array(spect.spec)
t_delta = 1
s[s==0] = np.min(s[s!=0])
split_draw_spectrogram(
    spect.freqs,
    spect.times, 
    s,
    ground_truth,
    marked, 
    data[0].df,
    data[0].wavs[0])
plt.tight_layout()
plt.show()

In [None]:
print("Merged detections")
precision = metrics.detection_precision(
    pred_given_truth_iou_dict=true_positive_merged, 
    truth_given_pred_iou_dict=positive_true_merged,
    threshold = 0.1)
print(f"Precision: {precision}")

recall_overall, recall_labels = metrics.detection_recall(
    pred_given_truth_iou_dict=true_positive_merged, 
    threshold = 0.1)
print(f"Overall recall: {recall_overall}")
print("Recall per label:")
for label, recall in recall_labels.items():
    print(f"                  {label} : {recall}")
# Plot spectrogram    
s = np.array(spect.spec)
t_delta = 1
s[s==0] = np.min(s[s!=0])
split_draw_spectrogram(
    spect.freqs,
    spect.times, 
    s,
    ground_truth,
    marked_merged, 
    data[0].df,
    data[0].wavs[0])
plt.tight_layout()
plt.show()

In [None]:
print("Filtered detections")
precision = metrics.detection_precision(
    pred_given_truth_iou_dict=true_positive_filtered, 
    truth_given_pred_iou_dict=positive_true_filtered,
    threshold = 0.1)
print(f"Precision: {precision}")

recall_overall, recall_labels = metrics.detection_recall(
    pred_given_truth_iou_dict=true_positive_filtered, 
    threshold = 0.1)
print(f"Overall recall: {recall_overall}")
print("Recall per label:")
for label, recall in recall_labels.items():
    print(f"                  {label} : {recall}")
# Plot spectrogram    
s = np.array(spect.spec)
t_delta = 1
s[s==0] = np.min(s[s!=0])
split_draw_spectrogram(
    spect.freqs,
    spect.times, 
    s,
    ground_truth,
    marked_filtered, 
    data[0].df,
    data[0].wavs[0])
plt.tight_layout()
plt.show()

### Whole spectrogram - denoising

In [None]:
# adaptive_histogram_equalization
# h, w = spect.spec.shape
# num_tiles = (max(h // 50, 2), max(w // 50, 2))
# #num_tiles = (8, 16)
# clip_limit = 0.05
# alpha = 0.1

# spec_den = test_utils.adaptive_histogram_equalization(
#             test_utils.log_and_normalize(spect.spec),
#             num_tiles=num_tiles,
#             clip_limit=clip_limit,
#             alpha=alpha)

### Bilateral Filter

In [None]:
blur = cv2.bilateralFilter(spect.spec.numpy(), 9,75,75)

In [None]:
# start_time = 3.5
# delta_time = 1.
# #Spectrogram
# freq_cutoff = 0
# delta_time = int(delta_time/sound_data.duration * spect.times.shape[0])
# start_time = int(start_time/sound_data.duration * spect.times.shape[0])
# end_time = start_time + delta_time
# draw_spectrogram(spect.times[start_time:end_time], spect.freqs[freq_cutoff:], blur[freq_cutoff:, start_time:end_time])
# plt.show()

# spec_ = blur[spect.freqs > 22000, :]
# plt.figure(figsize=(10,5))
# plt.plot(moving_average(entropy(spec_[freq_cutoff:, start_time:end_time]), 5))
# plt.show()

In [None]:
ground_truth = data_util.load_squeak_boxes(data[0].df, data[0].wavs[0], spect)
#Best so far: (5, 0.71, 18, 2.1, 35, 10, 2.1)
marked_den = mark_low_entropy_high_energy(
                     blur, 
                     spect.freqs,
                     entropy_window=5,
                     entropy_threshold=0.91,
                     energy_window=18,
                     energy_threshold=2.1)
marked_merged_den = merge_entropy_neighbors(marked_den, 35, 10)
marked_filtered_den = filter_by_ratio(marked_merged_den, 2.1)

In [None]:
#positive_true_raw_den, true_positive_raw_den = intersection_over_union_elementwise(marked_den, ground_truth_den)
positive_true_filtered_den, true_positive_filtered_den = intersection_over_union_elementwise(marked_filtered_den, ground_truth)
positive_true_merged_den, true_positive_merged_den = intersection_over_union_elementwise(marked_merged_den, ground_truth)

In [None]:
print("Raw detections (denoising)")
s = np.array(blur)
t_delta = 1
s[s==0] = np.min(s[s!=0])
split_draw_spectrogram(
    spect.freqs,
    spect.times, 
    s,
    ground_truth,
    marked_den, 
    data[0].df,
    data[0].wavs[0])
plt.tight_layout()
plt.show()

In [None]:
print("Merged detections (denoising)")
precision = metrics.detection_precision(
    pred_given_truth_iou_dict=true_positive_merged_den, 
    truth_given_pred_iou_dict=positive_true_merged_den,
    threshold = 0.1)
print(f"Precision: {precision}")

recall_overall, recall_labels = metrics.detection_recall(
    pred_given_truth_iou_dict=true_positive_merged_den, 
    threshold = 0.1)
print(f"Overall recall: {recall_overall}")
print("Recall per label:")
for label, recall in recall_labels.items():
    print(f"                  {label} : {recall}")
# Plot spectrogram    
s = np.array(blur)
t_delta = 1
s[s==0] = np.min(s[s!=0])
split_draw_spectrogram(
    spect.freqs,
    spect.times, 
    s,
    ground_truth,
    marked_merged_den, 
    data[0].df,
    data[0].wavs[0])
plt.tight_layout()
plt.show()

In [None]:
print("Filtered detections (denoising)")
precision = metrics.detection_precision(
    pred_given_truth_iou_dict=true_positive_filtered_den, 
    truth_given_pred_iou_dict=positive_true_filtered_den,
    threshold = 0.1)
print(f"Precision: {precision}")

recall_overall, recall_labels = metrics.detection_recall(
    pred_given_truth_iou_dict=true_positive_filtered_den, 
    threshold = 0.1)
print(f"Overall recall: {recall_overall}")
print("Recall per label:")
for label, recall in recall_labels.items():
    print(f"                  {label} : {recall}")
# Plot spectrogram    
s = np.array(blur)
t_delta = 1
s[s==0] = np.min(s[s!=0])
split_draw_spectrogram(
    spect.freqs,
    spect.times, 
    s,
    ground_truth,
    marked_filtered_den, 
    data[0].df,
    data[0].wavs[0])
plt.tight_layout()
plt.show()

### When it doesn't work?

In [None]:
draw_spectrogram_for_comparison(start_time = 2.5, delta_time = 1.)
draw_entropy_for_comparison(start_time = 2.5, delta_time = 1.)

In [None]:
draw_detection(start_time = 2.5, delta_time = 1.)

### When it works?

In [None]:
draw_spectrogram_for_comparison(start_time = 4.5, delta_time = 2)
draw_entropy_for_comparison(start_time = 4.5, delta_time = 2)

In [None]:
draw_detection(start_time = 4.5, delta_time = 1.)

## Evaluation

In [None]:
# Search space
# entropy_window = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
# entropy_threshold = [0.71, 0.73, 0.75, 0.77, 0.79, 0.81, 0.83, 0.85]
# energy_window = [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
# energy_threshold = [1.3, 1.5, 1.7, 1.9, 2.1]
# delta_freq = [35, 40, 45, 50]
# delta_time = [5, 10, 15, 20]
# ratio_cutoff = [1.3, 1.5, 1.7, 1.9, 2.1]
# outcome = {}

In [None]:
# ground_truth = data_util.load_squeak_boxes(data[0].df, data[0].wavs[0], spect)
# for ent_w in entropy_window:
#     for ent_t in entropy_threshold:
#         for ene_w in energy_window:
#             for ene_t in energy_threshold:
                
#                 marked = mark_low_entropy_high_energy(
#                     spect.spec, 
#                     entropy_window=ent_w,
#                     entropy_threshold=ent_t,
#                     energy_window=ene_w,
#                     energy_threshold=ene_t)
                
#                 for d_f in delta_freq:
#                     for d_t in delta_time:
#                         marked_merged = merge_entropy_neighbors(marked, d_f, d_t)
                        
#                         for r_c in ratio_cutoff:
#                             marked_filtered = filter_by_ratio(marked_merged, r_c)
                            
#                             iou = intersection_over_union_global(marked_filtered, ground_truth)
#                             outcome[(ent_w, ent_t, ene_w, ene_t, d_f, d_t, r_c)] = iou
#                             print(f"{(ent_w, ent_t, ene_w, ene_t, d_f, d_t, r_c)}: {iou}")

In [None]:
ground_truth = data_util.load_squeak_boxes(data[0].df, data[0].wavs[0], spect)

In [None]:
#Best so far: (5, 0.71, 18, 2.1, 35, 10, 2.1)
marked = mark_low_entropy_high_energy(
                     spect.spec, 
                     entropy_window=5,
                     entropy_threshold=0.71,
                     energy_window=18,
                     energy_threshold=2.1)
marked_merged = merge_entropy_neighbors(marked, 35, 10)
marked_filtered = filter_by_ratio(marked_merged, 2.1)

In [None]:
positive_true_filtered, true_positive_filtered = intersection_over_union_elementwise(marked_filtered, ground_truth)
positive_true, true_positive = intersection_over_union_elementwise(marked_merged, ground_truth)

### Ground truth given prediction

In [None]:
d = metrics.iou_dict_to_plain_list(positive_true_filtered, mode='max')
#d = metrics.iou_dict_to_plain_list(positive_true, mode='max')
bins = 101
hist = np.histogram(d, bins)
plt.figure(figsize=(20, 5))
plt.title("IoU histogram (ground truth given prediction)")
plt.bar([i for i in range(101)], hist[0], width=0.8)
plt.xticks([2*i for i in range(51)])
plt.show()

### Prediction given ground truth

In [None]:
# !
#d = metrics.iou_dict_to_plain_list(true_positive_filtered, mode='max')
d = metrics.iou_dict_to_plain_list(true_positive, mode='sum')
bins = 101
hist = np.histogram(d, bins)
plt.figure(figsize=(20, 5))
plt.title("IoU histogram (prediction given ground truth)")
plt.bar([i for i in range(101)], hist[0], width=0.8)
plt.xticks([2*i for i in range(51)])
plt.show()

In [None]:
# !
#d = metrics.iou_dict_to_label_lists(true_positive_filtered, mode='max')
d = metrics.iou_dict_to_label_lists(true_positive, mode='sum')
bins = 101
for (label, intersections) in d.items():
    hist = np.histogram(intersections, bins)
    plt.figure(figsize=(20, 5))
    plt.title(f"IoU histogram ({label} prediction given ground truth)")
    plt.bar([i for i in range(101)], hist[0], width=0.8)
    plt.xticks([2*i for i in range(51)])
    plt.show()

## Save spectrograms

In [None]:
# !mkdir figures
# !mkdir figures/preprocessed
# !mkdir figures/log

In [None]:
# folders = data_util.load_data(constants.LABELED_SOURCES)

In [None]:
# for folder in folders:
#     for sig in tqdm(folder.signals):
# #         if sig.name in ['ch1-2018-11-20_10-20-34_0000006.wav', 'ch1-2018-11-20_10-23-08_0000008.wav',
# #                        'ch1-2018-11-20_10-31-42_0000014.wav', 'ch1-2018-11-20_10-39-58_0000019.wav',
# #                        'ch1-2018-11-20_10-29-02_0000012.wav', 'ch1-2018-11-20_10-37-25_0000017.wav',
# #                        'ch1-2018-11-20_10-42-38_0000021.wav']:
# #             continue
# #         print(sig.name)
#         spec = sound_util.signal_spectrogram(sig)
        
#         ground_truth = data_util.load_squeak_boxes(folder, sig.name, spec)
#         #Best so far: (5, 0.71, 18, 2.1, 35, 10, 2.1)
#         marked = mark_low_entropy_high_energy(
#                      spec.spec, 
#                      spec.freqs,
#                      entropy_window=5,
#                      entropy_threshold=0.71,
#                      energy_window=18,
#                      energy_threshold=2.1)
#         marked_merged = merge_entropy_neighbors(marked, 35, 10)
#         marked_filtered = filter_by_ratio(marked_merged, 2.1)
        
#         raw = [marked, marked_merged, marked_filtered]
#         raw_name = ['raw_5_0.91_18_2.1',
#                     'merged_5_0.91_18_2.1',
#                     'filtered_5_0.91_18_2.1']
        
#         blur = cv2.bilateralFilter(spec.spec.numpy(), 9,75,75)
#         marked_den = mark_low_entropy_high_energy(
#                      blur, 
#                      spec.freqs,
#                      entropy_window=5,
#                      entropy_threshold=0.91,
#                      energy_window=18,
#                      energy_threshold=2.1)
#         marked_merged_den = merge_entropy_neighbors(marked_den, 35, 10)
#         marked_filtered_den = filter_by_ratio(marked_merged_den, 2.1)
        
#         preprocessed = [marked_den, marked_merged_den, marked_filtered_den]
#         preprocessed_name = ['raw_5_0.91_18_2.1_bilateral_9_9_75',
#                         'merged_5_0.91_18_2.1_bilateral_9_9_75',
#                         'filtered_5_0.91_18_2.1_bilateral_9_9_75']
        
        
#         for i in range(3):
            
#             plt.figure()
#             s = np.array(blur)
#             t_delta = 1
#             s[s==0] = np.min(s[s!=0])
#             split_draw_spectrogram(
#                 spec.freqs,
#                 spec.times,
#                 s,
#                 ground_truth,
#                 preprocessed[i],
#                 folder,
#                 sig.name,
#                 t_delta=t_delta)
#             plt.tight_layout()
#             plt.savefig(fname=f'figures/preprocessed/{sig.folder}-{sig.name}--{preprocessed_name[i]}',format='svg', facecolor='w')
#             plt.figure()
            
#             s = np.array(spec.spec)
#             t_delta = 1
#             s[s==0] = np.min(s[s!=0])
#             split_draw_spectrogram(
#                 spec.freqs, 
#                 spec.times, 
#                 s,
#                 ground_truth,
#                 raw[i],
#                 folder,
#                 sig.name,
#                 t_delta=t_delta)
#             plt.tight_layout()

#             plt.savefig(fname=f'figures/log/{sig.folder}-{sig.name}--{raw_name[i]}',format='svg', facecolor='w')