In [1]:
import pandas as pd
import torch
import os
import sys
import librosa
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import math

sys.path.append('../audio_preprocessing')
sys.path.append('../src')
sys.path.append('../model_training_utils')

import preprocessing_func_2
from generator_to_dataset_2 import NormalisedDataSet
from gdsc_utils import PROJECT_DIR
import model_training

os.chdir(PROJECT_DIR)

In [2]:
df = pd.read_csv('data/metadata.csv')
df_train = df[df["subset"] == "train"]
df_val = df[df["subset"] == "validation"]
df_train.head()

Unnamed: 0,file_name,unique_file,path,species,label,subset,sample_rate,num_frames,length
0,Roeselianaroeselii_XC751814-dat028-019_edit1.wav,Roeselianaroeselii_XC751814-dat028-019,data/train/Roeselianaroeselii_XC751814-dat028-...,Roeselianaroeselii,56,train,44100,4586400,104.0
1,Roeselianaroeselii_XC752367-dat006-010.wav,Roeselianaroeselii_XC752367-dat006-010,data/train/Roeselianaroeselii_XC752367-dat006-...,Roeselianaroeselii,56,train,44100,337571,7.654671
2,Yoyettacelis_GBIF2465208563_IN36000894_50988.wav,Yoyettacelis_GBIF2465208563_IN36000894_50988,data/train/Yoyettacelis_GBIF2465208563_IN36000...,Yoyettacelis,64,train,44100,220500,5.0
3,Gomphocerippusrufus_XC752285-dat001-045.wav,Gomphocerippusrufus_XC752285-dat001-045,data/train/Gomphocerippusrufus_XC752285-dat001...,Gomphocerippusrufus,26,train,44100,693715,15.730499
5,Phaneropteranana_XC755717-221013-Phaneroptera-...,Phaneropteranana_XC755717-221013-Phaneroptera-...,data/train/Phaneropteranana_XC755717-221013-Ph...,Phaneropteranana,41,train,44100,88200,2.0


In [3]:
from typing import List

def load_wav(path):
    wav, sr = librosa.load(path, sr=44100)
    return wav

def normalise_wav(wav, normalise_to:float=0.5):
    wav = wav - 0.5*(wav.max() + wav.min())
    max_value: float = max(abs(wav.max()), abs(wav.min()))
    return normalise_to*(wav/max_value)

def calculate_melsp(x, n_fft=1024, hop_length=128):
    stft = np.abs(librosa.stft(x, n_fft=n_fft, hop_length=hop_length))**2
    log_stft = librosa.power_to_db(stft)
    melsp = librosa.feature.melspectrogram(S=log_stft,n_mels=128)
    return melsp

def find_wav_peaks(wav:List[float], distance_between_peaks:int=66150) -> List[int]|None:
    '''
    wav is array of floats
    '''
    if len(wav) < distance_between_peaks:
        return None
    
    height:float = wav.max()/4
    start_point:int = math.floor(distance_between_peaks/2)
    end_point:int = len(wav) - 1 - start_point
    peaks, _ = signal.find_peaks(wav[start_point:end_point], distance=distance_between_peaks, height=height)
    
    if len(peaks) == 0:
        peaks, _ = signal.find_peaks(wav, distance=distance_between_peaks, height=height)
        for i in range(len(peaks)):
            if peaks[i] < start_point:
                peaks[i] = start_point + 1
            if peaks[i] + math.floor(distance_between_peaks/2) > len(wav):
                peaks[i] = len(wav) - math.floor(distance_between_peaks/2) - 1
            return peaks
        
    return peaks + start_point

In [4]:
list_1 = [1, 2]
list_2 = [1, 2]
print(list_1 == list_2)

True


In [5]:
path = df_train.iloc[18].path
wav = load_wav(path)
print(wav.shape)

(176400,)


In [6]:
# paths = list(df["path"].values)
# for path in paths:
#     wav = load_wav(path)
#     peaks_1 = preprocessing_func_2.find_wav_peaks(wav)
#     peaks_2 = preprocessing_func_2.find_wav_peaks(normalise_wav(wav))
#     if len(peaks_1) != len(peaks_2):
#         print(paths)

In [9]:
#paths = list(df["path"].values)
paths = []
stem = "data/test"
for i in range(556):
    paths.append(f"{stem}/{i}.wav")
#print(paths)
for path in paths:
    wav = load_wav(path)
    wav = normalise_wav(wav)
    peaks = preprocessing_func_2.find_wav_peaks(wav)
    print(peaks)
    if peaks is not None and len(peaks) == 0:
        print(path)
    

[ 61018 150832 237232 308219]
[113151]
[  81500  152840  368693  474356  652614  735115  909917 1015936 1233845
 1392926 1579015 1806251 1894039]
[ 180658  296327  455463  578835  688600  793351  935921 1042490 1316559
 1545437 1760039]
[33076]
[ 93476 184912]
[ 106992  342434  574682  789270  995340 1206220 1427906 1653746 1871692
 2074829 2291880 2561271 2772011 2988596 3232198 3491393]
[  41421  122474  257566  372240  539530  662264  753038  887692  981336
 1113078 1243372 1425186]
[171233 254508 334826 418404 484577]
[  45829  167031  269337  363845  443845  530573  666680  772888  867082
  949270 1026920 1141028 1299401 1385759 1473109 1585066 1682097 1773880
 1862338 1943932 2037549 2156431 2277659 2364561 2477615 2593915 2680714
 2796952 2886994 2978345 3064511]
[115891 190938]
[157672 246331 342312 423011]
[  63150  151874  219551  533276  655582  726878  805101  886421  956691
 1078277 1184896 1265400 1345537 1715346 1791038 1948520 2062983 2159490
 2298475 2495777]
[  44461 

[  72639  431769  532117  615319  723278  825756  895607  968761 1035308
 1102820 1169171 1259222 1375845 1520306 1638983 1748343 1859880 1943029
 2052592 2155795 2225663 2349957 2467922 2536168 2602939 2721807 2799974
 2877948 2949656 3017538 3088239 3177334 3243887 3312253 3410512 3491888
 3598133 3706430 3781440 4062394 4129424 4219434 4306500 4802719 4870047
 4981708]
[53760]
[58798]
[  69207  192860  259987  334913  462116  583519  682283  786444  856931
  950944 1090920 1157287 1257745 1381909 1475146 1598531 1740141 1858041
 1972156 2080226 2185749 2306324 2424575 2564778 2657077 2748129 2872010
 2983597 3110251 3183958 3305057 3371733 3462306]
[49440]
[199451 271817 345343 429204 504438 614877 694333]
[ 40809 209576]
[ 308238  401331  468549  583069  652470  733058  839648  961259 1044047
 1144594 1259986 1349695 1425198 1510049 1597683 1664622 1744142 1832183
 1950213 2018467 2086604 2157930 2229259 2298223]
[ 35924 102156 173999 283191 414910]
[  49649  143193  239872  353288

[  42026  319839  450075  623229  752452  856410 1009208 1084778]
[ 41741 114174 188763 268656 378412]
[70284]
[ 50354 163566 285879 403618]
[53134]
[  39464  168437  277650  367448  476924  565610  689969  806986  882406
  954407 1047943 1117695 1244357 1321247 1414288 1480729 1552973 1620567]
[  43225  201688  283846  356272  440643  510315  586346  652973  755556
  843641  929841  999393 1094492 1179524 1263361 1333484 1402085 1504701
 1605656 1691075 1769925 1898395 1964990 2048992 2149431 2215955 2304127
 2441542 2526021 2648761 2751840 2863171 2975106 3046232 3129984 3201629
 3311703 3428667 3523971 3628344 3703570 3769795 3870463 3937379 4004711]
[ 96224 373519]
[115588]
[544717 612982 684088 861340 963779]
[  41423  249214  329944  411733  573193  650648  734326  819699  907464
 1019800 1138179 1219762 1298055 1365748 1452774 1545669 1629914 1721523
 1813805 1903699]
[ 227473  406612  609507  808619 1023038 1187971 1401069 1598966 1786498]
[ 63644 143709 237263 307643 419792 49

[174174 430123 581066]
[ 82665 159675 230365 307827]
[150528 220253 362429 557575]
[112297]
[36877]
[  39574  152999  258855  351297  425344  556605  656618  724313  811123
  906698  974234 1072895 1143303 1242143 1328947 1460065 1542116 1619155
 1742956 1822730 1923344 2032094 2126782 2215256 2322273 2417669 2513673
 2602152 2669955 2759542 2864798 2979140 3102046 3179793 3269745 3347946
 3458501 3560674 3678198 3744386 3811570 3938462 4069191 4157448 4238957
 4338942 4423721 4490185 4559488 4659976 4735497 4820675 4920895 5014547
 5093039 5202717 5312769 5380853 5454409 5523870 5615094 5740896 5848292
 5969441 6070109 6138051 6207865 6311935 6423003]
[ 64805 134111 267881]
[158531 235261]
[ 88491 248665 321185 391520 457673 568376 666426]
[ 56117 127342]
[ 60935 159249 298752 366317]
[  40890  219330  400907  570116  727371  917529 1108940 1263294 1437929
 1583700 1755935 1941159 2104281 2259596 2441312 2590687 2747228 2906666
 3068333 3226647 3384778 3541925]
[  92544  196272  30225

[ 756016  860517  935024 1028889 1145033 1222314 1316655 1406431 1472596
 1569251 1676893 1751784 1818214 1896101 1990608 2071584 2197185 2269711
 2336807 2434713]
[102430 176774 283278]
[ 82635 161206 268855]
[145406 239668 377294]
[ 86788 160186 229276 342975 414131 514717 586586 669129]
[55124]
[ 281401  350495  431505  536668  657373  724586  839646  915940  995339
 1069412]
[37532]
[ 56221 129886 204871 315348 471322 568218 640901 863667]
[ 238299  444651  524377  607785  725775  807636  874785  965191 1057807
 1147543 1215711 1282891 1375667 1471847 1562045 1653508 1741716 1832127
 2008615]
[210842 279876 359696 440050]
[ 44071 127629 233032 299219 391791]
[ 33708 101361 180573 293100 360169 465173 552236]
[ 56136 123045 206499 315841 385270]
[ 63250 151864 219729 286245 367024]
[  43445  168270  260558  386380  489552  619705  687503  796931  864685
  956178 1069627 1159761]
[189565 319481 399464 480284 550171 636954]
[87930]
[  64235  131300  216020  314582  430559  498010  564

In [8]:
my_list = [1, 2]
for i in range(len(my_list)):
    my_list[i] += 2
print(my_list)

[3, 4]


In [9]:
wav.max(), wav.min()

(0.5, -0.5)

In [10]:
normalised_wav = normalise_wav(wav)
normalised_wav.max(), normalised_wav.min()

(0.5, -0.5)

In [11]:
height = wav.max()/4
total = len(wav)-1
peaks, _ = signal.find_peaks(wav[33075:total-33075], distance=66150, height=height)
print(peaks + 33075)

[ 54554 129320 229952]


In [12]:
0.75 * 44100, 1.5 * 44100

(33075.0, 66150.0)

In [13]:
def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='equal', xmax=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or 'Spectrogram (db)')
    axs.set_ylabel(ylabel)
    axs.set_xlabel('frame')
    im = axs.imshow(spec, origin='lower', aspect=aspect)
    if xmax:
        axs.set_xlim((0, xmax))
    fig.colorbar(im, ax=axs)
    plt.show(block=False)

In [14]:
paths, labels = list(df["path"].values), list(df["label"].values)

generator = preprocessing_func_2.non_normalised_data_generator(
    paths=paths,
    labels=labels,
)

In [15]:
for item in generator:
    print(item)
    break

(tensor([[[-0.2722, -0.1581, -0.1075,  ..., -0.3408, -0.2560, -0.2538],
         [-0.1800, -0.1075, -0.1808,  ..., -0.4441, -0.5429, -0.5363],
         [-0.1474, -0.0859, -0.1747,  ..., -0.4918, -0.6582, -0.7021],
         ...,
         [-1.1810, -0.8828, -0.6520,  ..., -1.8885, -1.7781, -1.7689],
         [-1.2904, -0.8053, -0.5174,  ..., -1.8691, -1.7856, -1.7590],
         [-0.9889, -0.6990, -0.4829,  ..., -1.8702, -1.8180, -1.7690]]]), 56)


In [16]:
item[0].shape

torch.Size([1, 128, 512])

In [17]:
mean, std, weights = preprocessing_func_2.get_stats_and_class_weights_of_non_normalised_data_gen(generator, (128, 512))

In [18]:
mean, std, weights

(tensor([-1.4301]),
 tensor([0.7815]),
 array([ 0.43518264,  5.33501684, 19.64256198,  8.00252525,  6.44979647,
         1.28995929,  2.04803964,  0.49900273,  0.9866127 ,  1.09401611,
         1.11089039,  0.8677437 ,  2.12875056,  2.27440191,  3.42965368,
         4.96708464,  0.81076241,  0.40273659,  9.6030303 ,  0.26349778,
         3.75770751,  0.56562351,  1.70804887,  2.98025078,  3.32412587,
         0.33343855,  0.87300275,  0.78285573,  0.6892127 ,  0.21902502,
         1.0096644 ,  0.79583124,  0.82155202,  0.42118554,  0.54425235,
         3.9285124 ,  0.70958352,  2.82442068,  1.63688017,  0.47228018,
         0.99341693,  1.0096644 ,  1.80810194,  1.38950599,  5.76181818,
         0.70958352,  1.06176011,  1.24177116,  1.10804196,  2.5876429 ,
         2.63497783, 12.00378788,  8.1535163 ,  4.40955473,  3.27376033,
         1.82336018,  0.50660769,  1.32557167,  1.68803267,  0.71664405,
         5.08395722,  0.34654079,  0.6488534 ,  7.32434515,  2.00993658,
         3.6

In [21]:
import json

def save_as_json(path, description, mean, std, weights):
    my_dict = {
        "description": description,
        "mean": float(mean),
        "std": float(std),
        "weights": list(weights.astype(float)),
    }
    with open(path, 'w') as f:
        json.dump(my_dict, f)

save_as_json(
    "audio_preprocessing/saved_data/new_data.json", 
    "seconds 1.5, image shape (128,512)", mean, std, weights)

In [23]:
path = "data/val/Leptophyespunctatissima_XC752570-dat141-007_edit2.wav"

wav = preprocessing_func_2.load_wav(path)
wav = preprocessing_func_2.normalise_wav(wav)
peaks = preprocessing_func_2.find_wav_peaks(wav)
#chunks = preprocessing_func_2.split_wav_by_peaks(wav, peaks)
print(peaks, len(wav))

None 66150


In [24]:
paths, labels = list(df["path"].values), list(df["label"].values)

generator = preprocessing_func_2.non_normalised_data_generator(
    paths=paths,
    labels=labels,
)

normalised_data_generator = preprocessing_func_2.normalised_data_generator(generator, mean, std)

In [26]:
for item in normalised_data_generator:
    print(item)
    break

(tensor([[[ 0.9417,  1.1581,  1.3858,  ...,  1.5099,  1.5410,  1.4609],
         [ 1.3198,  1.4610,  1.4469,  ...,  1.4531,  1.5505,  1.5703],
         [ 1.4466,  1.5302,  1.4436,  ...,  1.2608,  1.3738,  1.4447],
         ...,
         [ 1.8708,  2.0365,  2.1937,  ..., -0.4133, -0.3800, -0.4902],
         [ 1.8880,  1.9948,  2.1318,  ..., -0.4260, -0.3485, -0.3929],
         [ 1.7683,  1.9787,  2.1201,  ..., -0.5645, -0.5553, -0.5230]]]), 56)


In [31]:
from generator_to_dataset_2 import NormalisedDataSet

my_dataset = NormalisedDataSet(
    non_normalised_data_generator_fn = preprocessing_func_2.non_normalised_data_generator,
    normalised_data_generator_fn = preprocessing_func_2.normalised_data_generator,
    df = df,
    shuffle = True,
    mean = mean,
    std = std,
)