In [10]:
input_root="./szunyog_hangok_25_01_14" # Root folder containing input recordings
output_root="./szunyog_hangok_25_01_14_preprocessed_database_AST" # Root folder for saving outputs


MODEL_PATH = "b_mosquito/results_AST_b_mosquito_25_01_21-full/checkpoint-3966/" # Path to the trained AST model checkpoint

classification_threshold=0.9 # Probability threshold to classify as 'mosquito'
b_write_csv=True # Whether to save the output CSV file
b_save_segments=True # Whether to save the detected .wav segments



In [11]:
import os
from pathlib import Path
from pydub import AudioSegment
import os
import numpy as np
import librosa
from scipy.signal import butter, filtfilt
import glob
import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import time

# Supported audio file formats
supported_formats = [".wav"]

# Create the output folder if it doesn't exist
os.makedirs(output_root, exist_ok=True)



In [12]:

# Load pretrained model and processor
label2id = {'not':0, 'mosquito':1}
id2label = {v: k for k, v in label2id.items()}

from transformers import AutoProcessor, AutoModelForAudioClassification
processor = AutoProcessor.from_pretrained(MODEL_PATH)

model = AutoModelForAudioClassification.from_pretrained(
    MODEL_PATH,
    num_labels=2,  # Number of classes
    id2label=id2label,
    label2id=label2id,
ignore_mismatched_sizes=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=model.to(device)
model


ASTForAudioClassification(
  (audio_spectrogram_transformer): ASTModel(
    (embeddings): ASTEmbeddings(
      (patch_embeddings): ASTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ASTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ASTLayer(
          (attention): ASTAttention(
            (attention): ASTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ASTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ASTIntermediate(
            (de

In [13]:
def split_audio_chunks(input_file):
    """
    Read audio file, determine number of channels, split into 1-second chunks with 0.5-second overlap.
    """
    chunks = []
    metadata = []
    try:
        # Load audio file
        audio = AudioSegment.from_file(input_file)

        # Determine the number of channels
        num_channels = audio.channels
        print(f"File: {input_file}, Number of channels: {num_channels}")
        
        #if num_channels!=channel_num:
        #    print("incorrect channel number")
        #    return  [], [], num_channels

        # Separate all channels
        separated_channels = audio.split_to_mono()

        for channel_index, channel_audio in enumerate(separated_channels):
            #if channel_index<3:
            #    continue

            # Convert to 8kHz
            #channel_audio = channel_audio.set_frame_rate(8000)
            channel_audio = channel_audio.set_frame_rate(16000)

            duration_ms = len(channel_audio)

            # x-second window, y-second step
            window_size = 1000  # in milliseconds
            step_size = 500     # in milliseconds

            for start_ms in range(0, duration_ms - window_size + 1, step_size):
                end_ms = start_ms + window_size
                chunk = channel_audio[start_ms:end_ms]
                chunks.append(chunk)

                # Record metadata
                metadata.append({
                    "file": os.path.basename(input_file),
                    "channel": channel_index + 1,
                    "start_ms": start_ms,
                    "end_ms": end_ms
                })
            #break # channel
            
    except Exception as e:
        print(f"Error during splitting: {input_file} - {e}")

    return chunks, metadata


In [14]:

def predict_audio_chunk(chunk, model, threshold=0.5):
    """
    Filter function that decides whether the given audio chunk contains a mosquito sound.
    """
    data = chunk.get_array_of_samples()
    dtype = data.typecode  # The type of the array, e.g., 'h' or 'i'
    y = np.array(data, dtype=np.float32)  # Always convert to float32

    # Normalization depending on the data type
    if dtype == 'h':  # 16-bit integer
        y = y / (2**15)  # Normalize between -1 and 1
    elif dtype == 'i':  # 32-bit integer
        y = y / (2**31)  # Normalize between -1 and 1
    else:
        raise ValueError(f"Unsupported data format: {dtype}")
        
    #y = np.array(chunk.get_array_of_samples(), dtype=np.float32) / (2**15)
    sr = chunk.frame_rate

    # 100 Hz high-pass filtering
    #y = highpass_filter(y, sr)

    #print(chunk)
    #print(y)

    inputs = processor(y, sampling_rate=16000, return_tensors="pt", padding=True)
    #print(inputs)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    #print(model(**inputs))
    
    with torch.no_grad():
        logits = model(**inputs).logits
    #predicted_class = torch.argmax(logits, dim=-1).item()

    # Calculate probabilities with softmax
    probabilities = F.softmax(logits, dim=-1)
    #print(probabilities)
    
    # Probabilities of class 1 for all elements in the batch
    class_1_probabilities = probabilities[:, 1]
    
    # Logical tensor: which element is greater than or equal to 0.5?
    thresholded = class_1_probabilities >= threshold

    return thresholded[0]



In [15]:
#def anal_chunks(chunks, output_dir, base_filename, metadata, model, b_save=False):
def anal_chunks(chunks, output_dir, base_filename, metadata, model, classification_threshold=0.5, b_save=False):
    """
    Save chunks to the specified folder with numbering.
    """
    os.makedirs(output_dir, exist_ok=True)
    not_selected_dir = os.path.join(os.path.dirname(output_dir), os.path.basename(output_dir) + "_not_selected")
    os.makedirs(not_selected_dir, exist_ok=True)

    #speech_dir = os.path.join(os.path.dirname(output_dir), os.path.basename(output_dir) + "_speech")
    #os.makedirs(speech_dir, exist_ok=True)

    sound_idxs=[]

    for idx, chunk in enumerate(chunks):
        md=metadata[idx]

        ch=md['channel']
        start=str(int(md['start_ms']))
        
        #if filter_speech(chunk):
        #    if b_save:
        #        chunk_speech_file = os.path.join(speech_dir, f"{base_filename}_{ch}_{start}.wav")            
                #chunk.export(chunk_speech_file, format="wav", parameters=["-ar", "16000", "-ac", "1", "-sample_fmt", "s16"])

        #    if idx % 50 == 0:
        #        pass
                #print(f"{idx}: Saved chunk (SPEECH): {chunk_speech_file}")
        #    continue
        
        if predict_audio_chunk(chunk, model, threshold=classification_threshold):
            sound_idxs.append(idx)

            if b_save:
                chunk_output_file = os.path.join(output_dir, f"{base_filename}_{ch}_{start}.wav")
                chunk.export(chunk_output_file, format="wav", parameters=["-ar", "16000", "-ac", "1", "-sample_fmt", "s16"])

            if idx % 10 == 0:
                pass
                print(f"{idx}: Saved chunk (MOSQUITO SOUND): {chunk_output_file}")
        else:
            if b_save:
                not_selected_file = os.path.join(not_selected_dir, f"{base_filename}_{ch}_{start}.wav")
                #chunk.export(not_selected_file, format="wav", parameters=["-ar", "16000", "-ac", "1", "-sample_fmt", "s16"])

            if idx % 50 == 0:
                pass
                #print(f"{idx}: Not selected chunk: {not_selected_file}")

        #break
        
    return sound_idxs


In [None]:

fns=glob.glob(os.path.join(input_root,"*.wav"))
#print(fns)

# Start time measurement
start_time = time.time()


for idx, input_fn in enumerate(fns):

    base_filename=os.path.basename(input_fn)
    output_file = os.path.join(output_root, f"{base_filename[:-4]}.csv")

    
    #If the output file already exists, continue with the next file
    if os.path.exists(output_file):
        print(f"already done: {input_fn}")
        continue

    # Splitting and filtering
    #print("split begun")
    chunks, metadata = split_audio_chunks(input_fn)
    #print("split finished")
    
    
    if len(chunks)>0:
        # index contains mosquito segments, and the function also writes wavs if b_save=1
        #sound_idxs=anal_chunks(chunks, output_root, base_filename, metadata, model, b_save=True)
        sound_idxs=anal_chunks(chunks, output_root, base_filename, metadata, model, classification_threshold=classification_threshold, b_save=b_save_segments)

        if b_write_csv:
            if len(sound_idxs)>0:
                # Create Metadata DataFrame
                out_df = pd.DataFrame(metadata)
                
                # Filter based on analyzed indices
                out_df = out_df.iloc[sound_idxs]
            
                # Save results to file
                out_df.to_csv(output_file, index=False)
            else:
                # Empty csv
                with open(output_file, mode="w") as file:
                    file.write("file,channel,start_ms,end_ms\n")
                print("no mosquito sound found.")

end_time = time.time()
print(f"finished, time: {end_time - start_time:.6f} seconds")


File: ./szunyog_hangok_25_01_14/Test0495.wav, Number of channels: 4
200: Saved chunk (MOSQUITO SOUND): ./szunyog_hangok_25_01_14_preprocessed_database_AST\Test0495.wav_1_100000.wav
640: Saved chunk (MOSQUITO SOUND): ./szunyog_hangok_25_01_14_preprocessed_database_AST\Test0495.wav_1_320000.wav
670: Saved chunk (MOSQUITO SOUND): ./szunyog_hangok_25_01_14_preprocessed_database_AST\Test0495.wav_1_335000.wav
1060: Saved chunk (MOSQUITO SOUND): ./szunyog_hangok_25_01_14_preprocessed_database_AST\Test0495.wav_1_530000.wav
1240: Saved chunk (MOSQUITO SOUND): ./szunyog_hangok_25_01_14_preprocessed_database_AST\Test0495.wav_1_620000.wav
1280: Saved chunk (MOSQUITO SOUND): ./szunyog_hangok_25_01_14_preprocessed_database_AST\Test0495.wav_1_640000.wav
1370: Saved chunk (MOSQUITO SOUND): ./szunyog_hangok_25_01_14_preprocessed_database_AST\Test0495.wav_2_5500.wav
1440: Saved chunk (MOSQUITO SOUND): ./szunyog_hangok_25_01_14_preprocessed_database_AST\Test0495.wav_2_40500.wav
1560: Saved chunk (MOSQUIT