This notebook aims at applying TiSSNet on all the demo data.

In [1]:
import datetime
import numpy as np
import torch
import math
from torchvision.transforms import Resize
from matplotlib import pyplot as plt
from pathlib import Path
import pickle
from tqdm import tqdm
from scipy.signal import find_peaks
import pandas as pd
import seaborn as sns

from utils.data_reading.sound_data.station import StationsCatalog
from utils.physics.signal.make_spectrogram import make_spectrogram
from utils.detection.TiSSNet import TiSSNet
from utils.detection.TiSSNet import process_batch

In [2]:
catalog_path = "/media/plerolland/akoustik/MAHY/MAHY.csv"
tissnet_checkpoint = "../../../../data/models/i_TiSSNet/torch_save_checked-reboot-3"
out_dir = "../../../../data/detection/i_TiSSNet_checked-reboot-3/MAHY"  # where files will be saved
Path(out_dir).mkdir(parents=True, exist_ok=True)  # create output directory if needed

stations = StationsCatalog(catalog_path).filter_out_undated() # remove stations with no start / end dates
print(stations)

model_det = TiSSNet()
model_det.load_state_dict(torch.load(tissnet_checkpoint))

DELTA = datetime.timedelta(seconds=100)  # duration of segments that are given to TiSSNet
OVERLAP = 0.02   # overlap between those segments (no link with STFT)
STEP = (1 - OVERLAP) * DELTA
batch_size = 1  # number of segments that are fed together to TiSSNet

# parameters of peak finding (TiSSNet outputs 1 value per spectrogram time bin, we use a peak finding algorithm to save only the peaks)
TISSNET_PROMINENCE = 0.05
ALLOWED_ERROR_S = 2
MIN_HEIGHT = 0.05

TIME_RES = 0.5342  # duration of each spectrogram pixel in seconds
FREQ_RES = 0.9375  # f of each spectrogram pixel in Hz
HEIGHT = 128

device = "cuda"  # if there is a GPU and CUDA is installed, device can be set to "cuda" instead
model_det.to(device)

(MAHY0_MAHY01, MAHY0_MAHY02, MAHY0_MAHY03, MAHY0_MAHY04, MAHY1_MAHY11, MAHY1_MAHY12, MAHY1_MAHY13, MAHY1_MAHY14, MAHY2_MAHY21, MAHY2_MAHY22, MAHY2_MAHY23, MAHY3_MAHY31, MAHY3_MAHY32, MAHY3_MAHY33, MAHY3_MAHY34, MAHY4_MAHY41, MAHY4_MAHY42, MAHY4_MAHY43, MAHY4_MAHY44)


TiSSNet(
  (layers): Sequential(
    (0): Conv2d(1, 16, kernel_size=(8, 8), stride=(1, 1), padding=same)
    (1): LeakyReLU(negative_slope=0.3, inplace=True)
    (2): Conv2d(16, 16, kernel_size=(8, 8), stride=(1, 1), padding=same)
    (3): LeakyReLU(negative_slope=0.3, inplace=True)
    (4): Conv2d(16, 16, kernel_size=(8, 8), stride=(1, 1), padding=same)
    (5): LeakyReLU(negative_slope=0.3, inplace=True)
    (6): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(16, 32, kernel_size=(5, 8), stride=(1, 1), padding=same)
    (8): LeakyReLU(negative_slope=0.3, inplace=True)
    (9): Conv2d(32, 32, kernel_size=(5, 8), stride=(1, 1), padding=same)
    (10): LeakyReLU(negative_slope=0.3, inplace=True)
    (11): Conv2d(32, 32, kernel_size=(5, 8), stride=(1, 1), padding=same)
    (12): LeakyReLU(negative_slope=0.3, inplace=True)
    (13): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(32,

In [3]:
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import torch

import os
from concurrent.futures import ProcessPoolExecutor
import torch

BATCH_SIZE = 256

def process_and_make_spec(idx, start, STEP, DELTA, end, sampling_f, HEIGHT):
    seg_start = start + idx * STEP
    seg_end = min(end, seg_start + DELTA)
    manager = get_manager_for_worker()
    data = manager.get_segment(seg_start, seg_end)
    if len(data) / sampling_f <= 1:
        return None

    spec = make_spectrogram(data, manager.sampling_f, t_res=TIME_RES, f_res=FREQ_RES, return_bins=False, normalize=True, vmin=-35, vmax=140).astype(np.float32)
    if spec.shape[0] > HEIGHT:
        spec = spec[-HEIGHT:]
    elif spec.shape[0] < HEIGHT:
        spec = np.pad(spec, ((HEIGHT - spec.shape[0], 0), (0, 0)), 'constant')
    return (seg_start, spec[np.newaxis, :, :])

_manager_cache = {}

def get_manager_for_worker():
    pid = os.getpid()
    if pid not in _manager_cache:
        _manager_cache[pid] = global_station.get_manager()  # global_station est défini par le process parent
    return _manager_cache[pid]

for station in stations:
    if "MAHY43" not in station.name:
        continue

    print(f"Starting detection on {station.name}")
    global_station = station  # pour les workers
    manager = station.get_manager()
    out_file = f"{out_dir}/{station.dataset}_{station.name}.pkl"

    start, end = manager.dataset_start, manager.dataset_end
    steps = int(np.ceil((end - start) / STEP))
    start_idx = 0

    if Path(out_file).exists():
        with open(out_file, "rb") as f:
            while True:
                try: last_date = pickle.load(f)[0]
                except EOFError: break
        start_idx = int(np.floor((last_date - start) / STEP))

    with open(out_file, "ab") as f_out, \
         ProcessPoolExecutor() as executor:

        for i in tqdm(range(start_idx, steps, BATCH_SIZE)):
            idxs = [i + j for j in range(BATCH_SIZE) if (i + j) < steps]
            results = list(executor.map(
                process_and_make_spec,
                idxs,
                [start]*len(idxs),
                [STEP]*len(idxs),
                [DELTA]*len(idxs),
                [end]*len(idxs),
                [manager.sampling_f]*len(idxs),
                [HEIGHT]*len(idxs)
            ))

            results = [r for r in results if r is not None]
            if not results:
                continue

            times_loaded, spectros = zip(*results)

            try:
                batch_tensor = np.stack(spectros)
                preds = process_batch(batch_tensor, device, model_det)
                pairs = zip(times_loaded, preds)
            except ValueError:
                pairs = []
                for time_, spec in zip(times_loaded, spectros):
                    try:
                        spec_tensor = np.expand_dims(spec, 0)
                        pred = process_batch(spec_tensor, device, model_det)[0]
                        pairs.append((time_, pred))
                    except Exception as e:
                        print(f"Erreur dans le traitement du spectro isolé : {e}")

            for seg_start, pred in pairs:
                t_res = DELTA.total_seconds() / pred.shape[0]
                peaks = find_peaks(pred, height=0, distance=math.ceil(ALLOWED_ERROR_S / t_res), prominence=TISSNET_PROMINENCE)
                time_s = peaks[0] * t_res
                for j, t in enumerate(time_s):
                    if peaks[1]["peak_heights"][j] > MIN_HEIGHT:
                        date = seg_start + datetime.timedelta(seconds=t)
                        prob = peaks[1]["peak_heights"][j]
                        pickle.dump([date, prob.astype(np.float16)], f_out)

Starting detection on MAHY43


  return F.conv2d(
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
  spectro = 10*np.log10(spectro)
100%|██████████| 1378/1378 [27:50<00:00,  1.21s/it]


Take a look at the results

In [None]:
detection_file = f"{out_dir}/MAHY0_MAHY01.pkl"
d = []
with open(detection_file, "rb") as f:
    while True:
        try:
            d.append(pickle.load(f))
        except EOFError:
            break
print(f"{len(detection_file)} detections found")


dates_plot = np.array(d)[:,0]
df = pd.DataFrame({'date': dates_plot})
df['date'] = pd.to_datetime(df['date'])

counts = df.resample('10D', on='date').size().asfreq('10D', fill_value=0)

sns.barplot(x=counts.index.strftime("%j"), y=counts.values)
plt.title(f"10-day detections from {dates_plot[0].day:02d}/{dates_plot[0].month:02d}/{dates_plot[0].year}")
plt.xlabel("Date")
plt.ylabel("Number of events")