In [None]:
import math
import pandas as pd
import numpy as np
import librosa
import warnings
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy.io import wavfile
from collections import OrderedDict
from tqdm import tqdm
import pickle

import glob
import os
from os import path

from PIL import Image

In [None]:
left_col, right_col = "Begin Time (s)", "End Time (s)"
top_col, bot_col = "High Freq (Hz)", "Low Freq (Hz)"
class_col, class_conf_col = "Species", "Species Confidence"

recording_dir = "../data"
annotation_dir = "../data"
output_dir = "../data/spectrograms-15s-072321"

# SPECTROGRAM CONSTANTS
# Window size (n_fft) in seconds
WINDOW_SIZE_SEC = 3/20
# Hop Length in seconds
HOP_LEN_SEC = 15/400
# Number of frequency bands (y dimension of spectrogram)
N_MELS = 400
# Maximum frequency considered (highest value in y dimension)
FREQUENCY_MAX = 1500

# CHUNK CONSTANTS
# Length of one chunk in seconds
CHUNK_SIZE_SEC = 45
# Minimum % visibility of a call to keep annotation
MIN_BOX_PERCENT = 0.3

In [None]:
def get_file_id(fname):
    return path.basename(fname)[:22]

# Constructs a list of pairs (wav_fname, annot_fname)
def get_filename_pairs(recording_dir, annotation_dir):
    wav_fnames = glob.glob(path.join(recording_dir, "*.wav"))
    annot_fnames = glob.glob(path.join(annotation_dir, "*.txt"))
    
    res = []
    id_set = set()
    for wfname in wav_fnames:
        id_nums = get_file_id(wfname)
        if id_nums in id_set:
            raise ValueError("Duplicate Wav ID: {}".format(id_nums))
        id_set.add(id_nums)
        
        annots = [a for a in annot_fnames if path.basename(a).startswith(id_nums)]
        if len(annots) > 1:
            raise ValueError("More than one annotation for recording: {}".format(wfname))
        if len(annots) < 1:
            raise ValueError("No annotation for recording: {}".format(wfname))
        
        res.append((wfname, annots[0]))
    
    return res

In [None]:
dataset = get_filename_pairs(recording_dir, annotation_dir)
dataset

In [None]:
# TODO: measure speed of different fns for opening wav files
def read_wavfile(wav_name, normalize=True, verbose=False):
    if verbose:
        print("Reading {}".format(wav_name))
    sr, data = wavfile.read(wav_name)
    if verbose:
        print("{} samples at {} samples/sec --> {} seconds".format(data.shape[0], sr, data.shape[0]/sr))

    if normalize:
        data = data.astype(float)
        data = data - data.min()
        data = data / data.max()
        data = data - 0.5
    
    return sr, data


def read_annotations(fname, verbose=False):
    annotations = pd.read_csv(fname, sep="\t")
    if verbose:
        print("Read {} annotations from {}".format(len(annotations), fname))
        print(",".join(annotations.columns))
    return annotations

In [None]:
def get_area(annotation):
    return ((annotation[right_col] - annotation[left_col])
            * (annotation[top_col] - annotation[bot_col]))

# Returns the min and max db observed in all files
def get_db_bounds(wav_filenames):
    min_val, max_val = None, None
    for wfname in wav_filenames:
        sr, data = read_wavfile(wfname, normalize=True)
        n_fft = int(WINDOW_SIZE_SEC * sr)
        hop_len = int(HOP_LEN_SEC * sr)
        chunk_size = int(CHUNK_SIZE_SEC * sr)
        for start_i in range(0, len(data), chunk_size - (hop_len * (N_MELS-2) + n_fft)):
            mel_spec = librosa.feature.melspectrogram(y=data[start_i:min(len(data),start_i+chunk_size)],
                                                      sr=sr,
                                                      n_fft=n_fft,
                                                      hop_length=hop_len,
                                                      n_mels=N_MELS,
                                                      fmax=FREQUENCY_MAX,
                                                      center=False)
            mel_spec = librosa.power_to_db(mel_spec)
            temp_min = mel_spec.min()
            temp_max = mel_spec.max()
            if min_val is None or temp_min < min_val:
                min_val = temp_min
            if max_val is None or temp_max > max_val:
                max_val = temp_max
    return min_val, max_val

def process_file(wav_filename, annot_filename, min_db, max_db, verbose=False):
    sr, data = read_wavfile(wav_filename, normalize=True, verbose=verbose)
    annotations = read_annotations(annot_filename, verbose=verbose)
    file_id = get_file_id(wav_filename)
    
    n_fft = int(WINDOW_SIZE_SEC * sr)
    hop_len = int(HOP_LEN_SEC * sr)
    chunk_size = int(CHUNK_SIZE_SEC * sr)
    
    # Start Indices of each chunk
    start_vals = [s for s in range(0, len(data), chunk_size - (hop_len * (N_MELS-2) + n_fft))]
    
    # If last cut point creates a tiny chunk, remove it
    if len(data) - start_vals[-1] < int(chunk_size / 2):
        start_vals = start_vals[:-1]
    
    if verbose:
        print("Start points: [{}]".format(",".join([str(s) for s in start_vals])))
        
    def extract_chunk(start_i, end_i, spec_name, annot_name):
        mel_spec = librosa.feature.melspectrogram(y=data[start_i:end_i],
                                                  sr=sr,
                                                  n_fft=n_fft,
                                                  hop_length=hop_len,
                                                  n_mels=N_MELS,
                                                  fmax=FREQUENCY_MAX,
                                                  center=False)
        mel_spec = librosa.power_to_db(mel_spec)
        mel_spec = np.clip((mel_spec - min_db) / (max_db - min_db) * 255, a_min=0, a_max=255)
        mel_spec = mel_spec.astype(np.uint8)
        spec_height, spec_width = mel_spec.shape
        
        # Get annotations to those inside chunk
        start_s, end_s = start_i/sr, end_i/sr
        freq_axis_low, freq_axis_high = librosa.hz_to_mel(0.0), librosa.hz_to_mel(FREQUENCY_MAX)
        chunk_annotations = annotations.loc[~((annotations[left_col] > end_s)
                                              | (annotations[right_col] < start_s))].copy()
        # Rescale axes to 0.0-1.0 based on location inside chunk
        chunk_annotations.loc[:,[left_col,right_col]] = ((chunk_annotations[[left_col,right_col]]
                                                         - start_s) / (end_s - start_s)) * spec_width
        chunk_annotations.loc[:,[bot_col,top_col]] = (1.0 - ((librosa.hz_to_mel(chunk_annotations[[bot_col,top_col]])
                                                      - freq_axis_low) / (freq_axis_high - freq_axis_low))) * spec_height
        trimmed_annots = chunk_annotations.copy()
        trimmed_annots[left_col] = trimmed_annots[left_col].clip(lower=0, upper=spec_width).astype(int)
        trimmed_annots[right_col] = trimmed_annots[right_col].clip(lower=0, upper=spec_width).astype(int)
        trimmed_annots[bot_col] = trimmed_annots[bot_col].clip(lower=0, upper=spec_height).astype(int)
        trimmed_annots[top_col] = trimmed_annots[top_col].clip(lower=0, upper=spec_height).astype(int)
        overlaps = []
        for i in trimmed_annots.index:
            intersection = trimmed_annots.loc[i]
            original = chunk_annotations.loc[i]
            overlaps.append(get_area(intersection) / get_area(original))
        chunk_annotations = trimmed_annots.loc[np.array(overlaps) > MIN_BOX_PERCENT]
        
        if verbose:
            print("Found {} annotations in chunk".format(len(chunk_annotations)))
        
        # Save Chunk
        im = Image.fromarray(mel_spec[::-1, :])
        im = im.convert("L")
        im.save(path.join(output_dir, spec_name))
        if verbose:
            print("Saved spectrogram to '{}'".format(spec_name))
        chunk_annotations["image_name"] = pd.Series({idx:spec_name for idx in chunk_annotations.index})
        return chunk_annotations
    
    all_chunk_annotations = []
    for ind, start_i in enumerate(start_vals[:-1]):
        # Compute & Draw Mel Spectrogram
        spec_name = "{}-{}.png".format(file_id, ind)
        annot_name = "{}-{}-labels.txt".format(file_id, ind)
        all_chunk_annotations.append(extract_chunk(start_i, start_i+chunk_size, spec_name, annot_name))
    spec_name = "{}-{}.png".format(file_id, len(start_vals)-1)
    annot_name = "{}-{}-labels.txt".format(file_id, len(start_vals)-1)
    all_chunk_annotations.append(extract_chunk(start_vals[-1], len(data), spec_name, annot_name))
    all_chunk_annotations = pd.concat(all_chunk_annotations)[
        ["image_name", "Species", left_col, top_col, right_col, bot_col]
    ]
    classes = all_chunk_annotations["Species"].unique()
    class_map = {}
    rev_class_map = {}
    for i in range(len(classes)):
        class_map[i+1] = classes[i]
        rev_class_map[classes[i]] = i+1
    pickle.dump(class_map, open(path.join(output_dir, "classes.p"), "wb"))
    all_chunk_annotations["Species"] = all_chunk_annotations["Species"].map(rev_class_map)
    all_chunk_annotations.to_csv(path.join(output_dir, "all_labels.csv"), index=False, header=False)
    print("Saved annotations to '{}'".format("all_labels.csv"))

In [None]:
min_db, max_db = get_db_bounds([p[0] for p in dataset])
print(min_db, max_db)

In [None]:
os.makedirs(output_dir)
for wav_filename, annot_filename in dataset:
    process_file(wav_filename, annot_filename, min_db, max_db, verbose=True)

In [None]:
# TODO:
#  - Save metadata such as the mapping from y_px to hz and start,end time in seconds of the chunks

In [None]:
classes = pickle.load(open(path.join(output_dir, "classes.p"), "rb"))
colors = plt.cm.rainbow(np.linspace(0, 1, len(classes)+1)).tolist()
labels = pd.read_csv(path.join(output_dir, "all_labels.csv"), header=None).to_numpy()
img_name = labels[59,0]
img_labels = labels[(labels[:,0] == img_name)]
img_path = path.join(output_dir, img_name)
with Image.open(img_path) as img:
    img_data = np.array(img)
    print(img_data.shape)
    plt.figure(figsize=(15,5))
    plt.imshow(img_data, cmap='gray')
    
    for box in img_labels:
        xmin = box[2]
        ymin = box[3]
        xmax = box[4]
        ymax = box[5]
        label = '{}'.format(classes[int(box[1])])
        color = colors[int(box[1])]
        plt.gca().add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, color=color, fill=False, linewidth=2))  
        plt.gca().text(xmin+3, ymin-12, label, size='large', color='white', bbox={'facecolor':color, 'alpha':1.0})
    
    plt.title(img_path)
    plt.show()
    plt.close()