In [None]:
from utils.data_reading.sound_data.station import StationsCatalog

import datetime
from scipy import signal
import numpy as np
import torch
from torchvision.transforms import Resize
from matplotlib import pyplot as plt
from utils.detection.TiSSNet import TiSSNet  # seems useless but enables PyTorch to retrieve the model definition

In [None]:
catalog_path = "../../../data/demo/dataset.yaml"
tissnet_checkpoint = "../../../data/models/TiSSNet/torch_save"

stations = StationsCatalog(catalog_path)
elan_raw = stations.by_dataset("OHASISBIO_2018_raw")
elan_raw = elan_raw.stations[0]
manager = elan_raw.get_manager()

date_start = manager.dataset_start + datetime.timedelta(seconds=100)
date_end = date_start + datetime.timedelta(seconds=100)
data = manager.get_segment(date_start, date_end)

In [None]:
spectrogram = 10*np.log10(signal.spectrogram(data, fs=240, nperseg=256, noverlap=128)[-1]).astype(np.float32)[::-1].copy()
spectrogram = spectrogram[np.newaxis, :, :]  # add a dummy dimension, this stands for the channel number (here we are in grayscale, i.e. only one value for each pixel)
model_det = torch.load(tissnet_checkpoint).cpu()

# resize data
input_data = Resize((128, spectrogram.shape[1]))(torch.from_numpy(spectrogram))
# normalization
input_data[input_data<-35] = -35
input_data[input_data>140] = 140
input_data = (input_data+35) / (35+140)

with torch.no_grad():  # tells PyTorch that no gradient back propagation is needed (we do not train any network here)
    res = model_det(input_data).numpy()

f = plt.figure(1)
plt.imshow(input_data[0], aspect="auto", cmap="jet", vmin=0, vmax=1)
f.show()

g = plt.figure(2)
plt.plot(res)
plt.xlim(0, len(res))
plt.ylim(0, 1)
g.show()