In [4]:
# Main script for acoustic anomaly detection in hydrophone recordings.
#
# AI contributions:
# - Assisted with major refactoring for memory efficiency, code quality, and readability.
# - Implemented advanced features for detailed logging, anomaly explanation, and results validation.
#

import datetime
import torch
import torchaudio
import librosa
import librosa.display
import numpy as np
from scipy.signal import butter, sosfilt
from sklearn.cluster import DBSCAN
import os
import glob
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend to prevent memory leaks
import matplotlib.pyplot as plt
from panns_inference import AudioTagging
import gc
import csv
import pathlib


LEG = 'LEG9_S20250708_PSTATSRAADLEHMKUHL' #'LEG10_S20250805_PSTATSRAADLEHMKUHL'
MONTH = str(7) # set to the month you want to anlyse
# --- Configuration ---
config = {
    # --- Paths ---
    
    "audio_folder_path": ...,
    "output_folder": ...,
    "audio_clips_subfolder": "audio_clips",
    "loud_clips_subfolder": "high_quality_clips",
    "hq_plots_subfolder": "high_quality_plots",
    "std_plots_subfolder": "standard_plots",

    # --- Anomaly Quality Control ---
    "separate_loud_anomalies": True,
    "loudness_ratio_threshold": 2.5,

    # --- Audio Processing & Visualization ---
    "chunk_seconds": 3.0,
    "pre_context_seconds": 5.0, # Seconds of audio to save before the event group
    "post_context_seconds": 7.0, # Seconds of audio to save after the event group (5s + 2s extra)
    "plot_context_seconds": 2.0, # Shorter context for a more zoomed-in plot
    "target_sample_rate": 32000,
    "high_pass_filter_hz": 2000,



    # --- Instrument Filtering ---
    "instrument_frequencies_hz": [38000, 75000, 80000], # Frequencies of EK80 and ADCP
    "frequency_tolerance_hz": 250,      # Search tolerance around each frequency
    "instrument_peak_ratio": 8.0,       # A ping's peak must be 8x the average spectral magnitude to be filtered
    "instrument_energy_threshold": 0.30,


    # --- Clustering & Grouping ---
    "dbscan_eps": 2.5,
    "dbscan_min_samples": 4,
    "temporal_grouping_seconds": 15.0,
    "min_anomalies_in_group": 2,

    # --- Model Prediction Filtering ---
    "validate_predictions": True,
    "target_labels": ['Whale', 'Dolphin', 'Whistling', 'Chirp tone', 'Animal'],
    "avoid_labels" : ['Stream','Water','Pour', 'Frying (food)','Drip','Boiling','Raindrop',],
    "min_confidence": 0.25,

    # --- General ---
    "hydrophone_id": "LUW6952"
}

# --- Helper Functions ---

def high_pass_filter(data, cutoff, fs, order=5):
    """Applies a high-pass filter to the data."""
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    if normal_cutoff >= 1.0: return data
    sos = butter(order, normal_cutoff, btype='high', analog=False, output='sos')
    y = sosfilt(sos, data)
    return y

def calculate_aliased_frequency(original_freq, sample_rate):
    """Calculates the aliased frequency after downsampling."""
    nyquist = sample_rate / 2
    while original_freq > nyquist:
        original_freq = abs(original_freq - sample_rate)
    return original_freq

def is_instrument_ping_new(chunk, sr, unwanted_freqs, tolerance_hz, energy_threshold=0.30):
    """Checks if a chunk's energy is concentrated in unwanted frequency bands."""
    if np.sum(np.abs(chunk)) == 0: return False
    spectrum = np.abs(np.fft.rfft(chunk))
    freq_axis = np.fft.rfftfreq(len(chunk), 1/sr)
    total_energy = np.sum(spectrum**2)
    if total_energy == 0: return False
    ping_energy = 0
    for freq_hz in unwanted_freqs:
        band_mask = (freq_axis >= freq_hz - tolerance_hz) & (freq_axis <= freq_hz + tolerance_hz)
        ping_energy += np.sum(spectrum[band_mask]**2)
    energy_ratio = ping_energy / total_energy
    return energy_ratio > energy_threshold


def is_instrument_ping(chunk, sr, aliased_freqs, freq_tolerance, peak_ratio_threshold):
    """
    Detects an instrument ping by looking for a dominant spectral peak in predefined frequency bands.

    This method is more robust than simple energy ratios because it specifically looks for a
    signal's "peakiness," which is characteristic of narrow-band pings from echosounders and ADCPs.

    Args:
        chunk (np.ndarray): The audio chunk to analyze.
        sr (int): The sample rate of the chunk.
        aliased_freqs (list): A list of instrument frequencies, already corrected for aliasing.
        freq_tolerance (float): The tolerance in Hz to search around each instrument frequency.
        peak_ratio_threshold (float): The peak magnitude must be this many times greater than the
                                      average magnitude of the spectrum to be considered a ping.

    Returns:
        bool: True if a ping is detected, False otherwise.
    """
    if np.sum(np.abs(chunk)) == 0:
        return False

    # Calculate the magnitude spectrum using Fast Fourier Transform
    spectrum = np.abs(np.fft.rfft(chunk))
    freq_axis = np.fft.rfftfreq(len(chunk), 1/sr)

    # Calculate the average magnitude of the entire spectrum for comparison
    avg_magnitude = np.mean(spectrum)
    if avg_magnitude == 0:
        return False

    # Check each specified instrument frequency band for a dominant peak
    for freq_hz in aliased_freqs:
        # Define the frequency band to inspect
        band_mask = (freq_axis >= freq_hz - freq_tolerance) & (freq_axis <= freq_hz + freq_tolerance)
        
        if np.any(band_mask):
            # Find the maximum peak within this specific band
            band_peak_magnitude = np.max(spectrum[band_mask])
            
            # Key step: Check if this peak is significantly stronger than the average
            if band_peak_magnitude > (avg_magnitude * peak_ratio_threshold):
                # A dominant ping was found, no need to check other frequencies
                return True

    # No dominant pings were found in any of the specified bands
    return False

def preprocess_audio(file_path, target_sr, high_pass_cutoff):
    """Loads, filters, and resamples a single audio file."""
    try:
        waveform, original_sr = torchaudio.load(file_path)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
    except Exception as e:
        print(f"    - Torchaudio failed: {e}. Falling back to librosa.")
        waveform_np, original_sr = librosa.load(file_path, sr=None, mono=True)
        waveform = torch.from_numpy(waveform_np).unsqueeze(0)
    resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
    waveform = resampler(waveform)
    waveform_np = waveform.numpy().flatten()
    filtered_waveform = high_pass_filter(waveform_np, high_pass_cutoff, target_sr)
    return filtered_waveform, target_sr

def create_zoomed_event_plot(full_waveform, sr, file_name, group, plot_directory, cfg):
    """Generates a spectrogram zoomed in on a single event group with context."""
    plot_context = cfg['plot_context_seconds']
    group_start_time = group[0]['time']
    group_end_time = group[-1]['time'] + cfg['chunk_seconds']
    
    # Determine time and sample boundaries for the plot
    plot_start_time = group_start_time - plot_context
    plot_end_time = group_end_time + plot_context
    start_sample = max(0, int(plot_start_time * sr))
    end_sample = min(len(full_waveform), int(plot_end_time * sr))
    
    # Slice the audio to the zoomed-in window
    zoomed_clip = full_waveform[start_sample:end_sample]

    fig, ax = plt.subplots(figsize=(18, 6)) 
    n_fft, hop_length = 2048, 512
    S = librosa.feature.melspectrogram(y=zoomed_clip, sr=sr, n_mels=128, n_fft=n_fft, hop_length=hop_length)
    log_S = librosa.power_to_db(S, ref=np.max)
    
    librosa.display.specshow(log_S, sr=sr, x_axis='time', y_axis='mel', ax=ax, hop_length=hop_length, cmap='magma')
    fig.colorbar(ax.collections[0], ax=ax, format='%+2.0f dB', label='Intensity (dB)')

    title = f"Zoomed Event in {file_name}\nOriginal Time: {group_start_time:.1f}s - {group_end_time:.1f}s"
    ax.set_title(title)
    ax.legend(loc='upper right')

    predictions_text = f"--- Predictions for Event ---\n"
    seen_predictions = set()
    for item in group:
        main_label, main_prob = item['top_3'][0]
        if main_label not in seen_predictions:
            predictions_text += f"- {main_label}: {main_prob:.2f}\n"
            seen_predictions.add(main_label)
    
    fig.text(0.02, 0.02, predictions_text, fontsize=9, wrap=True, verticalalignment='bottom')
    plt.tight_layout(rect=[0.0, 0.1, 1, 0.95])
    
    output_filename = f"{os.path.splitext(file_name)[0]}_event_at_{group_start_time:.0f}s_zoomed.png"
    output_path = os.path.join(plot_directory, output_filename)
    plt.savefig(output_path, dpi=200)
    plt.close(fig)
    return output_path

# --- Analysis Functions ---

def process_file(audio_file_path, panns_model, class_labels, cfg):
    """Processes a single audio file, returning a summary and detailed logs."""
    file_name = os.path.basename(audio_file_path)
    print(f"--- Processing file: {file_name} ---")
    log_entries = []

    try:
        waveform_np, sr = preprocess_audio(audio_file_path, cfg['target_sample_rate'], cfg['high_pass_filter_hz'])
        safe_const = 1e-9
        file_rms_avg = np.sqrt(np.mean(waveform_np**2)) + safe_const
        
        chunk_samples = int(cfg['chunk_seconds'] * sr)
        num_chunks = len(waveform_np) // chunk_samples
        if num_chunks == 0:
            print("  - File is too short to be analyzed. Skipping.")
            return None, log_entries

        aliased_instrument_freqs = [calculate_aliased_frequency(f, sr) for f in cfg['instrument_frequencies_hz']]
        
        print(f"  - Extracting features from {num_chunks} chunks...")
        all_embeddings, all_clipwise_outputs = [], []
        for i in range(num_chunks):
            chunk = waveform_np[i * chunk_samples:(i + 1) * chunk_samples]
            if len(chunk) < chunk_samples: chunk = np.pad(chunk, (0, chunk_samples - len(chunk)))
            clipwise_output, embedding = panns_model.inference(chunk[np.newaxis, :])
            all_embeddings.append(embedding)
            all_clipwise_outputs.append(clipwise_output[0])
        
        feature_matrix = np.array(all_embeddings).reshape(num_chunks, -1)
        del all_embeddings
        
        print("  - Clustering features to find initial anomalies...")
        db = DBSCAN(eps=cfg['dbscan_eps'], min_samples=cfg['dbscan_min_samples']).fit(feature_matrix)
        anomaly_indices = np.where(db.labels_ == -1)[0]
        del feature_matrix

        if len(anomaly_indices) == 0:
            print("No initial anomalies detected by clustering.")
            return None, log_entries
        
        print(f"  - Found {len(anomaly_indices)} potential anomalies. Validating...")
        validated_anomalies_temp = []
        for i in anomaly_indices:
            start_time = i * cfg['chunk_seconds']
            anomaly_chunk = waveform_np[i*chunk_samples:(i+1)*chunk_samples]
            prediction = all_clipwise_outputs[i]
            top_idx, top_prob = np.argmax(prediction), np.max(prediction)
            top_label = class_labels[top_idx]
            
            chunk_rms = np.sqrt(np.mean(anomaly_chunk**2))
            loudness_ratio = chunk_rms / file_rms_avg
            
            if is_instrument_ping_new(anomaly_chunk, sr, aliased_instrument_freqs, cfg['frequency_tolerance_hz'], cfg['instrument_energy_threshold']):
                log_entries.append({'file': file_name, 'timestamp_s': f"{start_time:.2f}", 'status': 'DISCARDED', 'reason': 'Instrument Ping', 'top_prediction': top_label, 'confidence': f"{top_prob:.2f}", 'details': f"LoudnessRatio={loudness_ratio:.2f}"})
                continue
            if is_instrument_ping(anomaly_chunk, sr, aliased_instrument_freqs, cfg['frequency_tolerance_hz'], cfg['instrument_peak_ratio']):
                log_entries.append({'file': file_name, 'timestamp_s': f"{start_time:.2f}", 'status': 'DISCARDED', 'reason': 'Instrument Ping', 'top_prediction': top_label, 'confidence': f"{top_prob:.2f}", 'details': f"LoudnessRatio={loudness_ratio:.2f}"})
                continue
            
            is_valid = not cfg['validate_predictions'] or not (top_label in cfg['avoid_labels'])
            if not is_valid:
                log_entries.append({'file': file_name, 'timestamp_s': f"{start_time:.2f}", 'status': 'DISCARDED', 'reason': 'Failed Validation', 'top_prediction': top_label, 'confidence': f"{top_prob:.2f}", 'details': f"LoudnessRatio={loudness_ratio:.2f}"})
                continue

            top_3_indices = np.argsort(prediction)[::-1][:3]
            top_3_preds = [(class_labels[k], prediction[k]) for k in top_3_indices]
            validated_anomalies_temp.append({"time": start_time, "top_3": top_3_preds, "loudness_ratio": loudness_ratio})
        
        del all_clipwise_outputs
        if not validated_anomalies_temp:
            print("No anomalies passed validation.")
            return None, log_entries

        validated_anomalies_temp.sort(key=lambda x: x['time'])
        event_groups = []
        current_group = [validated_anomalies_temp[0]]
        for i in range(1, len(validated_anomalies_temp)):
            if validated_anomalies_temp[i]['time'] - current_group[-1]['time'] <= cfg['temporal_grouping_seconds']:
                current_group.append(validated_anomalies_temp[i])
            else:
                if len(current_group) >= cfg['min_anomalies_in_group']: event_groups.append(current_group)
                current_group = [validated_anomalies_temp[i]]
        if len(current_group) >= cfg['min_anomalies_in_group']: event_groups.append(current_group)

        if not event_groups:
            print("No significant event groups detected (anomalies were too isolated).")
            for anom in validated_anomalies_temp:
                log_entries.append({'file': file_name, 'timestamp_s': f"{anom['time']:.2f}", 'status': 'DISCARDED', 'reason': 'Isolated Anomaly', 'top_prediction': anom['top_3'][0][0], 'confidence': f"{anom['top_3'][0][1]:.2f}", 'details': f"LoudnessRatio={anom['loudness_ratio']:.2f}"})
            return None, log_entries
        
        print(f"Found {len(event_groups)} distinct event group(s). Generating outputs...")
        
        hq_plots_path = os.path.join(cfg['output_folder'], cfg['hq_plots_subfolder'])
        std_plots_path = os.path.join(cfg['output_folder'], cfg['std_plots_subfolder'])

        plot_path_map = {}
        for group in event_groups:
            is_high_quality = any(anom['loudness_ratio'] >= cfg['loudness_ratio_threshold'] for anom in group)
            plot_dir = hq_plots_path if is_high_quality and cfg['separate_loud_anomalies'] else std_plots_path
            
            plot_path = create_zoomed_event_plot(waveform_np, sr, file_name, group, plot_dir, cfg)
            print(f"    - Zoomed spectrogram saved: {os.path.basename(plot_path)}")
            for anom in group:
                plot_path_map[anom['time']] = plot_path
        
        for anom in validated_anomalies_temp:
            log_entry = {'file': file_name, 'timestamp_s': f"{anom['time']:.2f}", 'top_prediction': anom['top_3'][0][0], 'confidence': f"{anom['top_3'][0][1]:.2f}", 'details': f"LoudnessRatio={anom['loudness_ratio']:.2f}"}
            plot_path = plot_path_map.get(anom['time'])
            if plot_path:
                log_entry.update({'status': 'KEPT', 'reason': 'Part of Event Group', 'spectrogram_path': plot_path})
            else:
                log_entry.update({'status': 'DISCARDED', 'reason': 'Isolated Anomaly', 'spectrogram_path': ''})
            log_entries.append(log_entry)
        
        high_quality_event_count = sum(1 for group in event_groups if any(anom['loudness_ratio'] >= cfg['loudness_ratio_threshold'] for anom in group))
        
        base_audio_path = os.path.join(cfg['output_folder'], cfg['audio_clips_subfolder'])
        hq_audio_path = os.path.join(cfg['output_folder'], cfg['loud_clips_subfolder'])
        for group in event_groups:
            is_high_quality = any(anom['loudness_ratio'] >= cfg['loudness_ratio_threshold'] for anom in group)
            output_path = hq_audio_path if cfg['separate_loud_anomalies'] and is_high_quality else base_audio_path
            start_time = group[0]['time']
            clip_filename = f"{os.path.splitext(file_name)[0]}_event_clip_{start_time:.0f}s.wav"
            start_sample = max(0, int((start_time - cfg['pre_context_seconds']) * sr))
            end_sample = min(len(waveform_np), int((group[-1]['time'] + cfg['chunk_seconds'] + cfg['post_context_seconds']) * sr))
            torchaudio.save(os.path.join(output_path, clip_filename), torch.from_numpy(waveform_np[start_sample:end_sample]).unsqueeze(0), sr)
            
        file_summary = {'filename': file_name, 'total_events': len(event_groups), 'high_quality_events': high_quality_event_count}
        return file_summary, log_entries

    except Exception as e:
        print(f"FAILED to process file {file_name}. Error: {e}")
        return None, log_entries
    finally:
        if 'waveform_np' in locals(): del waveform_np
        gc.collect()

def main_analysis_loop():
    """Main execution block: sets up directories, loads model, and loops through days."""
    os.makedirs(config['output_folder'], exist_ok=True)
    os.makedirs(os.path.join(config['output_folder'], config['audio_clips_subfolder']), exist_ok=True)
    os.makedirs(os.path.join(config['output_folder'], config['std_plots_subfolder']), exist_ok=True)
    if config['separate_loud_anomalies']:
        os.makedirs(os.path.join(config['output_folder'], config['loud_clips_subfolder']), exist_ok=True)
        os.makedirs(os.path.join(config['output_folder'], config['hq_plots_subfolder']), exist_ok=True)
    
    print("Loading PANNs (Cnn14) model...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    panns_model = AudioTagging(checkpoint_path=None, device=device)
    class_labels = panns_model.labels
    print(f"Model loaded on device: {device}")



    today_str = datetime.date.today().strftime('%Y%m%d')

    days_to_process = [f"{d:02d}" for d in range(1, 31)]
    #days_to_process = ['09'] 
    # kept for validation purposes - on the 9th we have confirmed dolphin sightings
    for day in days_to_process:


        processing_date_str = f"20250{MONTH}{day}"
        is_today = (processing_date_str == today_str)
        date_pattern = f"{config['hydrophone_id']}_{processing_date_str}"
        print(f"\n{'='*20}\nANALYSING DAY: {date_pattern}\n{'='*20}")

        summary_filename = f"{date_pattern}_daily_summary.csv"
        log_filename = f"{date_pattern}_detailed_log.csv"

        # check if we're running the scrip TODAY (we assume the data isn't collected fully so we want to have an incomplete analysis so we can run again later
        if is_today:
            print(f"INFO: Processing for the current day. Outputs will be marked as INCOMPLETE.")
            summary_filename = f"{date_pattern}_daily_summary_INCOMPLETE.csv"
            log_filename = f"{date_pattern}_detailed_log_INCOMPLETE.csv"

        summary_file_path = os.path.join(config['output_folder'], summary_filename)
        log_file_path = os.path.join(config['output_folder'], log_filename)

        # Skip if a FINAL summary exists for a past day. Allows re-running and overwriting for "today".
        final_summary_path = os.path.join(config['output_folder'], f"{date_pattern}_daily_summary.csv")
        if not is_today and os.path.exists(final_summary_path):
            print(f"Final summary file for this day already exists. Skipping.")
            continue

        search_pattern = os.path.join(config['audio_folder_path'], f"*{date_pattern}*.wav")
        audio_files_for_day = sorted(glob.glob(search_pattern))
        if not audio_files_for_day: continue
        
        print(f"Found {len(audio_files_for_day)} audio files to analyze for this day.")
        
        daily_summaries = []
        all_daily_logs = []
        for audio_file in audio_files_for_day:
            file_summary, file_logs = process_file(audio_file, panns_model, class_labels, config)
            if file_summary: daily_summaries.append(file_summary)
            if file_logs: all_daily_logs.extend(file_logs)
            print("-" * 50)

        # Write detailed log to CSV
        if all_daily_logs:
            log_file_path = os.path.join(config['output_folder'], f"{date_pattern}_detailed_log.csv")
            log_fieldnames = ['file', 'timestamp_s', 'status', 'reason', 'top_prediction', 'confidence', 'details', 'spectrogram_path']
            with open(log_file_path, 'w', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=log_fieldnames, extrasaction='ignore')
                writer.writeheader()
                writer.writerows(all_daily_logs)

        # Write high-level summary to CSV
        if daily_summaries:
            print(f"\n--- Day {day} Complete ---")
            print(f"Found events in {len(daily_summaries)} files. Saving summary to CSV.")
            summary_fieldnames = ['filename', 'total_events', 'high_quality_events']
            with open(summary_file_path, 'w', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=summary_fieldnames)
                writer.writeheader()
                writer.writerows(daily_summaries)
        else:
            print(f"\n--- Day {day} Complete --- \nNo significant events found.")

# --- Execute the Analysis ---
if __name__ == "__main__":
    main_analysis_loop()

Loading PANNs (Cnn14) model...
Checkpoint path: /home/operator0/panns_data/Cnn14_mAP=0.431.pth
GPU number: 1
Model loaded on device: cuda

ANALYSING DAY: LUW6952_20250701

ANALYSING DAY: LUW6952_20250702

ANALYSING DAY: LUW6952_20250703

ANALYSING DAY: LUW6952_20250704

ANALYSING DAY: LUW6952_20250705

ANALYSING DAY: LUW6952_20250706

ANALYSING DAY: LUW6952_20250707

ANALYSING DAY: LUW6952_20250708
Final summary file for this day already exists. Skipping.

ANALYSING DAY: LUW6952_20250709
Found 236 audio files to analyze for this day.
--- Processing file: LUW6952_20250709_000000.wav ---


KeyboardInterrupt: 