In [1]:
import math
import numpy as np
import glob2
import torch
import datetime
from pathlib import Path
import pickle
from tqdm.notebook import tqdm
from torchvision.transforms import Resize
from scipy.signal import find_peaks

from utils.detection.TiSSNet import TiSSNet
from utils.data_reading.sound_data.sound_file_manager import make_manager
from utils.physics.signal.make_spectrogram import make_spectrogram

In [2]:
sound_file_path = "/media/plerolland/akoustik/GEODAMS/2024"
paths = glob2.glob(f"{sound_file_path}/*")
tissnet_checkpoint = "../../../data/models/TiSSNet/torch_save"
out_root = "../../../data/detection"

# output files
DELTA = datetime.timedelta(seconds=3600/0.98)  # /0.98 to get 1h segments
TIME_RES = 0.5342  # duration of each spectrogram pixel in seconds
FREQ_RES = 0.9375  # f of each spectrogram pixel in Hz
REQ_HEIGHT = 128

OVERLAP = 0.02  # overlap for models application (no link with STFT)
STEP = (1 - OVERLAP) * DELTA

TISSNET_PROMINENCE = 0.05
ALLOWED_ERROR_S = 5
MIN_HEIGHT = 0.05

batch_size = 16

device = "cuda"

In [3]:
model_det = TiSSNet().to(device)
model_det.load_state_dict(torch.load(tissnet_checkpoint))

def process_batch(batch):
    try:
        batch = np.array(batch)
    except:
        print("not rectangular array")
    batch = torch.from_numpy(batch).to(device)
    with torch.no_grad():
        res = model_det(batch).cpu().numpy()
    del batch
    torch.cuda.empty_cache()
    return res

In [4]:
out_dir = f"{out_root}/TiSSNet/"
Path(out_dir).mkdir(parents=True, exist_ok=True)

for path in paths:
    manager = make_manager(path)
    out_file = f"{out_dir}/{manager.name}"

    print(f"Starting detection on {manager.name}")

    start, end = manager.dataset_start, manager.dataset_end
    steps = math.ceil((end - start)/STEP)
    start_idx = 0
    batch_dates, batch_process = [], []

    # if some detection has already been run, we start where it was stopped
    already_done = []
    if Path(out_file).exists():
        with open(out_file, "rb") as f:
            while True:
                try:
                    already_done.append(pickle.load(f))
                except EOFError:
                    break
        last_date = already_done[-1][0]
        start_idx = math.floor((last_date - start) / STEP)


    for i in tqdm(range(steps), smoothing=0.001):
        if i < start_idx:
            continue # this is just to fill tqdm progress bar in case we loaded an old detection file

        # important : prefer index multiplication over incrementation to avoid rounding errors (i.e. seg_start = start + i * STEP >> seg_start = seg_start + STEP)
        seg_start = start + i * STEP
        seg_end = min(end, seg_start + DELTA)
        if seg_start >= seg_end:
            break

        # add data to batch
        data = manager.get_segment(seg_start, seg_end)
        spectrogram = make_spectrogram(data, manager.sampling_f, t_res=TIME_RES, f_res=FREQ_RES, return_bins=False, normalize=True, vmin=-35, vmax=140).astype(np.float32)
        spectrogram = spectrogram[np.newaxis, :, :]  # add a dummy dimension, this stands for the channel number (here we are in grayscale, i.e. only one value for each pixel)
        input_data = Resize((REQ_HEIGHT, spectrogram.shape[-1]))(torch.from_numpy(spectrogram)) # resize data
        batch_dates.append(seg_start)
        batch_process.append(input_data)

        # check if the batch is ready to be processed
        if len(batch_process) == batch_size:
            if batch_process[-1].shape != batch_process[0].shape or batch_process[-2].shape != batch_process[-1].shape:
                # last (and probably the one before because of overlaps) batch has a last element shorter than the others, we thus make three batches
                rlastlast = process_batch(batch_process[-2])
                rlast = process_batch(batch_process[-1])
                rfirst = process_batch(batch_process[:-2])
                res = list(rfirst) + [rlastlast] + [rlast]
                del batch_process # reclaim some RAM
            else:
                res = process_batch(batch_process)
                del batch_process # reclaim some RAM

            # now proceed to peak finding for each window to keep only the peaks
            for i, (seg_start, r) in enumerate(zip(batch_dates, res)):
                peaks = find_peaks(r, height=0, distance=ALLOWED_ERROR_S / TIME_RES, prominence=TISSNET_PROMINENCE)
                time_s = peaks[0] * TIME_RES
                peaks = [(seg_start + datetime.timedelta(seconds=time_s[j]), peaks[1]["peak_heights"][j]) for j in range(len(time_s)) if peaks[1]["peak_heights"][j] > MIN_HEIGHT]

                with open(out_file, "ab") as f:
                    for i, (d, p) in enumerate(peaks):
                        pickle.dump([d, p.astype(np.float16)], f)  # we write detections as a list of (date, peak probability)

            batch_dates, batch_process = [], []

Starting detection on HAMS-East


  0%|          | 0/8061 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,
  spectro = 10*np.log10(spectro)


Starting detection on HAMS-Centre


  0%|          | 0/7963 [00:00<?, ?it/s]

Starting detection on HAMS-North


  0%|          | 0/7988 [00:00<?, ?it/s]

Starting detection on HAMS-South


  0%|          | 0/8002 [00:00<?, ?it/s]